スポンサーリンク

PyTochで最も近い整数に丸めるtorch.round

Python

PyTorchで最も近い整数に丸めるにはtorch.roundを使う。

torch.roundの使い方

torch.roundの第一引数inputにTensor配列、第二引数decimalsに丸める小数点の桁数を指定する。

import torch

a = torch.tensor((4.31, -2.24, 9.15, -6.71))

print(torch.round(a, decimals=0))
# tensor([ 4., -2.,  9., -7.])
print(torch.round(a, decimals=1))
# tensor([ 4.3000, -2.2000,  9.2000, -6.7000])
print(torch.round(a, decimals=2))
# tensor([ 4.3100, -2.2400,  9.1500, -6.7100])

torch.roundは、数値が二つの整数から等距離の場合は、偶数への丸めであり、四捨五入ではないので注意する必要がある。

b = torch.tensor((0.5, -1.5, 1.5, 2.5))

print(torch.round(b, decimals=0))
# tensor([ 0., -2.,  2.,  2.])

関連記事、参考資料

コメント