PyTorchのTensor配列のユニーク(一意)な要素の値、位置、数(カウント)を取得するにはtorch.uniqueを使う。
torch.uniqueの使い方
torch.unique()の第一引数にTensor配列を指定すると、ユニークな要素のTensor配列が取得できる。
第二引数のsortedのデフォルトがTrueになっているため、昇順にソートされて出力される。ソートしたくない場合は、第二引数のsortedにFalseを指定する。
import torch
a = torch.tensor([1, 3, 2, 3])
print(a)
# tensor([1, 3, 2, 3])
print(torch.unique(a))
# tensor([1, 2, 3])
print(torch.unique(a, sorted=False))
# tensor([2, 3, 1])
ユニークな要素の位置(インデックス)を取得
引数return_inverseにTrueを指定すると、ユニークな要素の位置を取得できる。
output, inverse_indices = torch.unique(a, return_inverse=True)
print(output)
# tensor([1, 2, 3])
print(inverse_indices)
# tensor([0, 2, 1, 2])
この場合は、元のTensor配列の0番目に1が、1番目に3、2番目に2、3番目に3があることを示している。
多次元配列の場合は、inverse_indicesも多次元配列で返ってくる。
b = torch.tensor([[1, 3, 2, 3],
[1, 3, 2, 3],])
output_b, inverse_indices_b = torch.unique(b, return_inverse=True)
print(output_b)
# tensor([1, 2, 3])
print(inverse_indices_b)
# tensor([[0, 2, 1, 2],
# [0, 2, 1, 2]])
ユニークな数(カウント)の取得
引数return_countsにTrueを指定すると、ユニークな数を取得できる。
output, counts = torch.unique(a, return_counts=True)
print(output)
# tensor([1, 2, 3])
print(counts)
# tensor([1, 1, 2])
この場合は、1が1つ、2が1つ、3が2つあることを示している。
関連記事、参考資料
PyTorch公式ページでも紹介されていた本で、「Tensorの仕組み」から「ディープラーニングの実践プロジェクト:肺がんの早期発見」までステップバイステップで説明しているため、中身をよく理解できます。
コメント