スポンサーリンク

完全固定!PyTorchの乱数シード

PyTorch

機械学習では、再現性を持たせるために乱数シード(seed)を固定することが多いと思います。しかし、PyTorchを使っていると、固定しているつもりでも、なかなか出来ていない事が多いです。これは、PyTorchに関連するライブラリが多い事と、GPUのdeterministicを有効化する必要があるからです。この記事では、その固定方法についてのコードを紹介します。

PyTorchの乱数シード固定方法

PyTorch OfficialのREPRODUCIBILITYのページを参考に、Python、Numpy、PyTorchそれぞれの乱数ジェネレータを固定しないといけないことが分かります。また、nondeterministic algorithmsをしないように設定する必要があります。それらのコードを固定したコードは以下になります。

import random
import numpy as np
import torch

def torch_fix_seed(seed=42):
    # Python random
    random.seed(seed)
    # Numpy
    np.random.seed(seed)
    # Pytorch
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.use_deterministic_algorithms = True


torch_fix_seed()

GPU(deterministic)を有効にすると、再現性はあがるが、処理パフォーマンスが低下するので注意が必要のようです。

関連記事、参考資料

短コースでわかる PyTorch &深層学習プログラミング
図を豊富に使っており、PyTorchの全体像を掴むのに最適な本です。この本を読んだ後だと、PyTorchの公式ドキュメントがかなり読みやすくなりました。乱数の固定方法などについても書かれている良書です。

コメント