スポンサーリンク

【PyTorch】Tensor配列をシフト(移動)させるtorch.roll

Python

PyTorchでTensor配列をシフト(移動)させるtorch.rollを使う。

torch.rollの使い方

1次元配列の場合は、torch.rollの第一引数inputにTensor配列、第二引数shiftsにシフトさせる要素数を指定する。shiftsに3を指定した場合は、3番目の要素が先頭になり、0から2番目の要素は一番後ろにシフトされる。

import torch

a = torch.tensor([0, 1, 2, 3, 4, 5])
print(a)
# tensor([0, 1, 2, 3, 4, 5])

a_shift3 = torch.roll(input=a, shifts=3)
print(a_shift3)
# tensor([3, 4, 5, 0, 1, 2])

多次元配列の場合は、1次元化された要素でシフトされ、元の形状に戻される。

a_2d = torch.tensor([0, 1, 2, 3, 4, 5]).view(3, 2)
print(a_2d)
# tensor([[0, 1],
#        [2, 3],
#        [4, 5]])

a_2d_shift3 = torch.roll(input=a_2d, shifts=3)
print(a_2d_shift3)
# tensor([[3, 4],
#        [5, 0],
#        [1, 2]])

# flattenで1次元化した後に、viewで形状を戻した結果は、rollと同様になる
print(torch.roll(input=torch.flatten(a_2d), shifts=3).view(3, 2))
# tensor([[3, 4],
#        [5, 0],
#        [1, 2]])

第三引数dimsを指定すると、その軸に沿ってシフトされる。

print(torch.roll(input=a_2d, shifts=3, dims=0))
# tensor([[0, 1],
#         [2, 3],
#         [4, 5]])

print(torch.roll(input=a_2d, shifts=3, dims=1))
# ensor([[1, 0],
#        [3, 2],
#        [5, 4]])

第二引数shiftsと第三引数dimsに複数の値を指定すると、複数の軸に沿ってシフトさせることができる。

print(torch.roll(input=a_2d, shifts=(2,1), dims=(0,1)))
# tensor([[3, 2],
#         [5, 4],
#         [1, 0]])

関連記事、参考資料

PyTorch公式からも紹介されている本で、Tensor配列の基礎的な内容から実際の画像データを用いた実践的な内容まで網羅しています。

コメント