PyTorchのモデル構造を可視化するライブラリtorchinfo
について見ていきます。このライブラリを用いるとTensorflow/Kerasのmodel.summary()
のようにモデルを表示することができます。
今回、torchvision.models
からresnet18
を読み込み、可視化するモデルとします。
import torch
import torchvision.models as models
model_resnet18 = models.resnet18()
- 関連記事 – 【PyTorch】torchvision.modelsでResNetやEfficientNetの読み込みと分類クラス数の変更、転移学習への活用
- torchvision.models – PyTorch documentation
torchinfoの使い方
pipかcondaのいずれかでインストールします。
pip install torchinfo
conda install -c conda-forge torchinfo
torchinfoのsummary関数を使うことで、モデルを表示します。第一引数model
にPyTorchモデルを、第二引数input_size
に、バッチ数と入力サイズを指定します。
from torchinfo import summary
batch_size = 16
summary(model=model_resnet18, input_size=(batch_size, 3, 16, 16))
=================================================
Layer (type:depth-idx) Output Shape Param #
=================================================
ResNet -- --
├─Conv2d: 1-1 [16, 64, 8, 8] 9,408
├─BatchNorm2d: 1-2 [16, 64, 8, 8] 128
├─ReLU: 1-3 [16, 64, 8, 8] --
├─MaxPool2d: 1-4 [16, 64, 4, 4] --
├─Sequential: 1-5 [16, 64, 4, 4] --
│ └─BasicBlock: 2-1 [16, 64, 4, 4] --
│ │ └─Conv2d: 3-1 [16, 64, 4, 4] 36,864
│ │ └─BatchNorm2d: 3-2 [16, 64, 4, 4] 128
│ │ └─ReLU: 3-3 [16, 64, 4, 4] --
│ │ └─Conv2d: 3-4 [16, 64, 4, 4] 36,864
│ │ └─BatchNorm2d: 3-5 [16, 64, 4, 4] 128
│ │ └─ReLU: 3-6 [16, 64, 4, 4] --
│ └─BasicBlock: 2-2 [16, 64, 4, 4] --
│ │ └─Conv2d: 3-7 [16, 64, 4, 4] 36,864
│ │ └─BatchNorm2d: 3-8 [16, 64, 4, 4] 128
│ │ └─ReLU: 3-9 [16, 64, 4, 4] --
│ │ └─Conv2d: 3-10 [16, 64, 4, 4] 36,864
│ │ └─BatchNorm2d: 3-11 [16, 64, 4, 4] 128
│ │ └─ReLU: 3-12 [16, 64, 4, 4] --
├─Sequential: 1-6 [16, 128, 2, 2] --
│ └─BasicBlock: 2-3 [16, 128, 2, 2] --
│ │ └─Conv2d: 3-13 [16, 128, 2, 2] 73,728
│ │ └─BatchNorm2d: 3-14 [16, 128, 2, 2] 256
│ │ └─ReLU: 3-15 [16, 128, 2, 2] --
│ │ └─Conv2d: 3-16 [16, 128, 2, 2] 147,456
│ │ └─BatchNorm2d: 3-17 [16, 128, 2, 2] 256
│ │ └─Sequential: 3-18 [16, 128, 2, 2] 8,448
│ │ └─ReLU: 3-19 [16, 128, 2, 2] --
│ └─BasicBlock: 2-4 [16, 128, 2, 2] --
│ │ └─Conv2d: 3-20 [16, 128, 2, 2] 147,456
│ │ └─BatchNorm2d: 3-21 [16, 128, 2, 2] 256
│ │ └─ReLU: 3-22 [16, 128, 2, 2] --
│ │ └─Conv2d: 3-23 [16, 128, 2, 2] 147,456
│ │ └─BatchNorm2d: 3-24 [16, 128, 2, 2] 256
│ │ └─ReLU: 3-25 [16, 128, 2, 2] --
├─Sequential: 1-7 [16, 256, 1, 1] --
│ └─BasicBlock: 2-5 [16, 256, 1, 1] --
│ │ └─Conv2d: 3-26 [16, 256, 1, 1] 294,912
│ │ └─BatchNorm2d: 3-27 [16, 256, 1, 1] 512
│ │ └─ReLU: 3-28 [16, 256, 1, 1] --
│ │ └─Conv2d: 3-29 [16, 256, 1, 1] 589,824
│ │ └─BatchNorm2d: 3-30 [16, 256, 1, 1] 512
│ │ └─Sequential: 3-31 [16, 256, 1, 1] 33,280
│ │ └─ReLU: 3-32 [16, 256, 1, 1] --
│ └─BasicBlock: 2-6 [16, 256, 1, 1] --
│ │ └─Conv2d: 3-33 [16, 256, 1, 1] 589,824
│ │ └─BatchNorm2d: 3-34 [16, 256, 1, 1] 512
│ │ └─ReLU: 3-35 [16, 256, 1, 1] --
│ │ └─Conv2d: 3-36 [16, 256, 1, 1] 589,824
│ │ └─BatchNorm2d: 3-37 [16, 256, 1, 1] 512
│ │ └─ReLU: 3-38 [16, 256, 1, 1] --
├─Sequential: 1-8 [16, 512, 1, 1] --
│ └─BasicBlock: 2-7 [16, 512, 1, 1] --
│ │ └─Conv2d: 3-39 [16, 512, 1, 1] 1,179,648
│ │ └─BatchNorm2d: 3-40 [16, 512, 1, 1] 1,024
│ │ └─ReLU: 3-41 [16, 512, 1, 1] --
│ │ └─Conv2d: 3-42 [16, 512, 1, 1] 2,359,296
│ │ └─BatchNorm2d: 3-43 [16, 512, 1, 1] 1,024
│ │ └─Sequential: 3-44 [16, 512, 1, 1] 132,096
│ │ └─ReLU: 3-45 [16, 512, 1, 1] --
│ └─BasicBlock: 2-8 [16, 512, 1, 1] --
│ │ └─Conv2d: 3-46 [16, 512, 1, 1] 2,359,296
│ │ └─BatchNorm2d: 3-47 [16, 512, 1, 1] 1,024
│ │ └─ReLU: 3-48 [16, 512, 1, 1] --
│ │ └─Conv2d: 3-49 [16, 512, 1, 1] 2,359,296
│ │ └─BatchNorm2d: 3-50 [16, 512, 1, 1] 1,024
│ │ └─ReLU: 3-51 [16, 512, 1, 1] --
├─AdaptiveAvgPool2d: 1-9 [16, 512, 1, 1] --
├─Linear: 1-10 [16, 1000] 513,000
=================================================
Total params: 11,689,512
Trainable params: 11,689,512
Non-trainable params: 0
Total mult-adds (M): 257.07
=================================================
Input size (MB): 0.05
Forward/backward pass size (MB): 3.86
Params size (MB): 46.76
Estimated Total Size (MB): 50.67
=================================================
各レイヤーごとのパラメータ数、出力形状、Forward/backward時の計算容量なども表示してくれます。
以前は、Tensorflow/Kerasのmodel.summary()
のように表示させるには、torchsummaryライブラリが使われていたようですが、こちらはもうアップデートされていないようで、torchinfoライブラリに吸収されたようです。
関連資料、関連記事
- torchinfo
GitHub - TylerYep/torchinfo: View model summaries in PyTorch!
View model summaries in PyTorch! Contribute to TylerYep/torchinfo development by creating an account on GitHub.
- torchsummary
GitHub - sksq96/pytorch-summary: Model summary in PyTorch similar to `model.summary()` in Keras
Model summary in PyTorch similar to `model.summary()` in Keras - sksq96/pytorch-summary
コメント