PyTorchで、Tensor配列の非ゼロ要素のインデックスを取得するにはtorch.nonzeroを使用する。
torch.nonzeroの使い方
torch.nonzeroの引数にTensor配列を指定すると、非ゼロ要素のインデックスが2次元配列のTensor配列として出力されます。
import torch
a = torch.tensor([1, 1, 1, 0, 1, 2])
print(torch.nonzero(a))
# tensor([[0],
# [1],
# [2],
# [4],
# [5]])
print(torch.nonzero(a).shape)
# torch.Size([5, 1])
b = torch.tensor([[1, 0, 1],
[0, 1, 0]])
print(torch.nonzero(b))
# tensor([[0, 0],
# [0, 2],
# [1, 1]])
引数as_tupleをTrueに指定すると、非ゼロ要素のインデックスは、各軸ごとにまとめられタプルで返ってくる。
a = torch.tensor([1, 1, 1, 0, 1, 2])
print(torch.nonzero(a, as_tuple=True))
# (tensor([0, 1, 2, 4, 5]),)
b = torch.tensor([[1, 0, 1],
[0, 1, 0]])
print(torch.nonzero(b, as_tuple=True))
# (tensor([0, 0, 1]), tensor([0, 2, 1]))
c = torch.tensor([[1, 0, 0],
[0, 1, 0],
[0, 0, 1]])
print(torch.nonzero(c, as_tuple=True))
# (tensor([0, 1, 2]), tensor([0, 1, 2]))
関連記事、参考資料
PyTorch公式からも紹介されている本で、Tensor配列の基礎的な内容から実際の画像データを用いた実践的な内容まで網羅しています。
コメント