スポンサーリンク

【PyTorch】Tensor配列の最大値・最小値のインデックスを取得するtorch.argmax、torch.argmin【位置】

Python

PyTorchでTensor配列の最大値・最小値となる要素のインデックスを取得するには、torch.argmaxtorch.argminを使う。

torch.argmaxの使い方

torch.argmax()の引数にTenasor配列を指定すると最大値のインデックスが返ってくる。多次元配列の場合は、平坦化(torch.flatten())された配列のインデックスが返ってくる。

import torch

a = torch.tensor([1, 3, 2])
print(torch.argmax(a))
# tensor(1)

a = torch.tensor([[0, 3, 8],
                 [6, 1, 4],
                 [5, 7, 2]])

print(torch.argmax(a))
# tensor(2)
print(torch.flatten(a))
# tensor([0, 3, 8, 6, 1, 4, 5, 7, 2])

第二引数dimを指定すると、各軸に沿って最大値のインデックスが返ってくる。

print(torch.argmax(a, dim=0))
# tensor([1, 2, 0])
print(torch.argmax(a, dim=1))
# tensor([2, 0, 1])

torch.maxでの取得方法

torch.max()でも、引数dimを指定している場合は、最大値のインデックスがindicesとして返ってくる。以下の場合、indicesを取り出すとtorch.argmax()と同様な結果になることが分かる。

print(torch.max(a,dim=0))
# torch.return_types.max(values=tensor([6, 7, 8]),indices=tensor([1, 2, 0]))

print(torch.max(a,dim=0).indices)
# tensor([1, 2, 0])

torch.argminの使い方

torch.argmin()の引数にTenasor配列を指定すると最小値のインデックスが返ってくる。多次元配列の場合は、平坦化(torch.flatten())された配列のインデックスが返ってくる。

print(torch.argmin(a))
# tensor(0)

print(torch.flatten(a))
# tensor([0, 3, 8, 6, 1, 4, 5, 7, 2])

第二引数dimを指定すると、各軸に沿って最小値のインデックスが返ってくる。

print(torch.argmin(a, dim=0))
# tensor([0, 1, 2])

print(torch.argmin(a, dim=1))
# tensor([0, 1, 2])

torch.minでの取得方法

torch.min()でも、引数dimを指定している場合は、最小値のインデックスがindicesとして返ってくる。以下の場合、indicesを取り出すとtorch.argmin()と同様な結果になることが分かる。

print(torch.min(a,dim=0))
# torch.return_types.min(values=tensor([0, 1, 2]),indices=tensor([0, 1, 2]))

print(torch.min(a,dim=0).indices)
# tensor([1, 2, 0])

関連記事、参考資料

コメント