スポンサーリンク

PyTorchで最頻値を算出するtorch.mode

Python

PyTorchで最頻値を算出するにはtorch.modeを使う。

torch.modeの使い方

torch.modeの第一引数inputにTensor配列を指定すると、出力にタプル[値、インデックス]が返ってくる。下記のようなTensor配列aを入力すると、valuesに最頻値である4、indicesに最頻値4のインデックスがTensor配列で返ってくる。.valuesで最頻値を値として取り出すことができる。

import torch

a = torch.tensor([1, 2, 3, 3, 4, 4, 4, 4])
print(a.shape)
# torch.Size([8])

print(torch.mode(input=a))
# torch.return_types.mode(values=tensor(4),indices=tensor(7))

print(torch.mode(input=a).values)
# tensor(4)

二次元配列以上の場合は、第二引数dimを指定することで、各軸に沿った最頻値を算出することができる。

b = torch.tensor([[1, 2, 3, 3, 4, 4, 4, 4],
                  [1, 1, 2, 3, 4, 5, 6, 7]])

print(b.shape)
# torch.Size([2, 8])

print(torch.mode(input=b, dim=0))
# torch.return_types.mode(values=tensor([1, 1, 2, 3, 4, 4, 4, 4]),indices=tensor([1, 1, 1, 1, 1, 0, 0, 0]))

print(torch.mode(input=b, dim=1))
# torch.return_types.mode(values=tensor([4, 1]),indices=tensor([7, 1]))

関連記事、参考資料

コメント