PyTorchで学習を行なっていると、以下エラーが出たので対処方法をメモしていく。
ValueError: Expected more than 1 value per channel when training, got input size torch.Size([1, 256, 1, 1])
原因
学習するModelにバッチノーマライゼーション(torch.nn.BatchNorm2dなど)を含んでいると、バッチ数が1(入力データの数が1)の場合平均、分散が計算できないため発生する。
これは、バッチサイズ20などに指定した場合も、入力データの総数によっては発生しえる。例えば、入力データが201の場合、最後のバッチは入力データが1になってしまうためである。
解決方法
入力データを調整することで解決することもできるが、torch.utils.data.DataLoaderのオプションを指定することで自動で調整可能となる。
drop_lastをTrueに指定することで、データセットサイズがバッチサイズで割り切れない場合に最後のバッチが削除されます。
DataLoader(ds, batch_size, drop_last=True)
関連記事、参考資料
PyTorchの入門書で、GPUの利用方法、ネットワークの構築方法や転移学習まで幅広く書かれていてオススメです。torch.utils.data.DataLoaderやDatasetの使い方も詳しく書かれています。
コメント