PyTorchでTensor配列の最大値・最小値となる要素のインデックスを取得するには、torch.argmax、torch.argminを使う。
torch.argmaxの使い方
torch.argmax()の引数にTenasor配列を指定すると最大値のインデックスが返ってくる。多次元配列の場合は、平坦化(torch.flatten())された配列のインデックスが返ってくる。
import torch
a = torch.tensor([1, 3, 2])
print(torch.argmax(a))
# tensor(1)
a = torch.tensor([[0, 3, 8],
[6, 1, 4],
[5, 7, 2]])
print(torch.argmax(a))
# tensor(2)
print(torch.flatten(a))
# tensor([0, 3, 8, 6, 1, 4, 5, 7, 2])
第二引数dimを指定すると、各軸に沿って最大値のインデックスが返ってくる。
print(torch.argmax(a, dim=0))
# tensor([1, 2, 0])
print(torch.argmax(a, dim=1))
# tensor([2, 0, 1])
torch.maxでの取得方法
torch.max()でも、引数dimを指定している場合は、最大値のインデックスがindicesとして返ってくる。以下の場合、indicesを取り出すとtorch.argmax()と同様な結果になることが分かる。
print(torch.max(a,dim=0))
# torch.return_types.max(values=tensor([6, 7, 8]),indices=tensor([1, 2, 0]))
print(torch.max(a,dim=0).indices)
# tensor([1, 2, 0])
torch.argminの使い方
torch.argmin()の引数にTenasor配列を指定すると最小値のインデックスが返ってくる。多次元配列の場合は、平坦化(torch.flatten())された配列のインデックスが返ってくる。
print(torch.argmin(a))
# tensor(0)
print(torch.flatten(a))
# tensor([0, 3, 8, 6, 1, 4, 5, 7, 2])
第二引数dimを指定すると、各軸に沿って最小値のインデックスが返ってくる。
print(torch.argmin(a, dim=0))
# tensor([0, 1, 2])
print(torch.argmin(a, dim=1))
# tensor([0, 1, 2])
torch.minでの取得方法
torch.min()でも、引数dimを指定している場合は、最小値のインデックスがindicesとして返ってくる。以下の場合、indicesを取り出すとtorch.argmin()と同様な結果になることが分かる。
print(torch.min(a,dim=0))
# torch.return_types.min(values=tensor([0, 1, 2]),indices=tensor([0, 1, 2]))
print(torch.min(a,dim=0).indices)
# tensor([1, 2, 0])
コメント