スポンサーリンク

【PyTorch】連番や等差数列を生成するtorch.arange、torch.linspace

Python

PyTorchで連番や等差数列を生成するには、torch.arangetorch.linspaceを使う。二つの違いは、torch.arangeが間隔を指定するのに対して、torch.linspaceは要素数を指定する。

torch.arangeの使い方

torch.arange()に、3つの引数に値を指定することで、連番や等差数列を生成する。第一引数startに開始値、第二引数endに終了値、第三引数stepに値の間隔を指定する。

import torch

print(torch.arange(0, 5, 1))
# tensor([0, 1, 2, 3, 4])

print(torch.arange(1, 2.5, 0.5))
# tensor([1.0000, 1.5000, 2.0000])

デフォルト値として、第一引数startに0、第三引数endに1が指定されているため、引数を1つだけ指定した場合は、第二引数を与えたことになり、他はデフォルト値が用いられる。

print(torch.arange(5)) 
# torch.arange(0, 5, 1)と同様の振る舞いをする
# tensor([0, 1, 2, 3, 4])

引数を2つ指定した場合は、第一引数startと第二引数を与えたことになる。

print(torch.arange(1,7)) # torch.arange(1, 7, 1)と同様の振る舞いをする
# tensor([1, 2, 3, 4, 5, 6])

torch.linspaceの使い方

torch.linspaceは、要素数を指定することで連番や等差数列を生成する。第一引数startに開始値、第二引数endに終了値、第三引数stepsに生成するTensor配列のサイズを指定する。開始値から終了値まででstepsに応じた等間隔の値を持つTensor配列が生成される。

a = torch.linspace(3, 10, steps=5)
print(a)
print(a.size())
# tensor([ 3.0000,  4.7500,  6.5000,  8.2500, 10.0000])
# torch.Size([5])

b = torch.linspace(-10, 10, steps=10)
print(b)
print(b.size())
# tensor([-10.0000,  -7.7778,  -5.5556,  -3.3333,  -1.1111,   1.1111,   3.3333,
#           5.5556,   7.7778,  10.0000])
# torch.Size([10])

c = torch.linspace(-10, 10, steps=1)
print(c)
print(c.size())
# tensor([-10.])
# torch.Size([1])

関連記事、参考資料

コメント