スポンサーリンク

【PyTorch】モデル構造を可視化するtorchinfo

Machine Learning

PyTorchのモデル構造を可視化するライブラリtorchinfoについて見ていきます。このライブラリを用いるとTensorflow/Kerasのmodel.summary()のようにモデルを表示することができます。

今回、torchvision.modelsからresnet18を読み込み、可視化するモデルとします。

import torch
import torchvision.models as models

model_resnet18 = models.resnet18()

torchinfoの使い方

pipかcondaのいずれかでインストールします。

pip install torchinfo
conda install -c conda-forge torchinfo

torchinfoのsummary関数を使うことで、モデルを表示します。第一引数modelにPyTorchモデルを、第二引数input_sizeに、バッチ数と入力サイズを指定します。

from torchinfo import summary

batch_size = 16
summary(model=model_resnet18, input_size=(batch_size, 3, 16, 16))
=================================================
Layer (type:depth-idx)                   Output Shape              Param #
=================================================
ResNet                                   --                        --
├─Conv2d: 1-1                            [16, 64, 8, 8]            9,408
├─BatchNorm2d: 1-2                       [16, 64, 8, 8]            128
├─ReLU: 1-3                              [16, 64, 8, 8]            --
├─MaxPool2d: 1-4                         [16, 64, 4, 4]            --
├─Sequential: 1-5                        [16, 64, 4, 4]            --
│    └─BasicBlock: 2-1                   [16, 64, 4, 4]            --
│    │    └─Conv2d: 3-1                  [16, 64, 4, 4]            36,864
│    │    └─BatchNorm2d: 3-2             [16, 64, 4, 4]            128
│    │    └─ReLU: 3-3                    [16, 64, 4, 4]            --
│    │    └─Conv2d: 3-4                  [16, 64, 4, 4]            36,864
│    │    └─BatchNorm2d: 3-5             [16, 64, 4, 4]            128
│    │    └─ReLU: 3-6                    [16, 64, 4, 4]            --
│    └─BasicBlock: 2-2                   [16, 64, 4, 4]            --
│    │    └─Conv2d: 3-7                  [16, 64, 4, 4]            36,864
│    │    └─BatchNorm2d: 3-8             [16, 64, 4, 4]            128
│    │    └─ReLU: 3-9                    [16, 64, 4, 4]            --
│    │    └─Conv2d: 3-10                 [16, 64, 4, 4]            36,864
│    │    └─BatchNorm2d: 3-11            [16, 64, 4, 4]            128
│    │    └─ReLU: 3-12                   [16, 64, 4, 4]            --
├─Sequential: 1-6                        [16, 128, 2, 2]           --
│    └─BasicBlock: 2-3                   [16, 128, 2, 2]           --
│    │    └─Conv2d: 3-13                 [16, 128, 2, 2]           73,728
│    │    └─BatchNorm2d: 3-14            [16, 128, 2, 2]           256
│    │    └─ReLU: 3-15                   [16, 128, 2, 2]           --
│    │    └─Conv2d: 3-16                 [16, 128, 2, 2]           147,456
│    │    └─BatchNorm2d: 3-17            [16, 128, 2, 2]           256
│    │    └─Sequential: 3-18             [16, 128, 2, 2]           8,448
│    │    └─ReLU: 3-19                   [16, 128, 2, 2]           --
│    └─BasicBlock: 2-4                   [16, 128, 2, 2]           --
│    │    └─Conv2d: 3-20                 [16, 128, 2, 2]           147,456
│    │    └─BatchNorm2d: 3-21            [16, 128, 2, 2]           256
│    │    └─ReLU: 3-22                   [16, 128, 2, 2]           --
│    │    └─Conv2d: 3-23                 [16, 128, 2, 2]           147,456
│    │    └─BatchNorm2d: 3-24            [16, 128, 2, 2]           256
│    │    └─ReLU: 3-25                   [16, 128, 2, 2]           --
├─Sequential: 1-7                        [16, 256, 1, 1]           --
│    └─BasicBlock: 2-5                   [16, 256, 1, 1]           --
│    │    └─Conv2d: 3-26                 [16, 256, 1, 1]           294,912
│    │    └─BatchNorm2d: 3-27            [16, 256, 1, 1]           512
│    │    └─ReLU: 3-28                   [16, 256, 1, 1]           --
│    │    └─Conv2d: 3-29                 [16, 256, 1, 1]           589,824
│    │    └─BatchNorm2d: 3-30            [16, 256, 1, 1]           512
│    │    └─Sequential: 3-31             [16, 256, 1, 1]           33,280
│    │    └─ReLU: 3-32                   [16, 256, 1, 1]           --
│    └─BasicBlock: 2-6                   [16, 256, 1, 1]           --
│    │    └─Conv2d: 3-33                 [16, 256, 1, 1]           589,824
│    │    └─BatchNorm2d: 3-34            [16, 256, 1, 1]           512
│    │    └─ReLU: 3-35                   [16, 256, 1, 1]           --
│    │    └─Conv2d: 3-36                 [16, 256, 1, 1]           589,824
│    │    └─BatchNorm2d: 3-37            [16, 256, 1, 1]           512
│    │    └─ReLU: 3-38                   [16, 256, 1, 1]           --
├─Sequential: 1-8                        [16, 512, 1, 1]           --
│    └─BasicBlock: 2-7                   [16, 512, 1, 1]           --
│    │    └─Conv2d: 3-39                 [16, 512, 1, 1]           1,179,648
│    │    └─BatchNorm2d: 3-40            [16, 512, 1, 1]           1,024
│    │    └─ReLU: 3-41                   [16, 512, 1, 1]           --
│    │    └─Conv2d: 3-42                 [16, 512, 1, 1]           2,359,296
│    │    └─BatchNorm2d: 3-43            [16, 512, 1, 1]           1,024
│    │    └─Sequential: 3-44             [16, 512, 1, 1]           132,096
│    │    └─ReLU: 3-45                   [16, 512, 1, 1]           --
│    └─BasicBlock: 2-8                   [16, 512, 1, 1]           --
│    │    └─Conv2d: 3-46                 [16, 512, 1, 1]           2,359,296
│    │    └─BatchNorm2d: 3-47            [16, 512, 1, 1]           1,024
│    │    └─ReLU: 3-48                   [16, 512, 1, 1]           --
│    │    └─Conv2d: 3-49                 [16, 512, 1, 1]           2,359,296
│    │    └─BatchNorm2d: 3-50            [16, 512, 1, 1]           1,024
│    │    └─ReLU: 3-51                   [16, 512, 1, 1]           --
├─AdaptiveAvgPool2d: 1-9                 [16, 512, 1, 1]           --
├─Linear: 1-10                           [16, 1000]                513,000
=================================================
Total params: 11,689,512
Trainable params: 11,689,512
Non-trainable params: 0
Total mult-adds (M): 257.07
=================================================
Input size (MB): 0.05
Forward/backward pass size (MB): 3.86
Params size (MB): 46.76
Estimated Total Size (MB): 50.67
=================================================

各レイヤーごとのパラメータ数、出力形状、Forward/backward時の計算容量なども表示してくれます。

以前は、Tensorflow/Kerasのmodel.summary()のように表示させるには、torchsummaryライブラリが使われていたようですが、こちらはもうアップデートされていないようで、torchinfoライブラリに吸収されたようです。

関連資料、関連記事

  • torchinfo
GitHub - TylerYep/torchinfo: View model summaries in PyTorch!
View model summaries in PyTorch! Contribute to TylerYep/torchinfo development by creating an account on GitHub.
  • torchsummary
GitHub - sksq96/pytorch-summary: Model summary in PyTorch similar to `model.summary()` in Keras
Model summary in PyTorch similar to `model.summary()` in Keras - sksq96/pytorch-summary

コメント