PyTorchにモデルの保存と読み込みには大きく分けて2種類の方法があります。1つ目はtorch.save/torch.loadを使う方法で、2つ目はTorchScript形式で保存/読み込む方法です。
torch.saveとtorch.loadを使う方法
始めに、保存や読み込みの対象とするモデルを定義します。今回は、全結合のシンプルなニューラルネットワークを対象とします。
import torch
from torch import nn
class NeuralNetwork(nn.Module):
def __init__(self):
super(NeuralNetwork, self).__init__()
self.flatten = nn.Flatten()
self.linear_relu_stack = nn.Sequential(
nn.Linear(28*28, 512),
nn.ReLU(),
nn.Linear(512, 512),
nn.ReLU(),
nn.Linear(512, 10)
)
def forward(self, x):
x = self.flatten(x)
logits = self.linear_relu_stack(x)
return logits
model = NeuralNetwork()
モデルのパラメータのみ保存/読み込むmodel.state_dict()
model.state_dict()を使うことでモデルのパラメータのみを保存することができます。
torch.save(model.state_dict(), 'model_weight.pth')
model.state_dict()を使って保存されたモデルはパラメータのみのため、読み込む際はモデル構造を定義した後に読み込みます。読み込みに成功すると、<All keys matched successfully>と表示されます。
model = NeuralNetwork()
model.load_state_dict(torch.load('model_weight.pth'))
# <All keys matched successfully>
モデル全体を保存/読み込む
モデル全体を保存する場合は、モデルをそのままsaveするだけです。
torch.save(model, 'model_weight.pth')
読み込みもそのままloadするだけになります。
model = torch.load('model_weight.pth')
TorchScript形式でモデル保存/読み込む方法
TorchScriptを使用してモデルを保存することもできます。
model_scripted = torch.jit.script(model)
model_scripted.save('model_scripted.pth')
読み込む方法もシンプルです。
model = torch.jit.load('model_scripted.pth')
関連記事、参考資料
PyTorch公式からも紹介されている本で、基礎的な内容から実際の画像データを用いた実践的な内容まで網羅しています。
コメント