スポンサーリンク

【PyTorch】Focal Lossの実装と効果の確認【自作損失関数】

Machine Learning

PyTorchにおけるFocal Lossの実装を行ない、簡単な追試を行ない性能がどのようになるか見ていきます。

Focal Lossについて

Facebook AI Research (FAIR)によって2017年に物体検出を対象に提案された損失関数です。

「物体検出におけるR-CNNなどの2段階手法に比べて、1段階手法は高速な一方で性能が劣る課題があった。この性能が劣る理由は、クラス間の不均衡であることを発見し、これを解決するためにFocal lossを提案した。この損失関数を組み込んだネットワークを提案し、既存の2段階検出器の性能を超えつつ、一段階検出器と同等の速度を達成した。」

論文の概要は上記のような内容で、ここでFocal Lossが使われています。Focal Lossは、分類が容易なサンプルの重みを下げることで、分類が難しいサンプルにより焦点をあてる。これにより、サンプル数が少ないクラスや分類が難しいサンプルに対して学習がしやすくなる特徴があります。

Focal Lossは物体検出を対象に提案されたロス関数ですが、シンプルな形であり様々な分野やタスクに応用可能です。

マルチラベルを対象としたFocal Loss

マルチラベル分類タスクを対象としたFocal Lossを実装していきます。

import torch
from torch import nn

class Focal_MultiLabel_Loss(nn.Module):
    def __init__(self, gamma):
      super(Focal_MultiLabel_Loss, self).__init__()
      self.gamma = gamma
      self.bceloss = nn.BCELoss(reduction='none')

    def forward(self, outputs, targets): 
      bce = self.bceloss(outputs, targets)
      bce_exp = torch.exp(-bce)
      focal_loss = (1-bce_exp)**self.gamma * bce
      return focal_loss.mean()

実装が正しいか確認するために、Binary Cross Entropy(BCE)と比較します。Focal Lossのγは2に指定します。

import matplotlib.pyplot as plt

# binary cross entropy
loss = nn.BCELoss(reduction='none')
target = torch.Tensor([1])
for input in torch.arange(0.001, 1, 0.001):
  output = loss(input, target[0])
  plt.scatter(input, output, color = "blue")

# focal loss γ=2
loss = Focal_MultiLabel_Loss(gamma=2)
for input in torch.arange(0.001, 1, 0.001):
  output = loss(input, target[0])
  plt.scatter(input, output, color = "red")

plt.xlabel('probability of ground truth class')
plt.ylabel('loss');

青線がBCEの結果で、赤線がFocal lossの結果になります。論文の図と同様になっていることが確認できました。

Focal Lossの効果の確認

Focal Lossの効果の確認するために、Kaggleクレジットカード詐欺データセットをダウンロードを利用します。

import pandas as pd

raw_df = pd.read_csv('<https://storage.googleapis.com/download.tensorflow.org/data/creditcard.csv>')
display(raw_df.head())

データの不均衡を確認してみると、全体の14%とPositiveが非常に少ないことが分かります。

import numpy as np
neg, pos = np.bincount(raw_df['Class'])
total = neg + pos
print('Examples:\\n    Total: {}\\n    Positive: {} ({:.2f}% of total)\\n'.format(
    total, pos, 100 * pos / total))

# Examples:
#     Total: 284807
#     Positive: 492 (0.17% of total)

簡単な前処理及び訓練データと検証データに分割し、DataLoaderにセットします。

from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader
cleaned_df = raw_df.copy()

cleaned_df.pop('Time')

eps = 0.001
cleaned_df['LogAmount'] = np.log(cleaned_df.pop('Amount') + eps)

train_df, test_df = train_test_split(cleaned_df, test_size=0.2, random_state=0)

train_labels = np.array(train_df.pop('Class'))
test_labels = np.array(test_df.pop('Class'))

train_features = np.array(train_df)
test_features = np.array(test_df)

batch_size = 128

train_dataset = torch.utils.data.TensorDataset(torch.tensor(train_features).to(torch.float32), torch.tensor(train_labels.astype(np.float32)).unsqueeze(1))
val_dataset = torch.utils.data.TensorDataset(torch.tensor(test_features).to(torch.float32), torch.tensor(test_labels.astype(np.float32)).unsqueeze(1))
train_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)
val_dataloader = DataLoader(val_dataset, shuffle=False)

シンプルなニューラルネットワークを定義します。

import torch.nn as nn
import torch.nn.functional as F

class SimpleNet(nn.Module):
    def __init__(self):
        super(SimpleNet, self).__init__()
        self.fc1 = nn.Linear(29, 16)
        self.fc2 = nn.Linear(16, 8)
        self.fc3 = nn.Linear(8, 1)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = torch.sigmoid(self.fc3(x))
        return x

まずは、損失関数をBinary Cross Entropyにして、学習させ性能を見ていきます。

import torch.optim as optim

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")

model = SimpleNet().to(device)

criterion = nn.BCELoss() # Binary Cross Entropy
optimizer = optim.Adam(model.parameters(), lr=1e-3)

def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    model.train()
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)

        # Compute prediction error
        pred = model(X)
        # print(X)
        # print(pred)
        # print(y)
        loss = loss_fn(pred, y)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch % 100 == 0:
            loss, current = loss.item(), batch * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")

epochs = 2
for t in range(epochs):
    print(f"Epoch {t+1}\\\\n-------------------------------")
    train(train_dataloader, model, criterion, optimizer)
print("Done!")

学習したモデルで検証データを推論し、混同行列を見てみます。

from sklearn.metrics import confusion_matrix
import seaborn as sns

pred = []
Y = []
model.eval()
for x, y in val_dataloader:
    with torch.no_grad():
        output = model(x.to(device))
    pred += [1 if output > 0.5 else 0]
    Y += [int(l) for l in y]

# 混同行列表示用
def plot_confusion_matrix(true, predicted):
    cm = confusion_matrix(true, predicted)
    fig, ax = plt.subplots(figsize = (10,10))
    sns.heatmap(cm, square=True, cbar=False, annot=True, cmap='Blues')
    plt.xlabel("Predicted label")
    plt.ylabel("True label")

plot_confusion_matrix(Y, pred)

全てのデータをモデルが0と予測していることが分かります。

次に、損失関数にFocal Lossを設定したモデルで検証データを推論し、混同行列を見てみます。

model = SimpleNet().to(device)

criterion = Focal_MultiLabel_Loss(gamma=4) # Binary Cross Entropy
optimizer = optim.Adam(model.parameters(), lr=1e-3)

for t in range(epochs):
    print(f"Epoch {t+1}\\\\n-------------------------------")
    train(train_dataloader, model, criterion, optimizer)
print("Done!")

pred = []
Y = []
model.eval()
for x, y in val_dataloader:
    with torch.no_grad():
        output = model(x.to(device))
    pred += [1 if output > 0.5 else 0]
    Y += [int(l) for l in y]

plot_confusion_matrix(Y, pred)

少数派のラベルに対しても推論が上手くいっていることが確認できました。

関連記事、参考資料

コメント