スポンサーリンク

【PyTorch】条件に応じた処理を行うtorch.where

Python

PyTorchのTensor配列に対して、条件を満たす要素に対して処理を行うには、torch.whereを使う。

torch.whereの使い方

torch.where(condition, x, y)

torch.whereconditionに条件式を指定し、その条件がTrueであればx、Falseであればyが代入される。 例えば、Tensor配列の要素の値が10以上であれば100、10未満であれば0にしたい場合は以下のように書く。

import torch

x = torch.Tensor([[0,10,2,30,4,],
                  [5,60,7,80,9]])
print(x)
# tensor([[ 0., 10.,  2., 30.,  4.],
#        [ 5., 60.,  7., 80.,  9.]])

y = torch.where(x >= 10, 100, 0)
print(y)
# tensor([[  0, 100,   0, 100,   0],
#        [  0, 100,   0, 100,   0]])

元のTensor配列の条件を満たす要素のみ置換したい場合は以下のように書く。

y = torch.where(x >= 10, 100, x)
print(y)
# Tensor([[  0., 100.,   2., 100.,   4.],
#        [  5., 100.,   7., 100.,   9.]])

条件を満たす要素のインデックスを取得

torch.whereconditionのみ指定した場合は、条件を満たす要素のインデックスを取得できる。

print(torch.where(x >= 10))
# (tensor([0, 0, 1, 1]), tensor([1, 3, 1, 3]))

関連記事、参考資料

PyTorch公式からも紹介されている本で、Tensor配列の基礎的な内容から実際の画像データを用いた実践的な内容まで網羅しています。

コメント