PyTorchでTensor配列の最大値、最小値を取得するためには、torch.maxとtorch.minを使います。
torch.max
torch.maxはTensor配列の最大値を取得します。まず対象となるTensor配列を、3行4列で定義します。
import torch
a = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12],])
print(a)
# tensor([[ 1, 2, 3, 4],
# [ 5, 6, 7, 8],
# [ 9, 10, 11, 12]])
print(a.shape)
# torch.Size([3, 4])
オプションを使わずにtorch.maxを使うと、Tensor配列全体の最大値が取得できます。
print(torch.max(a))
# tensor(12)
引数dimを指定することで、次元を指定して最大値を取得できます。次元を指定した場合は、最大値と最大値のインデックスがtupleで返ってくるため、値を抜き出す場合はvaluesを使います。
print(torch.max(a,dim=0))
# torch.return_types.max(values=tensor([ 9, 10, 11, 12]),indices=tensor([2, 2, 2, 2]))
print(torch.max(a,dim=0).values)
# tensor([ 9, 10, 11, 12])
print(torch.max(a,dim=1))
# torch.return_types.max(values=tensor([ 4, 8, 12]),indices=tensor([3, 3, 3]))
print(torch.max(a,dim=1).values)
# tensor([ 4, 8, 12])
torch.min
torch.minはTensor配列の最小値を取得します。
print(torch.min(a))
# tensor(1)
torch.maxと同様に、引数dimを指定することで、次元を指定して最小値を取得できます。次元を指定した場合は、最小値と最小値のインデックスがtupleで返ってくるため、値を抜き出す場合はvaluesを使います。
print(torch.min(a,dim=0))
# torch.return_types.min(values=tensor([1, 2, 3, 4]),indices=tensor([0, 0, 0, 0]))
print(torch.min(a,dim=0).values)
# tensor([1, 2, 3, 4])
print(torch.min(a,dim=1))
# torch.return_types.min(values=tensor([1, 5, 9]),indices=tensor([0, 0, 0]))
print(torch.min(a,dim=1).values)
# tensor([1, 5, 9])
関連記事、参考資料
- 関連記事 – 2つのTensor配列の要素ごとの最大値・最小値を取得するtorch.maximum、torch.minimum、torch.fmax、torch.fmin【PyTorch】
- PyTorch 記事一覧
PyTorch公式ページでも紹介されていた本で、「Tensorの仕組み」から「ディープラーニングの実践プロジェクト:肺がんの早期発見」までステップバイステップで説明しているため、中身をよく理解できます。
コメント