Tensor配列のサイズ1の次元を削除するにはtorch.squeeze、指定した位置にサイズ1の次元を挿入するにはtorch.unsqueezeを使用する。
torch.squeezeの使い方
サイズ1の次元を削除するにはtorch.squeezeを用いる。torch.squeezeにTensor配列を指定することで、サイズ1の次元が削除されたTensor配列が生成される。
import torch
x = torch.zeros(2, 1, 2, 1, 2)
print(x.size())
# torch.Size([2, 1, 2, 1, 2])
y = torch.squeeze(x)
print(y.size())
# torch.Size([2, 2, 2])
引数dimを指定することで、指定した次元のみ削除することができる。ただし、指定した次元のサイズが1でない場合は削除されない。
y = torch.squeeze(x, dim=1)
print(y.size())
# torch.Size([2, 2, 1, 2])
y = torch.squeeze(x, dim=3)
print(y.size())
# torch.Size([2, 1, 2, 2])
y = torch.squeeze(x, dim=0)
print(y.size())
# torch.Size([2, 1, 2, 1, 2])
torch.unsqueezeの使い方
指定した位置にサイズ1の次元を挿入するにはtorch.unsqueezeを用いる。torch.unsqueezeにTensor配列を指定し、引数dimに次元を挿入する位置を指定する。一番後ろに次元を追加したい場合は、dimに-1を指定する
。
x = torch.zeros(2, 3, 4, 5, 6)
print(x.size())
# torch.Size([2, 3, 4, 5, 6])
print(torch.unsqueeze(x, dim=0).size())
# torch.Size([1, 2, 3, 4, 5, 6])
print(torch.unsqueeze(x, dim=-1).size())
# torch.Size([2, 3, 4, 5, 6, 1])
関連記事、参考資料
PyTorch公式からも紹介されている本で、Tensor配列の基礎的な内容から実際の画像データを用いた実践的な内容まで網羅しています。
コメント