PyTorchで、データセットを学習と検証データに分割したいときには、torch.utils.data.random_splitを使う。
torch.utils.data.random_split
第一引数dataset
に分割するデータセット、第二引数lengths
に分割後のデータ数、第三引数generator
はオプションでデータの分割に再現性を持たせたい場合は指定する。
import torch
data = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
train, val = torch.utils.data.random_split(dataset=data, lengths=[7, 3], generator=torch.Generator().manual_seed(42))
print(list(train))
# [tensor(2), tensor(6), tensor(1), tensor(8), tensor(4), tensor(5), tensor(0)]
print(list(val))
# [tensor(9), tensor(3), tensor(7)]
データは3つ以上に分割することもでき、第二引数lengths
にリストで指定する。
train, val, test = torch.utils.data.random_split(dataset=data, lengths=[6, 3, 1], generator=torch.Generator().manual_seed(42))
print(list(train))
# [tensor(2), tensor(6), tensor(1), tensor(8), tensor(4), tensor(5)]
print(list(val))
# [tensor(0), tensor(9), tensor(3)]
print(list(test))
# [tensor(7)]
コメント