スポンサーリンク

【PyTorch】Tensor配列にNaNが含まれているか判定するtorch.isnan、nanを指定した値に置き換えるtorch.nan_to_num

Python

PyTorchでTensor配列にNaNが含まれているか判定するにはtorch.isnan、NaNを指定した値に置き換えるにはtorch.nan_to_numを使う。

torch.isnanの使い方

torch.isnanにTensor配列を指定することで、配列の各要素にNaNが含まれているか判定できる。

import torch

x = torch.tensor([1, float('nan'), 2])
print(x)
# tensor([1., nan, 2.])

print(torch.isnan(x))
# tensor([False,  True, False])

ちなみに、nan同士を比較してもFalseとなるため判定できない。

print(float('nan') == float('nan'))
# False

torch.nan_to_numの使い方

torch.nan_to_numを使うことで、NaNを指定した値に置き換えることができる。またNaN以外にもinf(正の無限大)と-inf(負の無限大)も置き換えることができる。

torch.nan_to_numにTensor配列を指定することで、NaN、infと-infが置き換わる。デフォルトでは、NaNは0、infはdtypeで表現できる最大値、-infはdtypeで表現できる最小値になる。

x = torch.tensor([float('nan'), float('inf'), -float('inf'), 3.])

print(torch.nan_to_num(x))
# tensor([ 0.0000e+00,  3.4028e+38, -3.4028e+38,  3.0000e+00])

引数nanposinfneginfを指定することで、任意の値に置き換えることができる。

print(torch.nan_to_num(x, nan=5, posinf=10, neginf=-10))
# tensor([  5.,  10., -10.,   3.])

関連記事、参考資料

PyTorch公式からも紹介されている本で、Tensor配列の基礎的な内容から実際の画像データを用いた実践的な内容まで網羅しています。

コメント