PyTorchでTensor配列にNaNが含まれているか判定するにはtorch.isnan、NaNを指定した値に置き換えるにはtorch.nan_to_numを使う。
torch.isnanの使い方
torch.isnanにTensor配列を指定することで、配列の各要素にNaNが含まれているか判定できる。
import torch
x = torch.tensor([1, float('nan'), 2])
print(x)
# tensor([1., nan, 2.])
print(torch.isnan(x))
# tensor([False, True, False])
ちなみに、nan同士を比較してもFalseとなるため判定できない。
print(float('nan') == float('nan'))
# False
torch.nan_to_numの使い方
torch.nan_to_numを使うことで、NaNを指定した値に置き換えることができる。またNaN以外にもinf(正の無限大)と-inf(負の無限大)も置き換えることができる。
torch.nan_to_numにTensor配列を指定することで、NaN、infと-infが置き換わる。デフォルトでは、NaNは0、infはdtypeで表現できる最大値、-infはdtypeで表現できる最小値になる。
x = torch.tensor([float('nan'), float('inf'), -float('inf'), 3.])
print(torch.nan_to_num(x))
# tensor([ 0.0000e+00, 3.4028e+38, -3.4028e+38, 3.0000e+00])
引数nan、posinf、neginfを指定することで、任意の値に置き換えることができる。
print(torch.nan_to_num(x, nan=5, posinf=10, neginf=-10))
# tensor([ 5., 10., -10., 3.])
関連記事、参考資料
PyTorch公式からも紹介されている本で、Tensor配列の基礎的な内容から実際の画像データを用いた実践的な内容まで網羅しています。
コメント