PyTorchのTensor配列を既存の次元(軸)に沿って結合するにはtorch.catを使う。
torch.catの使い方
以下のTensor配列を用いる。
import torch
a = torch.ones(2, 3)
print(a)
print(a.size())
# tensor([[1., 1., 1.],
# [1., 1., 1.]])
# torch.Size([2, 3])
b = 2*torch.ones(2, 3)
print(b)
print(b.size())
# tensor([[2., 2., 2.],
# [2., 2., 2.]])
# torch.Size([2, 3])
torch.catの第一引数に結合するTenosor配列、第二引数dimに結合する次元を指定する。dimに存在しない次元を指定するとIndexErrorが発生し、新しいTensor配列は生成されない。
a_b = torch.cat((a, b), dim=0)
print(a_b)
print(a_b.size())
# tensor([[1., 1., 1.],
# [1., 1., 1.],
# [2., 2., 2.],
# [2., 2., 2.]])
# torch.Size([4, 3])
a_b = torch.cat((a, b), dim=1)
print(a_b)
print(a_b.size())
# tensor([[1., 1., 1., 2., 2., 2.],
# [1., 1., 1., 2., 2., 2.]])
# torch.Size([2, 6])
a_b = torch.cat((a, b), dim=2)
# IndexError: Dimension out of range (expected to be in range of [-2, 1], but got 2)
Tensor配列の結合はいくつでも行うことができる。4つ結合する場合は以下のように書く。
a_b = torch.cat((a, b, a, b), dim=0)
print(a_b)
print(a_b.size())
# tensor([[1., 1., 1.],
# [1., 1., 1.],
# [2., 2., 2.],
# [2., 2., 2.],
# [1., 1., 1.],
# [1., 1., 1.],
# [2., 2., 2.],
# [2., 2., 2.]])
# torch.Size([8, 3])
結合する次元(軸)以外の軸のサイズが一致していない場合はRuntimeErrorが発生し、新しいTensor配列は生成されない。
c = 3*torch.ones(3, 4)
a_b_c = torch.cat((a, b, c), dim=0)
# RuntimeError: Sizes of tensors must match except in dimension 0. Expected size 3 but got size 4 for tensor number 2 in the list.
関連記事、参考資料
PyTorch公式からも紹介されている本で、Tensor配列の基礎的な内容から実際の画像データを用いた実践的な内容まで網羅しています。
コメント