PyTorchで自作の損失関数(loss function)を使う

Machine Learning

PyTorchで自作の損失関数(ロス関数)を使う方法について見ていきます。

スポンサーリンク

自作損失関数の定義

自作損失関数はnn.Moduleを継承して、基本的には定義します。

import torch
from torch import nn
class CustomLoss(nn.Module):
    def __init__(self): # パラメータの設定など初期化処理を行う
        super(CustomLoss, self).__init__()

    def forward(self, outputs, targets): # モデルの出力と正解データ

        # ロスの計算を何かしら書く
        # loss = torch.mean(outputs - targets)
        # ロスの計算を返す
        return loss

平均絶対誤差(MAE:Mean Absolute Error)を定義する場合は、以下のように書く。

class MAELoss(nn.Module):
    def __init__(self): 
        super(MAELoss, self).__init__()

    def forward(self, outputs, targets):

        loss = torch.mean(torch.abs(outputs - targets))
        return loss
スポンサーリンク

損失関数を使う

PyTorch標準の損失関数と同様に、以下のようにして使うことができます。

criterion = MAELoss()
outputs = model(inputs)
loss = criterion(outputs, targets)
loss.backward()

関連記事、参考資料

コメント