スポンサーリンク

PyTorchのモデルの保存と読み込み方法

Machine Learning

PyTorchにモデルの保存と読み込みには大きく分けて2種類の方法があります。1つ目はtorch.save/torch.loadを使う方法で、2つ目はTorchScript形式で保存/読み込む方法です。

torch.saveとtorch.loadを使う方法

始めに、保存や読み込みの対象とするモデルを定義します。今回は、全結合のシンプルなニューラルネットワークを対象とします。

import torch
from torch import nn

class NeuralNetwork(nn.Module):
    def __init__(self):
        super(NeuralNetwork, self).__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28*28, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10)
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits

model = NeuralNetwork()

モデルのパラメータのみ保存/読み込むmodel.state_dict()

model.state_dict()を使うことでモデルのパラメータのみを保存することができます。

torch.save(model.state_dict(), 'model_weight.pth')

model.state_dict()を使って保存されたモデルはパラメータのみのため、読み込む際はモデル構造を定義した後に読み込みます。読み込みに成功すると、<All keys matched successfully>と表示されます。

model = NeuralNetwork()
model.load_state_dict(torch.load('model_weight.pth'))
# <All keys matched successfully>

モデル全体を保存/読み込む

モデル全体を保存する場合は、モデルをそのままsaveするだけです。

torch.save(model, 'model_weight.pth')

読み込みもそのままloadするだけになります。

model = torch.load('model_weight.pth')

TorchScript形式でモデル保存/読み込む方法

TorchScriptを使用してモデルを保存することもできます。

model_scripted = torch.jit.script(model) 
model_scripted.save('model_scripted.pth') 

読み込む方法もシンプルです。

model = torch.jit.load('model_scripted.pth')

関連記事、参考資料

PyTorch公式からも紹介されている本で、基礎的な内容から実際の画像データを用いた実践的な内容まで網羅しています。

コメント