スポンサーリンク

【PyTorch】Tensor配列のユニークな要素の値、位置、数を取得するtorch.unique

Python

PyTorchのTensor配列のユニーク(一意)な要素の値、位置、数(カウント)を取得するにはtorch.uniqueを使う。

torch.uniqueの使い方

torch.unique()の第一引数にTensor配列を指定すると、ユニークな要素のTensor配列が取得できる。

第二引数のsortedのデフォルトがTrueになっているため、昇順にソートされて出力される。ソートしたくない場合は、第二引数のsortedにFalseを指定する。

import torch

a = torch.tensor([1, 3, 2, 3])
print(a)
# tensor([1, 3, 2, 3])

print(torch.unique(a))
# tensor([1, 2, 3])

print(torch.unique(a, sorted=False))
# tensor([2, 3, 1])

ユニークな要素の位置(インデックス)を取得

引数return_inverseにTrueを指定すると、ユニークな要素の位置を取得できる。

output, inverse_indices = torch.unique(a, return_inverse=True)

print(output)
# tensor([1, 2, 3])
print(inverse_indices)
# tensor([0, 2, 1, 2])

この場合は、元のTensor配列の0番目に1が、1番目に3、2番目に2、3番目に3があることを示している。

多次元配列の場合は、inverse_indicesも多次元配列で返ってくる。

b = torch.tensor([[1, 3, 2, 3],
                  [1, 3, 2, 3],])

output_b, inverse_indices_b = torch.unique(b, return_inverse=True)

print(output_b)
# tensor([1, 2, 3])
print(inverse_indices_b)
# tensor([[0, 2, 1, 2],
#        [0, 2, 1, 2]])

ユニークな数(カウント)の取得

引数return_countsにTrueを指定すると、ユニークな数を取得できる。

output, counts  = torch.unique(a, return_counts=True)

print(output)
# tensor([1, 2, 3])
print(counts)
# tensor([1, 1, 2])

この場合は、1が1つ、2が1つ、3が2つあることを示している。

関連記事、参考資料

PyTorch公式ページでも紹介されていた本で、「Tensorの仕組み」から「ディープラーニングの実践プロジェクト:肺がんの早期発見」までステップバイステップで説明しているため、中身をよく理解できます。

PyTorch実践入門 (Compass Booksシリーズ)

コメント