スポンサーリンク

【PyTorch】既存の次元(軸)に沿って結合するtorch.cat

Python

PyTorchのTensor配列を既存の次元(軸)に沿って結合するにはtorch.catを使う。

torch.catの使い方

以下のTensor配列を用いる。

import torch

a = torch.ones(2, 3)
print(a)
print(a.size())
# tensor([[1., 1., 1.],
#         [1., 1., 1.]])
# torch.Size([2, 3])

b = 2*torch.ones(2, 3)
print(b)
print(b.size())
# tensor([[2., 2., 2.],
#         [2., 2., 2.]])
# torch.Size([2, 3])

torch.catの第一引数に結合するTenosor配列、第二引数dimに結合する次元を指定する。dimに存在しない次元を指定するとIndexErrorが発生し、新しいTensor配列は生成されない。

a_b = torch.cat((a, b), dim=0)
print(a_b)
print(a_b.size())
# tensor([[1., 1., 1.],
#         [1., 1., 1.],
#         [2., 2., 2.],
#         [2., 2., 2.]])
# torch.Size([4, 3])

a_b = torch.cat((a, b), dim=1)
print(a_b)
print(a_b.size())
# tensor([[1., 1., 1., 2., 2., 2.],
#         [1., 1., 1., 2., 2., 2.]])
# torch.Size([2, 6])

a_b = torch.cat((a, b), dim=2)
# IndexError: Dimension out of range (expected to be in range of [-2, 1], but got 2)

Tensor配列の結合はいくつでも行うことができる。4つ結合する場合は以下のように書く。

a_b = torch.cat((a, b, a, b), dim=0)
print(a_b)
print(a_b.size())
# tensor([[1., 1., 1.],
#         [1., 1., 1.],
#         [2., 2., 2.],
#         [2., 2., 2.],
#         [1., 1., 1.],
#         [1., 1., 1.],
#         [2., 2., 2.],
#         [2., 2., 2.]])
# torch.Size([8, 3])

結合する次元(軸)以外の軸のサイズが一致していない場合はRuntimeErrorが発生し、新しいTensor配列は生成されない。

c = 3*torch.ones(3, 4)

a_b_c = torch.cat((a, b, c), dim=0)
# RuntimeError: Sizes of tensors must match except in dimension 0. Expected size 3 but got size 4 for tensor number 2 in the list.

関連記事、参考資料

PyTorch公式からも紹介されている本で、Tensor配列の基礎的な内容から実際の画像データを用いた実践的な内容まで網羅しています。

コメント