スポンサーリンク

【PyTorch】torchvision.modelsでResNetやEfficientNetの読み込みと分類クラス数の変更、ファインチューニングへの活用

PyTorch

PyTorchのtorchvision.modelsを用いることで、ResNetやEfficientNetなどの有名なモデルを簡単に使うことができ、ファインチューニングなどに利用できます。

torchvision.modelsの使い方

ResNet50の読み込み

ResNet50のモデルの読み込みは、models.resnet50で行うことができます。引数pretrainedは、学習済み重みを利用するならTrueにし、重みは利用せず構造のみの場合はFalseを指定します。

import torch
import torchvision.models as models

model_resnet50 = models.resnet50(pretrained=True)
print(model_resnet50)
# ResNet(
#   (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
#   (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
#   (relu): ReLU(inplace=True)
#   (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
#   (layer1): Sequential(
#                                                 |
#                                                 |
#                                                 | 省略
#                                                 |
#                                                 |  
#     (1): Bottleneck(
#       (conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
#       (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, 
#     (2): Bottleneck(
#       (conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
#       (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
#       (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
#       (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
#       (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
#       (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
#       (relu): ReLU(inplace=True)
#     )
#   )
#   (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
#   (fc): Linear(in_features=2048, out_features=1000, bias=True)
# )

printでモデル構造を確認すると、最後の層(fc)out_features=1000となっているので、このモデルは1000クラスの分類ができることが分かります。

ResNet50の出力クラス数を変更する

ResNet50を読み込んでも、1000クラスの分類なので、そのまま自身のクラス分類に利用できません。そのため、モデルの読み込み後に最終層を修正する必要があります。ここでは、10クラス分類に変更します。先ほど確認したように、最終層の名前はfcだったので、以下コードのように、fcを新しく定義した層で入れ替えるだけです。

model_resnet50.fc = torch.nn.Linear(model_resnet50.fc.in_features, 10)
print(model_resnet50)
#                                                 |
#                                                 | 省略
#                                                 |
#   )
#   (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
#  (fc): Linear(in_features=2048, out_features=10, bias=True)
# )

あとはこのモデルを、いつも通り学習させれば転移学習ができます。

AlexNet、GoogLeNet、MobileNetV3、EfficientNetの読み込み

AlexNet、GoogLeNet、MobileNetV3、EfficientNetなど他の有名なモデルも同じように使うことができます。

# AlexNet
alexnet = models.alexnet()
# GoogLeNet
googlenet = models.googlenet()
# MobileNetV3
mobilenet_v3_large = models.mobilenet_v3_large()
# EfficientNet
efficientnet_b7 = models.efficientnet_b7()

参考資料

第11章の事前学習済みモデルの利用で、丁寧に説明されており分かりやすいのでオススメです。

コメント