スポンサーリンク

【PyTorch】分位数を算出するtorch.quantile

Python

PyTorchで分位数を算出するにはtorch.quantileを使う。箱ひげ図などにも用いられる四分位数もtorch.quantileを使うことで、算出できる。

torch.quantileの使い方

torch.quantileの第一引数inputにTensor配列、第二引数qに1次元のTensor配列で分位数を指定する。四分位数の場合は、qに[0.25, 0.5, 0.75]を指定する。

import torch

a = torch.tensor([0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0])


q = torch.tensor([0.05, 0.95])
print(torch.quantile(input=a, q=q))
# tensor([0.5000, 9.5000])

q = torch.tensor([0.25, 0.5, 0.75])
print(torch.quantile(input=a, q=q))
# tensor([2.5000, 5.0000, 7.5000])

入力Tensorの型は、floatかdoubleである必要があるため、intを指定するとRuntimeErrorが発生します。

a = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
print(torch.quantile(input=a, q=q))
# RuntimeError: quantile() input tensor must be either float or double dtype

引数interpolationを指定すると、目的の分位点が 2 つのデータポイントの間にある場合に補完方法を選ぶことができます。

print(torch.quantile(input=a, q=q, interpolation='linear'))
# tensor([2.5000, 5.0000, 7.5000])
print(torch.quantile(input=a, q=q, interpolation='lower'))
# tensor([2., 5., 7.])
print(torch.quantile(input=a, q=q, interpolation='higher'))
# tensor([3., 5., 8.]
print(torch.quantile(input=a, q=q, interpolation='midpoint'))
# tensor([2.5000, 5.0000, 7.5000])

関連記事、参考資料

コメント