はじめに
こんにちは、Data Strategy所属の岡です。グループ会社BASE BANKで分析/モデリングなども兼務しています。
テキストデータを特徴量にもつ不均衡データ分類問題をDNNで解きたくなった際、下記の論文を参考にしたのでその内容を紹介します。
https://users.cs.fiu.edu/~chens/PDF/ISM15.pdf
不均衡データ分類問題ってなに?
何かしらのカテゴリを機械学習などで分類予測しようとする際、カテゴリごとのデータ件数に偏りがある、特に正例のデータが極端に少ないケースで予測精度が上がりにくい、という問題をこのように呼んでいます。
例: 不正決済と正常な注文、不正商品と健全な商品、がん患者と正常な患者
普通はどうやって対処するの?
ベースとなるアプローチは下記3つにまとめられます。
アプローチ | 内容 | デメリット |
---|---|---|
アンダーサンプリング | 多数派データをランダムに減らして少数派データと均一にする | 多数派データの多くを捨てるため、情報損失が生じうる |
オーバーサンプリング | 少数派を増やし、多数派データと均一にする 例: SMOTE, ADASYN |
少数派データの水増しになるため、過学習が懸念される |
損失関数のカスタマイズ | 損失関数に対して多数派データのコストを少数派データとの割合に応じて割り引くなど | あまりデメリットについて明言されてるケースは少ないと思いますが、これも過学習になりうるという印象 |
個人的には、アンダーサンプリングの情報損失という弱点をカバーしている under sampling + bagging をよく使用しています。DNNでこれと似たようなアプローチができないか調べていたところ、冒頭の論文を発見しました。
ミニバッチでunder samplingするアプローチ
今回紹介する論文では、ミニバッチ作成時にラベルごとのサンプルサイズを合わせる方法を提案しています。
この処理フローは下図のようにまとめられます。
まず、多数派データ(P個)のみをバッチサイズ(ここではN個)分に分割してミニバッチ作成します。ここでミニバッチ1つあたりのサイズは(P / N)個です。
続いて少数派データ(Q個)から(P / N)個を重複がないようにランダムサンプリングし、最初に作った1個目のミニバッチのデータに混ぜます。このサンプリングを続いて2,3,4...N個目のミニバッチごとに繰り返し、ラベルの比率が均等なN個のミニバッチが作られます。
注意点として、ミニバッチ1個あたりに対して少数派データ(Q個)から(P / N)個を選び出すときは非復元抽出となるため、ミニバッチ単体で見ればデータの重複はありません。ただし、別のミニバッチを作るときには同じようにQ個から(P / N)個選び出すことになるので、バッチ全体としては少数派データの重複を許しつつサンプリングされています。
実装例
PyTorchでの実装例を示します。図示したミニバッチ作成部分だけ簡略化したコードだと下記のようになります。
class BinaryBalancedSampler: def __init__(self, features, labels, n_samples): self.features = features self.labels = labels label_counts = np.bincount(labels) major_label = label_counts.argmax() minor_label = label_counts.argmin() self.major_indices = np.where(labels == major_label)[0] self.minor_indices = np.where(labels == minor_label)[0] np.random.shuffle(self.major_indices) np.random.shuffle(self.minor_indices) self.used_indices = 0 self.count = 0 self.n_samples = n_samples self.batch_size = self.n_samples * 2 def __iter__(self): self.count = 0 while self.count + self.batch_size < len(self.major_indices): # 多数派データ(major_indices)からは順番に選び出し # 少数派データ(minor_indices)からはランダムに選び出す操作を繰り返す indices = self.major_indices[self.used_indices:self.used_indices + self.n_samples].tolist()\ + np.random.choice(self.minor_indices, self.n_samples, replace=False).tolist() yield torch.tensor(self.features[indices]), torch.tensor(self.labels[indices]) self.used_indices += self.n_samples self.count += self.n_samples * 2
コメントでも触れられている↓の箇所のコードが今回の肝となる操作です。
indices = self.major_indices[self.used_indices:self.used_indices + self.n_samples].tolist()\
+ np.random.choice(self.minor_indices, self.n_samples, replace=False).tolist()
多数派データのインデックス(major_indices)からは順番に選び出し、少数派データのインデックス(minor_indices)からはランダムに選び出す操作をしています。それぞれのラベルからn_samples
ずつ取り出して結合しています。
この操作は後でイテレータとして用いるので、def __iter__(self)
で定義しています。yield
の出力はたとえば下記のようになります。
test_iter = BinaryBalancedSampler(features=train_features, labels=train_labels, n_samples=50) for i in test_iter: feature, label = i print(feature.shape) print(label.shape) print(np.bincount(label)) break > torch.Size([100, 29]) > torch.Size([100]) > [50 50]
1イテレーションあたりに特徴量データと対応するラベルを返しています。ここではn_samples=50
としたので2ラベル分で100個のデータが生成されています。ラベルの分布も [50 50]
と均一になっているので、想定通りのミニバッチが作成できていそうです。
実験
不均衡データを普通に学習させた時と、均一になったミニバッチで学習させた場合の精度の差を見てみます。
学習結果の全容はこちらのgitに載せたので、ここでは簡単に説明していきます。
前処理
サンプルデータセットにはtensorflowの下記チュートリアルと同じものを用いました。
https://www.tensorflow.org/tutorials/structured_data/imbalanced_data
こちらも不均衡データをどのように分類するかという内容になっています。
import pandas as pd from IPython.display import display raw_df = pd.read_csv('https://storage.googleapis.com/download.tensorflow.org/data/creditcard.csv') display(raw_df.head())
読み込んだデータは上記のような中身になっていて、Class
列が分類対象のラベルです。
このラベルの分布が偏っていることを確認します。
import numpy as np neg, pos = np.bincount(raw_df['Class']) total = neg + pos print('Examples:\n Total: {}\n Positive: {} ({:.2f}% of total)\n'.format( total, pos, 100 * pos / total)) > Examples: > Total: 284807 > Positive: 492 (0.17% of total)
1のラベルが全体の0.17%しかないという偏り具合で、かなりの不均衡データになっています。
続いてチュートリアルと同じように前処理し、訓練データとテストデータに分割します。
from sklearn.model_selection import train_test_split # Cleaning cleaned_df = raw_df.copy() # `Time` カラムは不要 cleaned_df.pop('Time') # `Amount` カラムは数値のレンジが広すぎるので対数化 eps = 0.001 cleaned_df['LogAmount'] = np.log(cleaned_df.pop('Amount') + eps) # split train and test. train_df, test_df = train_test_split(cleaned_df, test_size=0.2, random_state=0) # split label and feature train_labels = np.array(train_df.pop('Class')) test_labels = np.array(test_df.pop('Class')) train_features = np.array(train_df) test_features = np.array(test_df)
使用するネットワークはかなりシンプルにしました。
import torch import torch.nn as nn import torch.nn.functional as F class SimpleNet(nn.Module): def __init__(self): super(SimpleNet, self).__init__() self.fc1 = nn.Linear(29, 32) self.fc2 = nn.Linear(32, 16) self.fc3 = nn.Linear(16, 1) def forward(self, x): x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) x = torch.sigmoid(self.fc3(x)) return x
入力される特徴量が29ユニット、最終的に予測するラベルは0 or 1なので出力は1ユニットにしています。
その他、損失関数とオプティマイザーは下記を用いました。
import torch.optim as optim criterion = nn.BCELoss() # BCELoss: Binary Crossentropy optimizer = optim.Adam(simple_net.parameters(), lr=1e-3)
普通に学習させた結果
学習部分のコードは下記のようになっています。
PyTorch特有のDataset, DataLoader
を使ってて見慣れない部分があるかもしれませんが、単純にミニバッチサイズは1000、エポック数は2で訓練させただけの普段通りの処理です。読み飛ばして構いません。
from torch.utils.data import Dataset, DataLoader class MyDataset(Dataset): def __init__(self, feature, label): self.feature = feature self.label = label self.data_num = feature.shape[0] def __len__(self): return self.data_num def __getitem__(self, idx): out_feature = self.feature[idx] out_label = self.label[idx] return out_feature, out_label BATCH_SIZE = 1000 trainset = MyDataset(feature=train_features, label=train_labels) trainloader = DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True) testset = MyDataset(feature=test_features, label=test_labels) testloader = DataLoader(testset, shuffle=False) # training for epoch in range(2): running_loss = 0.0 for i, data in enumerate(trainloader, 0): inputs, labels = data # zero the parameter gradients optimizer.zero_grad() # forward + backward + optimize outputs = simple_net(inputs.float()) loss = criterion(outputs, labels.view(-1, 1).float()) loss.backward() optimizer.step() # print statistics running_loss += loss.item() if i % 100 == 99: # print every 100 mini-batches print('[%d, %5d] loss: %.3f' % (epoch + 1, i + 1, running_loss / 100)) running_loss = 0.0 print('Finished Training')
次に学習し終えたネットワークでテストデータのラベルを予想させます。
pred = [] Y = [] for x, y in testloader: with torch.no_grad(): output = simple_net(x.float()) pred += [1 if output > 0.5 else 0] Y += [int(l) for l in y]
テストデータの予測結果をconfusion matrixで可視化すると下記のようになりました。
全データをラベル0(多数派)と予測していて、うまい具合に不均衡データの罠にハマっています。
ミニバッチでunder samplingして予測した結果
実装例のコードで書いたBinaryBalancedSampler
を用いて学習させます。
少数派ラベルのサイズを考慮して、ここではミニバッチサイズは100、エポック数は2にしています。
balanced_net = SimpleNet() balanced_loader = BinaryBalancedSampler(features=train_features, labels=train_labels, n_samples=50) criterion = nn.BCELoss() # BCELoss: Binary Crossentropy optimizer = optim.Adam(balanced_net.parameters(), lr=1e-3) # training for epoch in range(2): running_loss = 0.0 for i, data in enumerate(balanced_loader, 0): inputs, labels = data # zero the parameter gradients optimizer.zero_grad() # forward + backward + optimize outputs = balanced_net(inputs.float()) loss = criterion(outputs, labels.view(-1, 1).float()) loss.backward() optimizer.step() # print statistics running_loss += loss.item() if i % 1000 == 999: # print every 1000 mini-batches print('[%d, %5d] loss: %.3f' % (epoch + 1, i + 1, running_loss / 1000)) running_loss = 0.0 print('Finished Training')
このネットワークでテストデータを予測させると下記のようになりました。
多数派ラベルの巻き込みも少なく(187 サンプル)、少数派ラベルの9割弱(88サンプル)を当てることができました。
実験のまとめ
サンプルに使った不均衡ラベルのデータセットについては、ミニバッチでunder samplingするアプローチは有効そうです。
感想
この方法はミニバッチごとには均一なデータセットが得られますが、バッチ全体で見れば少数派データを複数回ピックアップするので、ある意味オーバーサンプリングとなり過学習の心配があります。early stoppingなどを組み合わせる必要がありそうです。
(あと発表が2015年とやや古い...)
とはいえ実装はとても簡単なため、実務でサクッと試せる点が個人的に気に入っています。