PyTorchのtorchvision.modelsを用いることで、ResNetやEfficientNetなどの有名なモデルを簡単に使うことができ、ファインチューニングなどに利用できます。
torchvision.modelsの使い方
ResNet50の読み込み
ResNet50のモデルの読み込みは、models.resnet50
で行うことができます。引数pretrained
は、学習済み重みを利用するならTrue
にし、重みは利用せず構造のみの場合はFalse
を指定します。
import torch
import torchvision.models as models
model_resnet50 = models.resnet50(pretrained=True)
print(model_resnet50)
# ResNet(
# (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
# (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
# (relu): ReLU(inplace=True)
# (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
# (layer1): Sequential(
# |
# |
# | 省略
# |
# |
# (1): Bottleneck(
# (conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
# (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True,
# (2): Bottleneck(
# (conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
# (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
# (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
# (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
# (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
# (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
# (relu): ReLU(inplace=True)
# )
# )
# (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
# (fc): Linear(in_features=2048, out_features=1000, bias=True)
# )
printでモデル構造を確認すると、最後の層(fc)
でout_features=1000
となっているので、このモデルは1000クラスの分類ができることが分かります。
ResNet50の出力クラス数を変更する
ResNet50を読み込んでも、1000クラスの分類なので、そのまま自身のクラス分類に利用できません。そのため、モデルの読み込み後に最終層を修正する必要があります。ここでは、10クラス分類に変更します。先ほど確認したように、最終層の名前はfcだったので、以下コードのように、fcを新しく定義した層で入れ替えるだけです。
model_resnet50.fc = torch.nn.Linear(model_resnet50.fc.in_features, 10)
print(model_resnet50)
# |
# | 省略
# |
# )
# (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
# (fc): Linear(in_features=2048, out_features=10, bias=True)
# )
あとはこのモデルを、いつも通り学習させれば転移学習ができます。
AlexNet、GoogLeNet、MobileNetV3、EfficientNetの読み込み
AlexNet、GoogLeNet、MobileNetV3、EfficientNetなど他の有名なモデルも同じように使うことができます。
# AlexNet
alexnet = models.alexnet()
# GoogLeNet
googlenet = models.googlenet()
# MobileNetV3
mobilenet_v3_large = models.mobilenet_v3_large()
# EfficientNet
efficientnet_b7 = models.efficientnet_b7()
参考資料
第11章の事前学習済みモデルの利用で、丁寧に説明されており分かりやすいのでオススメです。
コメント