スポンサーリンク

【PyTorch】Tensor配列のサイズ1の次元を削除するtorch.squeeze、指定した位置にサイズ1の次元を挿入するtorch.unsqueeze

Python

Tensor配列のサイズ1の次元を削除するにはtorch.squeeze、指定した位置にサイズ1の次元を挿入するにはtorch.unsqueezeを使用する。

torch.squeezeの使い方

サイズ1の次元を削除するにはtorch.squeezeを用いる。torch.squeezeにTensor配列を指定することで、サイズ1の次元が削除されたTensor配列が生成される。

import torch

x = torch.zeros(2, 1, 2, 1, 2)
print(x.size())
# torch.Size([2, 1, 2, 1, 2])

y = torch.squeeze(x)
print(y.size())
# torch.Size([2, 2, 2])

引数dimを指定することで、指定した次元のみ削除することができる。ただし、指定した次元のサイズが1でない場合は削除されない。

y = torch.squeeze(x, dim=1)
print(y.size())
# torch.Size([2, 2, 1, 2])

y = torch.squeeze(x, dim=3)
print(y.size())
# torch.Size([2, 1, 2, 2])

y = torch.squeeze(x, dim=0)
print(y.size())
# torch.Size([2, 1, 2, 1, 2])

torch.unsqueezeの使い方

指定した位置にサイズ1の次元を挿入するにはtorch.unsqueezeを用いる。torch.unsqueezeにTensor配列を指定し、引数dimに次元を挿入する位置を指定する。一番後ろに次元を追加したい場合は、dimに-1を指定する

x = torch.zeros(2, 3, 4, 5, 6)
print(x.size())
# torch.Size([2, 3, 4, 5, 6])

print(torch.unsqueeze(x, dim=0).size())
# torch.Size([1, 2, 3, 4, 5, 6])

print(torch.unsqueeze(x, dim=-1).size())
# torch.Size([2, 3, 4, 5, 6, 1])

関連記事、参考資料

PyTorch公式からも紹介されている本で、Tensor配列の基礎的な内容から実際の画像データを用いた実践的な内容まで網羅しています。

コメント