PyTorchでTensor配列の符号を取得するtorch.sign

Python

PyTorchでTensor配列の符号を取得するには、torch.sign()を使う。

スポンサーリンク

torch.signの使い方

torch.signの引数にTensor配列を指定する。負の値は-1、正の値は1、0は0が返ってくる。データ型は引数に指定したTenor配列と同じ型になる。

import torch

a = torch.tensor([0.7, -1.2, 0., 2.3])
print(torch.sign(a))
# tensor([ 1., -1.,  0.,  1.])
print(a.dtype)
# torch.float32

a = torch.tensor([0, -1, 0, 2])
print(torch.sign(a))
# tensor([ 0, -1,  0,  1])
print(a.dtype)
# torch.int64

負の0、無限、負の無限、nanの場合

-0は0、無限torch.infは1、負の無限-torch.infは-1、欠損値torch.nanは0となる。

a = torch.tensor([0, -0, torch.inf, -torch.inf, torch.nan])
print(torch.sign(a))
# tensor([ 0.,  0.,  1., -1.,  0.])
スポンサーリンク

関連記事、参考資料

PyTorch公式ページでも紹介されていた本で、「Tensorの仕組み」から「ディープラーニングの実践プロジェクト:肺がんの早期発見」までステップバイステップで説明しているため、中身をよく理解できます。

PyTorch実践入門 (Compass Booksシリーズ)

コメント