スポンサーリンク

PyTorchでデータを分割するtorch.utils.data.random_split【学習、検証】

Machine Learning

PyTorchで、データセットを学習と検証データに分割したいときには、torch.utils.data.random_splitを使う。

torch.utils.data.random_split

第一引数datasetに分割するデータセット、第二引数lengthsに分割後のデータ数、第三引数generatorはオプションでデータの分割に再現性を持たせたい場合は指定する。

import torch

data = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])

train, val = torch.utils.data.random_split(dataset=data, lengths=[7, 3], generator=torch.Generator().manual_seed(42))

print(list(train))
# [tensor(2), tensor(6), tensor(1), tensor(8), tensor(4), tensor(5), tensor(0)]
print(list(val))
# [tensor(9), tensor(3), tensor(7)]

データは3つ以上に分割することもでき、第二引数lengthsにリストで指定する。

train, val, test = torch.utils.data.random_split(dataset=data, lengths=[6, 3, 1], generator=torch.Generator().manual_seed(42))

print(list(train))
# [tensor(2), tensor(6), tensor(1), tensor(8), tensor(4), tensor(5)]
print(list(val))
# [tensor(0), tensor(9), tensor(3)]
print(list(test))
# [tensor(7)]

関連記事、関連資料

コメント