スポンサーリンク

PyTorch modelを安全にコピーする方法【copy.deepcopy】

Machine Learning

PyTorchのモデルmodelAをmodelBにコピーすると、変数名は異なりますが、idが同じことからメモリが共有されていることが分かります。この場合、どちらか一方のモデルを学習などで更新すると、もう一方のモデルも変更されてしまい、学習中にエポックごとのモデルを保存したい場合やベストなモデルを保存したい場合に意図せず変更されてしまう可能性があります。
ここでは、このような問題が発生しないようなモデルのコピーの仕方について記載します。

import torch
from torch import nn

class NeuralNetwork(nn.Module):
    def __init__(self):
        super(NeuralNetwork, self).__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28*28, 512),
            nn.ReLU(),
            nn.Linear(512, 10),
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits

modelA = NeuralNetwork()

modelB = modelA

print(id(modelA) == id(modelB))
# True

copy.deepcopyの使い方

Pythonのcopyモジュールのdeepcopyを使います。deepcopyの引数にコピーしたい変数を指定します。modelAとmodelBのidが違うことからメモリが異なることが分かります。

import copy

modelA = NeuralNetwork()

modelB = copy.deepcopy(modelA)

print(id(modelA) == id(modelB))
# False

関連記事、参考記事

最短コースでわかる PyTorch &深層学習プログラミング
PyTorchのモデルの扱いについて丁寧に書かれておりオススメです。

コメント