PyTorchのTensor配列の次元を変える関数の一つ、torch.viewをみていきます。
torch.viewの使い方
4行4列のTensor配列をtorch.viewによって変換していきます。
import torch
x = torch.tensor([[ 0., 1., 2., 3.],
[4., 5., 6., 7.],
[8., 9., 10., 11.],
[12., 13., 14., 15.]])
print(x.shape)
# torch.Size([4, 4])
基本的な使い方
torch.viewの引数に変換後の配列の形状をしていします。ただし、入力配列の要素数と出力配列の要素数を一致させる必要があり、一致していない場合はRuntimeErrorが発生します。
x_view = x.view(2, 8)
print(x_view.shape)
# torch.Size([2, 8])
x_view = x.view(16, 1)
print(x_view.shape)
# torch.Size([16, 1])
x_view = x.view(2, 2, 4)
print(x_view.shape)
# torch.Size([2, 2, 4])
x_view = x.view(5, 2, 4)
print(x_view.shape)
# RuntimeError: shape '[5, 2, 4]' is invalid for input of size 16
第一引数に-1を指定する使い方
第一引数に-1を指定すると、第二引数以降に指定された値から、自動的に出力配列の形を決めます。
x_view = x.view(-1)
print(x_view.shape)
# torch.Size([16])
x_view = x.view(-1, 1)
print(x_view.shape)
# torch.Size([16, 1])
x_view = x.view(-1, 2)
print(x_view.shape)
# torch.Size([8, 2])
x_view = x.view(-1, 2, 2)
print(x_view.shape)
# torch.Size([4, 2, 2])
RuntimeError: view size is not compatible with input tensor’s size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(…) instead.が発生した場合
torch.viewで変換する場合は、入力配列の値がメモリ上に順番に並んでいないと以下のようにエラーが発生します。
x.T.view(16, 1)
# RuntimeError: view size is not compatible with input tensor's size and stride
# (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.
この場合は、contiguousを使用して連続したメモリにデータを置くか、reshapeを使う方法があります。
x_view = x.T.contiguous().view(16, 1)
print(x_view.shape)
# torch.Size([16, 1])
x_view = x.T.reshape(16, 1)
print(x_view.shape)
# torch.Size([16, 1])
関連記事、参考記事
PyTorch公式からも紹介されている本で、Tensor配列の基礎的な内容から実際の画像データを用いた実践的な内容まで網羅しています。
コメント