スポンサーリンク

【PyTorch】最大値、最小値を取得するtorch.max、torch.min

PyTorch

PyTorchでTensor配列の最大値、最小値を取得するためには、torch.maxとtorch.minを使います。

torch.max

torch.maxはTensor配列の最大値を取得します。まず対象となるTensor配列を、3行4列で定義します。

import torch

a = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12],])
print(a)
# tensor([[ 1,  2,  3,  4],
#        [ 5,  6,  7,  8],
#        [ 9, 10, 11, 12]])

print(a.shape)
# torch.Size([3, 4])

オプションを使わずにtorch.maxを使うと、Tensor配列全体の最大値が取得できます。

print(torch.max(a))
# tensor(12)

引数dimを指定することで、次元を指定して最大値を取得できます。次元を指定した場合は、最大値と最大値のインデックスがtupleで返ってくるため、値を抜き出す場合はvaluesを使います。

print(torch.max(a,dim=0))
# torch.return_types.max(values=tensor([ 9, 10, 11, 12]),indices=tensor([2, 2, 2, 2]))

print(torch.max(a,dim=0).values)
# tensor([ 9, 10, 11, 12])

print(torch.max(a,dim=1))
# torch.return_types.max(values=tensor([ 4,  8, 12]),indices=tensor([3, 3, 3]))

print(torch.max(a,dim=1).values)
# tensor([ 4,  8, 12])

torch.min

torch.minはTensor配列の最小値を取得します。

print(torch.min(a))
# tensor(1)

torch.maxと同様に、引数dimを指定することで、次元を指定して最小値を取得できます。次元を指定した場合は、最小値と最小値のインデックスがtupleで返ってくるため、値を抜き出す場合はvaluesを使います。

print(torch.min(a,dim=0))
# torch.return_types.min(values=tensor([1, 2, 3, 4]),indices=tensor([0, 0, 0, 0]))

print(torch.min(a,dim=0).values)
# tensor([1, 2, 3, 4])

print(torch.min(a,dim=1))
# torch.return_types.min(values=tensor([1, 5, 9]),indices=tensor([0, 0, 0]))

print(torch.min(a,dim=1).values)
# tensor([1, 5, 9])

関連記事、参考資料

PyTorch公式ページでも紹介されていた本で、「Tensorの仕組み」から「ディープラーニングの実践プロジェクト:肺がんの早期発見」までステップバイステップで説明しているため、中身をよく理解できます。

PyTorch実践入門 (Compass Booksシリーズ)

コメント