PyTorchのTensor配列に対して、条件を満たす要素に対して処理を行うには、torch.whereを使う。
torch.whereの使い方
torch.where(condition, x, y)
torch.whereのconditionに条件式を指定し、その条件がTrueであればx、Falseであればyが代入される。 例えば、Tensor配列の要素の値が10以上であれば100、10未満であれば0にしたい場合は以下のように書く。
import torch
x = torch.Tensor([[0,10,2,30,4,],
[5,60,7,80,9]])
print(x)
# tensor([[ 0., 10., 2., 30., 4.],
# [ 5., 60., 7., 80., 9.]])
y = torch.where(x >= 10, 100, 0)
print(y)
# tensor([[ 0, 100, 0, 100, 0],
# [ 0, 100, 0, 100, 0]])
元のTensor配列の条件を満たす要素のみ置換したい場合は以下のように書く。
y = torch.where(x >= 10, 100, x)
print(y)
# Tensor([[ 0., 100., 2., 100., 4.],
# [ 5., 100., 7., 100., 9.]])
条件を満たす要素のインデックスを取得
torch.whereのconditionのみ指定した場合は、条件を満たす要素のインデックスを取得できる。
print(torch.where(x >= 10))
# (tensor([0, 0, 1, 1]), tensor([1, 3, 1, 3]))
関連記事、参考資料
PyTorch公式からも紹介されている本で、Tensor配列の基礎的な内容から実際の画像データを用いた実践的な内容まで網羅しています。
コメント