PyTorchで最頻値を算出するにはtorch.modeを使う。
torch.modeの使い方
torch.modeの第一引数inputにTensor配列を指定すると、出力にタプル[値、インデックス]が返ってくる。下記のようなTensor配列aを入力すると、valuesに最頻値である4、indicesに最頻値4のインデックスがTensor配列で返ってくる。.valuesで最頻値を値として取り出すことができる。
import torch
a = torch.tensor([1, 2, 3, 3, 4, 4, 4, 4])
print(a.shape)
# torch.Size([8])
print(torch.mode(input=a))
# torch.return_types.mode(values=tensor(4),indices=tensor(7))
print(torch.mode(input=a).values)
# tensor(4)
二次元配列以上の場合は、第二引数dimを指定することで、各軸に沿った最頻値を算出することができる。
b = torch.tensor([[1, 2, 3, 3, 4, 4, 4, 4],
[1, 1, 2, 3, 4, 5, 6, 7]])
print(b.shape)
# torch.Size([2, 8])
print(torch.mode(input=b, dim=0))
# torch.return_types.mode(values=tensor([1, 1, 2, 3, 4, 4, 4, 4]),indices=tensor([1, 1, 1, 1, 1, 0, 0, 0]))
print(torch.mode(input=b, dim=1))
# torch.return_types.mode(values=tensor([4, 1]),indices=tensor([7, 1]))
コメント