スポンサーリンク

【PyTorch】Tensor配列を反転するtorch.flip、torch.flipud、torch.fliplr

Python

PyTorchのTensor配列を上下左右反転させるにはtorch.flipを使う。他にも、上下方向に反転させるtorch.flipud、左右方向に反転させるtorch.fliplrがある。

上下反転させるtorch.flipudの使い方

Tensor配列を上下反転させるにはtorch.flipudの引数にTensor配列を指定する。

import torch

a2d = torch.tensor([[0, 1],
        [2, 3]])
print(a2d)
# tensor([[0, 1],
#         [2, 3]])

a2d_flipud = torch.flipud(a2d)
print(a2d_flipud)
# tensor([[2, 3],
#        [0, 1]])

1次元配列の場合は、上下反転ではなく左右反転になる。

a1d = torch.tensor([0, 1, 2, 3])
print(a1d)
# tensor([0, 1, 2, 3])

a1d_flipud = torch.flipud(a1d)
print(a1d_flipud)
# tensor([3, 2, 1, 0])

左右反転させるtorch.fliplrの使い方

Tensor配列を左右反転させるにはtorch.fliplrの引数にTensor配列を指定する。

a2d = torch.tensor([[0, 1],
        [2, 3]])
print(a2d)
# tensor([[0, 1],
#         [2, 3]])

a2d_fliplr = torch.fliplr(a2d)
print(a2d_fliplr)
# tensor([[1, 0],
#         [3, 2]])

torch.fliplrは2次元以上の配列に対応しているため、1次元配列を指定するとエラーがでる。

RuntimeError: Input must be >= 2-d.

a1d = torch.tensor([0, 1, 2, 3])
print(a1d)
# tensor([0, 1, 2, 3])

a1d_fliplr = torch.fliplr(a1d)
print(a1d_fliplr)
# RuntimeError: Input must be >= 2-d.

任意の軸で反転させるtorch.flipの使い方

Tensor配列を任意の軸で反転させるにはtorch.flipを使う。 第二引数のdimsに反転させる軸を指定する。dims=0torch.flipuddims=1torch.fliplrと同様の振る舞いをする。

a2d = torch.tensor([[0, 1],
        [2, 3]])
print(a2d)
# tensor([[0, 1],
#         [2, 3]])

a2d_flip = torch.flip(a2d, dims=[0, 1])
print(a2d_flip)
# tensor([[3, 2],
#        [1, 0]])

print(torch.flip(a2d, dims=[0]))
# tensor([[2, 3],
#         [0, 1]])

print(torch.flip(a2d, dims=[1]))
# tensor([[1, 0],
#         [3, 2]])

関連記事、参考資料

PyTorchの入門書で、GPUの利用方法、ネットワークの構築方法や転移学習まで幅広く書かれていてオススメです。

コメント