スポンサーリンク

【PyTorch reshape】Tensor配列の形状を変換するtorch.reshape

Python

PyTorchでTensor配列の形状を変換するには、torch.reshapeを使う。

torch.reshapeの使い方

torch.reshapeの第一引数に入力のTensor配列を、第二引数に出力配列の形状をタプルで指定します。入力配列の要素数と出力配列の要素数が一致しない場合は、RuntimeErrorが発生します。

import torch

a = torch.arange(4.)
print(a)
# tensor([0., 1., 2., 3.])
print(a.shape)
# torch.Size([4])

a_reshape = torch.reshape(a, (2, 2))
print(a_reshape)
# tensor([[0., 1.],
#        [2., 3.]])
print(a_reshape.shape)
# torch.Size([2, 2])

a_reshape = torch.reshape(a, (1, 4))
print(a_reshape)
# tensor([[0., 1., 2., 3.]])
print(a_reshape.shape)
# torch.Size([1, 4])

a_reshape = torch.reshape(a, (1, 5))
# RuntimeError: shape '[1, 5]' is invalid for input of size 4

第二引数のいずれかの次元に-1を指定した場合は、他の次元の指定した値に応じて自動で形状が設定されます。

a_reshape = torch.reshape(a, (-1, 2))
print(a_reshape)
# tensor([[0., 1.],
#        [2., 3.]])
print(a_reshape.shape)
# torch.Size([2, 2])

a_reshape = torch.reshape(a, (1, -1))
print(a_reshape)
# tensor([[0., 1., 2., 3.]])
print(a_reshape.shape)
# torch.Size([1, 4])

関連記事、参考資料

PyTorch公式からも紹介されている本で、Tensor配列の基礎的な内容から実際の画像データを用いた実践的な内容まで網羅しています。

コメント