PyTorchで自作の損失関数(ロス関数)を使う方法について見ていきます。
自作損失関数の定義
自作損失関数はnn.Moduleを継承して、基本的には定義します。
import torch
from torch import nn
class CustomLoss(nn.Module):
def __init__(self): # パラメータの設定など初期化処理を行う
super(CustomLoss, self).__init__()
def forward(self, outputs, targets): # モデルの出力と正解データ
# ロスの計算を何かしら書く
# loss = torch.mean(outputs - targets)
# ロスの計算を返す
return loss
平均絶対誤差(MAE:Mean Absolute Error)を定義する場合は、以下のように書く。
class MAELoss(nn.Module):
def __init__(self):
super(MAELoss, self).__init__()
def forward(self, outputs, targets):
loss = torch.mean(torch.abs(outputs - targets))
return loss
損失関数を使う
PyTorch標準の損失関数と同様に、以下のようにして使うことができます。
criterion = MAELoss()
outputs = model(inputs)
loss = criterion(outputs, targets)
loss.backward()
コメント