スポンサーリンク

Tensor配列の非ゼロ要素のインデックスを取得するtorch.nonzero【PyTorch】

Python

PyTorchで、Tensor配列の非ゼロ要素のインデックスを取得するにはtorch.nonzeroを使用する。

torch.nonzeroの使い方

torch.nonzeroの引数にTensor配列を指定すると、非ゼロ要素のインデックスが2次元配列のTensor配列として出力されます。

import torch

a = torch.tensor([1, 1, 1, 0, 1, 2])
print(torch.nonzero(a))
# tensor([[0],
#         [1],
#         [2],
#        [4],
#         [5]])
print(torch.nonzero(a).shape)
# torch.Size([5, 1])

b = torch.tensor([[1, 0, 1],
                  [0, 1, 0]])
print(torch.nonzero(b))
# tensor([[0, 0],
#         [0, 2],
#         [1, 1]])

引数as_tupleTrueに指定すると、非ゼロ要素のインデックスは、各軸ごとにまとめられタプルで返ってくる。

a = torch.tensor([1, 1, 1, 0, 1, 2])
print(torch.nonzero(a, as_tuple=True))
# (tensor([0, 1, 2, 4, 5]),)

b = torch.tensor([[1, 0, 1],
                  [0, 1, 0]])
print(torch.nonzero(b, as_tuple=True))
# (tensor([0, 0, 1]), tensor([0, 2, 1]))

c = torch.tensor([[1, 0, 0],
                  [0, 1, 0],
                  [0, 0, 1]])
print(torch.nonzero(c, as_tuple=True))
# (tensor([0, 1, 2]), tensor([0, 1, 2]))

関連記事、参考資料

PyTorch公式からも紹介されている本で、Tensor配列の基礎的な内容から実際の画像データを用いた実践的な内容まで網羅しています。

コメント