PyTorchのモデルmodelAをmodelBにコピーすると、変数名は異なりますが、idが同じことからメモリが共有されていることが分かります。この場合、どちらか一方のモデルを学習などで更新すると、もう一方のモデルも変更されてしまい、学習中にエポックごとのモデルを保存したい場合やベストなモデルを保存したい場合に意図せず変更されてしまう可能性があります。
ここでは、このような問題が発生しないようなモデルのコピーの仕方について記載します。
import torch
from torch import nn
class NeuralNetwork(nn.Module):
def __init__(self):
super(NeuralNetwork, self).__init__()
self.flatten = nn.Flatten()
self.linear_relu_stack = nn.Sequential(
nn.Linear(28*28, 512),
nn.ReLU(),
nn.Linear(512, 10),
)
def forward(self, x):
x = self.flatten(x)
logits = self.linear_relu_stack(x)
return logits
modelA = NeuralNetwork()
modelB = modelA
print(id(modelA) == id(modelB))
# True
copy.deepcopyの使い方
Pythonのcopyモジュールのdeepcopyを使います。deepcopyの引数にコピーしたい変数を指定します。modelAとmodelBのidが違うことからメモリが異なることが分かります。
import copy
modelA = NeuralNetwork()
modelB = copy.deepcopy(modelA)
print(id(modelA) == id(modelB))
# False
関連記事、参考記事
最短コースでわかる PyTorch &深層学習プログラミング
PyTorchのモデルの扱いについて丁寧に書かれておりオススメです。
コメント