スポンサーリンク

【Pytorch】Segmentation Models Pytorchの基本的な使い方

Machine Learning

Pytorchで、Segmentationのモデルを簡単に使える「Segmentation Models Pytorch」の使い方を見ていく。Kaggleでもよく使われている。

Segmentation Models Pytorchの使い方

インストール

PyPIでインストールできる。

pip install -U segmentation-models-pytorch

ソースから最新バージョンをインストールしたい場合は、下記を実行する。

pip install -U git+https://github.com/qubvel/segmentation_models.pytorch

モデルの作成

モデルの作成は簡単で、ライブラリのインポートをすれば下記コードでできる。

model = smp.Unet(
    encoder_name="resnet34",        
    encoder_weights="imagenet",     
    in_channels=1,                  
    classes=3,                      
)

smp.XXXに扱うモデル名、encoder_nameに特徴量を抽出するためのエンコーダー(バックボーン)、encoder_weightsに事前学習済みの重み(使わない場合は’None’)、in_channelsに入力のチャンネル数、classesに出力マスクのクラス数を指定する。他にも出力にアクティベーション関数を適用するactivation、デコーダーにアテンションモジュールを適用するdecoder_attention_typeなどがある。
選択できるモデルは下記で、モデルごとに異なるオプションが存在するため、リンク先を参照して設定する。
Unet
Unet++
MAnet
Linknet
FPN
PSPNet
PAN
DeepLabV3
DeepLabV3+

バックボーンで使用できるモデル

encoder_nameencoder_weightsで指定できるモデルと事前学習済みの重みは、下記リンク先から選択できる。

ロス、メトリックス

セグメーテンションの基本的なlossやMetricsも揃っており便利そう。

参考資料、おすすめ書籍

コメント