PyTorchで最も近い整数に丸めるにはtorch.roundを使う。
torch.roundの使い方
torch.roundの第一引数inputにTensor配列、第二引数decimalsに丸める小数点の桁数を指定する。
import torch
a = torch.tensor((4.31, -2.24, 9.15, -6.71))
print(torch.round(a, decimals=0))
# tensor([ 4., -2., 9., -7.])
print(torch.round(a, decimals=1))
# tensor([ 4.3000, -2.2000, 9.2000, -6.7000])
print(torch.round(a, decimals=2))
# tensor([ 4.3100, -2.2400, 9.1500, -6.7100])
torch.roundは、数値が二つの整数から等距離の場合は、偶数への丸めであり、四捨五入ではないので注意する必要がある。
b = torch.tensor((0.5, -1.5, 1.5, 2.5))
print(torch.round(b, decimals=0))
# tensor([ 0., -2., 2., 2.])
コメント