BASEプロダクトチームブログ

ネットショップ作成サービス「BASE ( https://thebase.in )」、ショッピングアプリ「BASE ( https://thebase.in/sp )」のプロダクトチームによるブログです。

誤分類コストを考慮した機械学習モデルの考え方

BASE Advent Calendar 2021

はじめに

この記事はBASE Advent Calendar 23日目の記事です。

こんにちは、DataStrategyチームの竹内です。
BASEではより良いサービスを提供するために色々なところで機械学習モデルが活用されています。 BASEに限らず、インターネット上のあらゆるサービスに機械学習の技術が活用されるようになって久しい昨今ですが、こうした実際のサービスやビジネス領域に近いところで活用される機械学習モデルにおいては、計算コストやメンテナンスコスト、解釈性やバイアス、データセットシフトなど色々と考えなければいけない特有の要素が存在します。

今回はその中の1つとも言える誤分類コストの非対称性の問題について考え、それに対するアプローチとしてコスト考慮型学習(Cost-Sensitive Learning)について扱っていきたいと思います。

コスト考慮型学習とは

コスト考慮型学習とは、データマイニングにおける誤分類時のコスト(誤分類コストに限らず、計算コストなどの他の要素を考慮する場合もあります。)を考慮した学習手法のことです。
機械学習などによる分類モデルを現実の問題に適用する場合、どのデータをどのクラスに誤分類してしまうかで生じるコストが異なる、いわゆる誤分類コストの非対称性の問題に直面することがあります。1
よく挙げられる例で言えば、がんのような重大な病気を診断する場合、本当に罹患している人に対して健康であると誤診してしまった際の影響は極めて致命的になり得る一方で、健康な人に対して罹患していると誤診してしまった際は追加の検査費用分のコストで済むことになります。

医療診断における非対称な誤分類コスト

健康な人 罹患している人
陰性と診断 - 治療の遅れ、信頼の失墜(コスト大)
陽性と診断 追加の検査費用(コスト小) 早期治療

医療診断の場合ではどの患者についても概ね同様なコスト行列が適用できますが、中にはサンプルごとに誤分類コストが異なる場合もあります。
例えばカードローン審査のような例では、利用客によって申請金額が異なるため誤分類コストが変わってきます。

カードローン審査におけるサンプルごとに異なる誤分類コスト

返済できる人 返済できない人
審査を通す 金利や手数料分の利益 金額分の損失
審査を通さない 適格な申請数の減少 不適格な申請数の減少

このようなサンプルごとに誤分類コストが異なるタスクにおいてモデルの作成や改良を行う際、サンプル数ベースの正答率や再現率、AUCの改善が必ずしも金額ベースの改善につながらない可能性があることに注意する必要があります。

このように現実の問題を分類タスクとして捉え、予測モデルの作成や改良を行う場合、適切な誤分類コストに基づいた性能評価が求められる場合があります。
適切な誤分類コストの設定には機械学習や統計一般の知識だけではなく、その分野特有の知識、いわゆるドメイン知識を十分に活用することが求められます。

Cost-Sensitive Learningの手法

ここからは上の例のような非対称な誤分類コストに対する具体的なアプローチについて扱っていきます。
誤分類コストを考慮したモデルを構成する手法は大きく分けて3つ存在します。

  1. 通常の学習を行なったモデルの出力に対する検出閾値を、サンプルの誤分類コストに応じて変更する
  2. 学習データセットのクラス比率、あるいは重みを誤分類コストに応じて変更した上で通常の学習を行う
  3. 誤差関数などモデルの学習手法そのものに誤分類コストを組み込む

今回は3つアプローチの中から、1つ目の閾値を変更する手法について説明していきます。

コスト行列

手法の説明に入る前に、コスト行列について整理しておきます。
クラス数Mの多クラス分類において、モデルが C_iと予測したデータの真のクラスが C_jであった時の誤分類コストを c_{ij}と表すことにします。 例えば二値分類の場合、クラス1を陽性、クラス0を陰性とすると、 c_{10}は「陽性と予測したが本当は陰性だった」ため偽陽性、逆に c_{01}は「陰性と予測したが本当は陽性だった」ため偽陰性のコストを表していることになります。

この時、 c_{ij}を要素としてもつM×Mの行列をコスト行列と呼びます。 例えば二値分類の場合は以下のような2×2の行列となります。

真のクラスが0 真のクラスが1
予測したクラスが0  c_{00}  c_{01}
予測したクラスが1  c_{10}  c_{11}

 c_{00} c_{11}には正しく分類した時の利益(負の値)が入りますが、実際はこの部分を0として誤分類コストの方に機会損失として織り込むことができます。(後ほど説明します。)

ものすごく大雑把な例ですが、先程のカードローン審査の例で申請者 iの申請金額を x_i円とし、金利を1%とした上で、rejectした場合の申請数の減少による効果を1人あたり大雑把に20000円と見積もる(返済できる人の申請数が減る場合は損失とし、返済できない人の申請数が減る場合は利益とします。)と、以下のようなコスト行列を考えることができます。

返済できる人 返済できない人
審査を通す  -0.01x_i  x_i
審査を通さない 20000 -20000

閾値の調整による誤分類コストの反映

コスト行列を定義したところで、閾値を調整する手法の具体的な説明に入っていきます。
大まかには「間違えたらまずい(誤分類コストが相対的に大きい)」クラスについては、その予測確率がたとえ50%を下回っていたとしても、予測結果としてそのクラスを出力するという手法になります。

最適な検出閾値はコスト行列から具体的に以下のように計算することができます。
データ xがクラス C_jに属する確率が P(C_j|x)であった時、コストの期待値が小さい方に分類することを考えます。
式で表すと、 (クラス C_1と予測した時のコストの期待値)  \le(クラス C_0と予測した時のコストの期待値 が成り立つ時、つまり

 
P(C_0|x)c_{10} + P(C_1|x)c_{11} 
\le P(C_0|x)c_{00} + P(C_1|x)c_{01}

が成り立つ時にクラス C_1=陽性と判別すれば良いことになります。
陽性である事後確率 P(C_1|x) pとおくと、 P(C_0|x)=1-pから

 
(1-p)c_{10}+pc_{11}\le (1-p)c_{00}+pc_{01} \\
(c_{10}-c_{00}) \le p(c_{10}-c_{00}+c_{01}-c_{11}) \\

が得られます。
ここで「間違えて分類した時のコストは常に正しく分類できた時のコストよりも大きい」つまり すべての j\neq iとなる iに対して、

 
c_{i, i} \le c_{j, i}

が成り立つことを仮定します。
この時上の式から

 
p^* = \frac{c_{10}-c_{00}}{c_{10}-c_{00}+c_{01}-c_{11}}

とすると、


p^* \le p

が得られます。
つまり正しい事後確率 pが得られた時、検出閾値を p^*として pがそれ以上であれば陽性、そうでなければ陰性とすることで誤分類コストの期待値を最小化することができます。
また p^*の式からコスト行列を

真のクラスが0 真のクラスが1
予測したクラスが0 0  c_{01} - c_{11}
予測したクラスが1  c_{10}-c_{00} 0

とおいたものも同じ結果が得られることがわかります。
これは先程説明した通り、正しく分類できた時に得られる利益を誤分類した時の機会損失として扱うことに相当します。
また、真のクラスにかかわらず同じ予測に対して常に発生する同じ大きさのコストは相殺することがわかります。

実際のデータセットを用いた例

実際のデータセットで閾値の調整によるコストの変化をみてみたいと思います。

今回はUCIのMushroom Classificationのデータセットを検証用に使わせていただくことにします。
https://www.kaggle.com/uciml/mushroom-classification
このデータセットは、様々なキノコの傘の形や色などの特徴量と共に、そのキノコが食用なのかそうでないかのラベルが与えられたデータセットとなります。
今回はこれを用いて、キノコのいくつかの特徴量からそれが食用なのかそうでないかを予測する単純な二値分類のタスクを考えます。
分類器としては単純な決定木を使うことにします。

食用でないキノコのクラスを1、食用であるキノコのクラスを0とすると、サンプルサイズは以下のようになります。

クラス0: 4208
クラス1: 3916

必要なライブラリのimportや前処理など

from sklearn.preprocessing import LabelEncoder
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import train_test_split
from sklearn import metrics
import numpy as np
import pandas as pd
import seaborn as sns
from matplotlib import pyplot as plt


def plot_confusion_matrix(y_test, y_pred, y_prob, cost_matrix):
    cm = metrics.confusion_matrix(y_test, y_pred)
    tn, fp, fn, tp = cm.flatten()
    accuracy = (tp + tn) / (tn + fp + fn + tp)
    precision = tp / (tp + fp)
    recall = tp / (tp + fn)
    f_score = 2 * recall * precision/(recall + precision)
    fpr, tpr, thresholds = metrics.roc_curve(y_test, y_prob)
    roc_auc = metrics.auc(fpr, tpr)
    total_cost = np.sum(cost_matrix * cm.T)
    print(f"accuracy: {accuracy*100:.4f}%")
    print(f"precision: {precision*100:.4f}%")
    print(f"recall: {recall*100:.4f}%")
    print(f"f-score: {f_score:.4f}")
    print(f"AUC: {roc_auc:.4f}")
    print(f"total cost: {total_cost}")
    df_cm = pd.DataFrame(cm.T, range(2),
                      range(2))
    sns.set(font_scale=1.4)
    sns.heatmap(df_cm, annot=True ,annot_kws={"size": 16}, fmt="")
    plt.xlabel("true label")
    plt.ylabel("prediction")
    plt.show()


df = pd.read_csv("mushrooms.csv")
le = LabelEncoder()
for k, v in df.dtypes.items():
    df[k] = le.fit_transform(df[k])


# 全特徴量を使用すると完璧に分類できてしまうぐらいタスクが簡単なので、今回は実験用に使用する特徴量を制限します
df = df[df.columns[:4]]
df

混同行列と各種指標

今回の場合、食用でないキノコを誤って食用だと判別して食べてしまった時の被害と、食用のキノコを誤って食用でないと判別して食べ損なってしまった時の被害では前者の方がより重大であると考え、以下のようなコスト行列を設定することにします。

真のクラスが0 真のクラスが1
予測したクラスが0 0 10
予測したクラスが1 1 0

このコスト行列のもとで決定木による分類を行い、まずは閾値を0.5に設定して混同行列やrecision、recallなどの各種指標とともにコストの総計を計算してみます。

cost_matrix = np.array([0, 10, 1, 0]).reshape((2,2))

x_train, x_test, y_train, y_test = train_test_split(df.drop(columns=["class"]), df["class"], test_size=0.2, random_state=0)
dtc = DecisionTreeClassifier(max_depth=5, random_state=0)
model = dtc.fit(x_train, y_train)
y_prob = model.predict_proba(x_test)[:, 1]
y_pred = (y_prob >= 0.5).astype(int)
plot_confusion_matrix(y_test, y_pred, y_prob, cost_matrix)

混同行列と各種指標

コスト行列と混同行列の要素積を取ることで得られたコストの総計(total cost)は2201となりました。

次に閾値を先程の p^*に変えて同じように分類を行ってみます。学習する過程ではコスト行列は使用しないため、再学習せずに同じモデルを使用することができます。

threshold = (cost_matrix[1,0] - cost_matrix[0,0]) / (cost_matrix[1,0] - cost_matrix[0,0] + cost_matrix[0,1] - cost_matrix[1,1])
y_pred = (y_prob >= threshold).astype(int)
plot_confusion_matrix(y_test, y_pred, y_prob, cost_matrix)

混同行列と各種指標

結果としてprecisionが下がる代わりにrecallが上がることでtotal costを2201から690まで下げることができました。

一応閾値に対するtotal costをplotしてみると、 p*がtotal costを最小化する閾値であることが確認できます。

x = np.linspace(0, 0.5, 5000)
y = []
for i in x:
    y_pred = (y_prob >= i).astype(int)
    cm = metrics.confusion_matrix(y_test, y_pred)
    total_cost = np.sum(cost_matrix * cm.T)
    y.append(total_cost)

fig, ax= plt.subplots(1, 1, figsize=(10, 8))
ax.plot(x, y)
ax.vlines(threshold, ymin=0, ymax=max(y), color="orange")
ax.legend(["total cost", "p*"], loc='upper center')
ax.set_xlabel("threshold")
ax.set_ylabel("total cost")
plt.show()

検出閾値と誤分類コスト

まとめ

誤分類コストの非対称性の問題と、それに対するアプローチの1つであるコスト考慮型学習(Cost-Sensitive Learning)について紹介させていただきました。
BASEで活用されている機械学習モデルの一部にも、こうしたコスト考慮型学習のアプローチが用いられています。

実際には正確なコスト行列を設定することが難しい場合もありますが、それでも誤分類コストの非対称性については常に意識する必要があります。 例えばローン審査において機械学習モデルを人手によるチェックのための一時フィルター的な役割で使用する場合には、あらかじめ許容できる件数ベースの偽陽性率の上限を決めた上で、金額ベースの再現率が最も高くなるような閾値を設定する、といった方法が有効かもしれません。 その場合はモデルを開発するエンジニアだけではなく、二次チェックを行うオペレーターともうまくコミュニケーションを取りながら達成すべき目標を明確にしていくステップが必要不可欠となります。

また、実際に運用していく上では誤分類時のコストだけではなく、冒頭で触れた通り計算コストやメンテナンスコスト、解釈性の問題など色々な要素が影響してきます。それらを踏まえた上でドメイン知識を活用し、短期的な利益だけではなく、長期的な利益を見据えて最適なモデルを選択しチューニングしていくことが、ビジネスでの機械学習モデルの活用を求められるデータサイエンティストの役割であると考えます。

参考文献


  1. 特に現実の応用例を考える際、誤分類コストの非対称性の問題は、必ずと言っていいほど不均衡データの問題と一緒になって現れますが、個人的にはこの二つの問題は分けて考える方が良いかと思っています。というのも不均衡データには不均衡データ特有の、分類器の識別境界にかかるバイアスなど(https://scikit-learn.org/stable/auto_examples/svm/plot_separating_hyperplane_unbalanced.htmlhttps://ieeexplore.ieee.org/document/6137280 などで説明されている)タスクのドメイン(医療診断で使うのか、ローン審査で使うのかなど)とは独立した問題が存在し、誤分類コストが対称であっても生じる可能性がある一方で、誤分類コストの非対称性の問題はタスクのドメインに依存した問題であり、不均衡データでなくても起こり得るからです。ただし、不均衡データの問題を解決するために非対称な誤分類コストを設定したり、誤分類コストの非対称性を考慮するためにunder-samplingやover-samplingなどによって敢えて不均衡な事前分布を設定する(この記事では詳しく紹介できませんでしたが Cost-Sensitive Learningの手法の2つ目に相当します。)アプローチが取られることはあります。今回の記事ではトピックを絞るため、あまりデータの不均衡性の問題は取り上げていません。(今後機会があれば、2つ目以降の手法と共に記事にしようかと思います。)