PyTorchでTensor配列の符号を取得するには、torch.sign()を使う。
torch.signの使い方
torch.signの引数にTensor配列を指定する。負の値は-1、正の値は1、0は0が返ってくる。データ型は引数に指定したTenor配列と同じ型になる。
import torch
a = torch.tensor([0.7, -1.2, 0., 2.3])
print(torch.sign(a))
# tensor([ 1., -1., 0., 1.])
print(a.dtype)
# torch.float32
a = torch.tensor([0, -1, 0, 2])
print(torch.sign(a))
# tensor([ 0, -1, 0, 1])
print(a.dtype)
# torch.int64
負の0、無限、負の無限、nanの場合
-0は0、無限torch.infは1、負の無限-torch.infは-1、欠損値torch.nanは0となる。
a = torch.tensor([0, -0, torch.inf, -torch.inf, torch.nan])
print(torch.sign(a))
# tensor([ 0., 0., 1., -1., 0.])
関連記事、参考資料
PyTorch公式ページでも紹介されていた本で、「Tensorの仕組み」から「ディープラーニングの実践プロジェクト:肺がんの早期発見」までステップバイステップで説明しているため、中身をよく理解できます。
コメント