Pytorchで、Segmentationのモデルを簡単に使える「Segmentation Models Pytorch」の使い方を見ていく。Kaggleでもよく使われている。
Segmentation Models Pytorchの使い方
インストール
PyPIでインストールできる。
pip install -U segmentation-models-pytorch
ソースから最新バージョンをインストールしたい場合は、下記を実行する。
pip install -U git+https://github.com/qubvel/segmentation_models.pytorch
モデルの作成
モデルの作成は簡単で、ライブラリのインポートをすれば下記コードでできる。
model = smp.Unet(
encoder_name="resnet34",
encoder_weights="imagenet",
in_channels=1,
classes=3,
)
smp.XXXに扱うモデル名、encoder_nameに特徴量を抽出するためのエンコーダー(バックボーン)、encoder_weightsに事前学習済みの重み(使わない場合は’None’)、in_channelsに入力のチャンネル数、classesに出力マスクのクラス数を指定する。他にも出力にアクティベーション関数を適用するactivation、デコーダーにアテンションモジュールを適用するdecoder_attention_typeなどがある。
選択できるモデルは下記で、モデルごとに異なるオプションが存在するため、リンク先を参照して設定する。
— Unet
— Unet++
— MAnet
— Linknet
— FPN
— PSPNet
— PAN
— DeepLabV3
— DeepLabV3+
バックボーンで使用できるモデル
encoder_nameとencoder_weightsで指定できるモデルと事前学習済みの重みは、下記リンク先から選択できる。
ロス、メトリックス
セグメーテンションの基本的なlossやMetricsも揃っており便利そう。
参考資料、おすすめ書籍
- PyTorch実践入門 ~ ディープラーニングの基礎から実装へ
PyTorchの公式でも紹介されていた書籍で基本的な内容から応用まで学べ、第13章では医療画像を用いたセグメンテーションが対象となっている。 - つくりながら学ぶ! PyTorchによる発展ディープラーニング
ディープラーニングの発展・応用手法を実装しながら学習していく書籍で、第3章でセマンティックセグメンテーション(PSPNet):ピクセルレベルで画像内の物体を検出がある。
コメント