FCNによる画像セマンティックセグメンテーション:実践的アプローチ

学習のポイント

学習には忍耐が必要です。コードの実行には少なくとも1時間かかります。モデルが自動で実行中は、他の作業に専念できますが、最終結果を見るには忍耐と心構えが求められます。

本日は実際の応用として画像セグメンテーションを学びます。セグメンテーションツールの一つであるFCNの使用方法を紹介し、操作実験を行います。実験手順に従えば、期待される画像処理結果が得られます。学習後には、このツールの実際の応用を習得できます。

FCNによる画像セマンティックセグメンテーション

全畳み込みネットワーク(Fully Convolutional Networks、FCN)は、UC BerkeleyのJonathan Longらが2015年の論文「Fully Convolutional Networks for Semantic Segmentation」で提案した画像セマンティックセグメンテーション用のフレームワークです。

FCNは、ピクセルレベルの予測をエンドツーエンドで行う最初の全畳み込みネットワークです。

セマンティックセグメンテーション

FCNの具体的な紹介に先立ち、セマンティックセグメンテーションとは何かを説明します:

画像セマンティックセグメンテーションは、画像処理とコンピュータビジョンにおける画像理解の重要な一部です。AI分野における重要なブランチで、顔認識、物体検出、医学画像、衛星画像分析、自動運転の認知などに広く応用されています。

セマンティックセグメンテーションの目的は、画像内の各ピクセルを分類することです。通常の分類タスクが特定のカテゴリのみを出力するのに対し、セマンティックセグメンテーションタスクは、入力と同じサイズの出力画像を生成し、出力画像の各ピクセルが入力画像の対応するピクセルのカテゴリに対応します。

モデル概要

FCNは主にセグメンテーション分野に使用され、エンドツーエンドのセグメンテーション手法です。深層学習を画像セマンティックセグメンテーションに応用した最初の試みです。ピクセルレベルの予測を直接行い、元の画像サイズと同じラベルマップを出力します。FCNは全結合層を全畳み込み層に置き換えたため、すべての層が畳み込み層で構成され、全畳み込みネットワークと呼ばれます。

全畳み込みニューラルネットワークは主に以下の3つの技術を使用します:

  1. 畳み込み化(Convolutional)

    VGG-16をFCNのバックボーンとして使用します。VGG-16の入力は224×224のRGB画像で、出力は1000個の予測値です。VGG-16は固定サイズの入力のみを受け付け、空間座標を破棄し、非空間出力を生成します。VGG-16には3つの全結合層があり、全結合層は領域全体をカバーする畳み込みと見なすこともできます。全結合層を畳み込み層に変換することで、一次元の非空間出力を二次元行列に変換し、出力から入力画像のマッピングヒートマップを生成できます。

  2. アップサンプリング(Upsample)

    畳み込みプロセスの畳み込み操作とプーリング操作は特徴マップのサイズを小さくします。元の画像サイズの密な画像予測を得るためには、得られた特徴マップをアップサンプリングする必要があります。双線形補間のパラメータを使用してアップサンプリング逆畳み込みのパラメータを初期化し、逆伝播で非線形アップサンプリングを学習します。ネットワーク内でアップサンプリングを実行し、ピクセル損失の逆伝播を通じてエンドツーエンドの学習を行います。

  3. スキップ構造(Skip Layer)

    アップサンプリングのテクニックを使用して最後の層の特徴マップをアップサンプリングし、元のサイズのセグメンテーションを取得することは、ステップサイズが32ピクセルの予測と呼ばれます。最後の層の特徴マップが小さすぎて多くの詳細を失うため、スキップ構造により、よりグローバルな情報を持つ最後の層の予測と、より浅い層の予測を結合し、予測結果に多くの局所的な詳細を取得させます。下位層(ステップ32)の予測(FCN-32s)を2倍にアップサンプリングして元サイズの画像にし、pool4層(ステップ16)からの予測と融合(加算)します。この部分のネットワークはFCN-16sと呼ばれます。次に、この部分の予測をさらに2倍にアップサンプリングし、pool3層からの予測と融合させます。この部分のネットワークはFCN-8sと呼ばれます。スキップ構造により、深層のグローバル情報と浅層の局所情報が結合されます。

ネットワークの特徴

  1. 全結合層(fc)を含まない全畳み込み(fully conv)ネットワークで、任意サイズの入力に対応できます。
  2. データサイズを増やす逆畳み込み(deconv)層により、精細な結果を出力できます。
  3. 異なる深さの層の結果を結合するスキップ構造により、堅牢性と精度を両立しています。

データ処理

実験を開始する前に、Python環境とMindSporeがローカルにインストールされていることを確認してください。

データセットのダウンロード
%%capture captured_output

# 実験環境にはすでにmindspore==2.2.14がインストールされています。他のバージョンに変更する場合は、以下のバージョン番号を変更してください
!pip uninstall mindspore -y
!pip install -i https://pypi.mirrors.ustc.edu.cn/simple mindspore==2.2.14
データセットの確認
# 現在のmindsporeバージョンを確認
!pip show mindspore
データセットのダウンロードと展開
from download import download

url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/dataset_fcn8s.tar"
download(url, "./dataset", kind="tar", replace=True)

データ前処理

PASCAL VOC 2012データセット内の画像の解像度がほとんど一致していないため、テンソル内に配置できません。そのため、入力前に標準化処理が必要です。

データの読み込み

PASCAL VOC 2012データセットとSDBデータセットを混合します。

import numpy as np
import cv2
import mindspore.dataset as ds

class SegDataset:
    def __init__(self,
                 image_mean,
                 image_std,
                 data_file='',
                 batch_size=32,
                 crop_size=512,
                 max_scale=2.0,
                 min_scale=0.5,
                 ignore_label=255,
                 num_classes=21,
                 num_readers=2,
                 num_parallel_calls=4):

        self.data_file = data_file
        self.batch_size = batch_size
        self.crop_size = crop_size
        self.image_mean = np.array(image_mean, dtype=np.float32)
        self.image_std = np.array(image_std, dtype=np.float32)
        self.max_scale = max_scale
        self.min_scale = min_scale
        self.ignore_label = ignore_label
        self.num_classes = num_classes
        self.num_readers = num_readers
        self.num_parallel_calls = num_parallel_calls
        assert max_scale > min_scale

    def preprocess_dataset(self, image, label):
        image_out = cv2.imdecode(np.frombuffer(image, dtype=np.uint8), cv2.IMREAD_COLOR)
        label_out = cv2.imdecode(np.frombuffer(label, dtype=np.uint8), cv2.IMREAD_GRAYSCALE)
        
        sc = np.random.uniform(self.min_scale, self.max_scale)
        new_h, new_w = int(sc * image_out.shape[0]), int(sc * image_out.shape[1])
        image_out = cv2.resize(image_out, (new_w, new_h), interpolation=cv2.INTER_CUBIC)
        label_out = cv2.resize(label_out, (new_w, new_h), interpolation=cv2.INTER_NEAREST)
        
        image_out = (image_out - self.image_mean) / self.image_std
        out_h, out_w = max(new_h, self.crop_size), max(new_w, self.crop_size)
        pad_h, pad_w = out_h - new_h, out_w - new_w
        if pad_h > 0 or pad_w > 0:
            image_out = cv2.copyMakeBorder(image_out, 0, pad_h, 0, pad_w, cv2.BORDER_CONSTANT, value=0)
            label_out = cv2.copyMakeBorder(label_out, 0, pad_h, 0, pad_w, cv2.BORDER_CONSTANT, value=self.ignore_label)
            
        offset_h = np.random.randint(0, out_h - self.crop_size + 1)
        offset_w = np.random.randint(0, out_w - self.crop_size + 1)
        image_out = image_out[offset_h: offset_h + self.crop_size, offset_w: offset_w + self.crop_size, :]
        label_out = label_out[offset_h: offset_h + self.crop_size, offset_w: offset_w + self.crop_size]
        
        if np.random.uniform(0.0, 1.0) > 0.5:
            image_out = image_out[:, ::-1, :]
            label_out = label_out[:, ::-1]
            
        image_out = image_out.transpose((2, 0, 1))
        image_out = image_out.copy()
        label_out = label_out.copy()
        label_out = label_out.astype("int32")
        return image_out, label_out

    def get_dataset(self):
        ds.config.set_numa_enable(True)
        dataset = ds.MindDataset(self.data_file, columns_list=["data", "label"],
                                 shuffle=True, num_parallel_workers=self.num_readers)
        transforms_list = self.preprocess_dataset
        dataset = dataset.map(operations=transforms_list, input_columns=["data", "label"],
                              output_columns=["data", "label"],
                              num_parallel_workers=self.num_parallel_calls)
        dataset = dataset.shuffle(buffer_size=self.batch_size * 10)
        dataset = dataset.batch(self.batch_size, drop_remainder=True)
        return dataset

# データセット作成のパラメータ定義
IMAGE_MEAN = [103.53, 116.28, 123.675]
IMAGE_STD = [57.375, 57.120, 58.395]
DATA_FILE = "dataset/dataset_fcn8s/mindname.mindrecord"

# モデル学習パラメータ定義
train_batch_size = 4
crop_size = 512
min_scale = 0.5
max_scale = 2.0
ignore_label = 255
num_classes = 21

# Datasetのインスタンス化
dataset = SegDataset(image_mean=IMAGE_MEAN,
                     image_std=IMAGE_STD,
                     data_file=DATA_FILE,
                     batch_size=train_batch_size,
                     crop_size=crop_size,
                     max_scale=max_scale,
                     min_scale=min_scale,
                     ignore_label=ignore_label,
                     num_classes=num_classes,
                     num_readers=2,
                     num_parallel_calls=4)

dataset = dataset.get_dataset()

訓練データの可視化

以下のコードを実行して、読み込まれたデータセットの画像を確認します(データ処理の過程で既に正規化処理が行われています)。

import numpy as np
import matplotlib.pyplot as plt

plt.figure(figsize=(16, 8))

# 訓練データセットのデータを表示
for i in range(1, 9):
    plt.subplot(2, 4, i)
    show_data = next(dataset.create_dict_iterator())
    show_images = show_data["data"].asnumpy()
    show_images = np.clip(show_images, 0, 1)
    # 画像をHWC形式に変換して表示
    plt.imshow(show_images[0].transpose(1, 2, 0))
    plt.axis("off")
    plt.subplots_adjust(wspace=0.05, hspace=0)
plt.show()

ネットワーク構築

ネットワークの流れ

FCNネットワークの流れは以下の通りです:

  1. 入力画像imageをpool1でプーリングすると、サイズが元のサイズの1/2になります。
  2. pool2でプーリングすると、サイズが元のサイズの1/4になります。
  3. 次にpool3、pool4、pool5でプーリングし、サイズはそれぞれ元のサイズの1/8、1/16、1/32になります。
  4. conv6-7の畳み込み後の出力サイズは依然として元の画像の1/32です。
  5. FCN-32sは最後に逆畳み込みを使用して、出力画像のサイズを入力画像と同じにします。
  6. FCN-16sはconv7の出力を逆畳み込みしてサイズを2倍に拡大し(元のサイズの1/16)、pool4の出力特徴マップと融合させた後、逆畳み込みで元のサイズに拡大します。
  7. FCN-8sはconv7の出力を4倍に逆畳み拡大し、pool4の出力特徴マップを2倍に逆畳み拡大し、pool3の出力特徴マップを取り出し、3つを融合させてから逆畳み込みで元のサイズに拡大します。

以下のコードを使用してFCN-8sネットワークを構築します。

import mindspore.nn as nn

class FCN8s(nn.Cell):
    def __init__(self, n_class):
        super().__init__()
        self.n_class = n_class
        
        # 第一畳み込みブロック
        self.conv1 = nn.SequentialCell(
            nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, weight_init='xavier_uniform'),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, weight_init='xavier_uniform'),
            nn.BatchNorm2d(64),
            nn.ReLU()
        )
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        
        # 第二畳み込みブロック
        self.conv2 = nn.SequentialCell(
            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, weight_init='xavier_uniform'),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, weight_init='xavier_uniform'),
            nn.BatchNorm2d(128),
            nn.ReLU()
        )
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        
        # 第三畳み込みブロック
        self.conv3 = nn.SequentialCell(
            nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, weight_init='xavier_uniform'),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, weight_init='xavier_uniform'),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, weight_init='xavier_uniform'),
            nn.BatchNorm2d(256),
            nn.ReLU()
        )
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
        
        # 第四畳み込みブロック
        self.conv4 = nn.SequentialCell(
            nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, weight_init='xavier_uniform'),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, weight_init='xavier_uniform'),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, weight_init='xavier_uniform'),
            nn.BatchNorm2d(512),
            nn.ReLU()
        )
        self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)
        
        # 第五畳み込みブロック
        self.conv5 = nn.SequentialCell(
            nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, weight_init='xavier_uniform'),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, weight_init='xavier_uniform'),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, weight_init='xavier_uniform'),
            nn.BatchNorm2d(512),
            nn.ReLU()
        )
        self.pool5 = nn.MaxPool2d(kernel_size=2, stride=2)
        
        # 第六畳み込みブロック
        self.conv6 = nn.SequentialCell(
            nn.Conv2d(in_channels=512, out_channels=4096, kernel_size=7, weight_init='xavier_uniform'),
            nn.BatchNorm2d(4096),
            nn.ReLU(),
        )
        
        # 第七畳み込みブロック
        self.conv7 = nn.SequentialCell(
            nn.Conv2d(in_channels=4096, out_channels=4096, kernel_size=1, weight_init='xavier_uniform'),
            nn.BatchNorm2d(4096),
            nn.ReLU(),
        )
        
        # スコア層とアップサンプリング層
        self.score_fr = nn.Conv2d(in_channels=4096, out_channels=self.n_class, kernel_size=1, weight_init='xavier_uniform')
        self.upscore2 = nn.Conv2dTranspose(in_channels=self.n_class, out_channels=self.n_class, 
                                           kernel_size=4, stride=2, weight_init='xavier_uniform')
        self.score_pool4 = nn.Conv2d(in_channels=512, out_channels=self.n_class, kernel_size=1, weight_init='xavier_uniform')
        self.upscore_pool4 = nn.Conv2dTranspose(in_channels=self.n_class, out_channels=self.n_class, 
                                                kernel_size=4, stride=2, weight_init='xavier_uniform')
        self.score_pool3 = nn.Conv2d(in_channels=256, out_channels=self.n_class, kernel_size=1, weight_init='xavier_uniform')
        self.upscore8 = nn.Conv2dTranspose(in_channels=self.n_class, out_channels=self.n_class, 
                                           kernel_size=16, stride=8, weight_init='xavier_uniform')

    def construct(self, x):
        # エンコーダ部分
        x1 = self.conv1(x)
        p1 = self.pool1(x1)
        x2 = self.conv2(p1)
        p2 = self.pool2(x2)
        x3 = self.conv3(p2)
        p3 = self.pool3(x3)
        x4 = self.conv4(p3)
        p4 = self.pool4(x4)
        x5 = self.conv5(p4)
        p5 = self.pool5(x5)
        x6 = self.conv6(p5)
        x7 = self.conv7(x6)
        
        # デコーダ部分
        sf = self.score_fr(x7)
        u2 = self.upscore2(sf)
        s4 = self.score_pool4(p4)
        f4 = s4 + u2
        u4 = self.upscore_pool4(f4)
        s3 = self.score_pool3(p3)
        f3 = s3 + u4
        out = self.upscore8(f3)
        return out

学習の準備

VGG-16の事前学習重みのインポート

FCNは画像エンコーディングを実現するためにVGG-16をバックボーンとして使用します。以下のコードを使用してVGG-16事前学習モデルの一部の事前学習重みをインポートします。

from download import download
from mindspore import load_checkpoint, load_param_into_net

url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/fcn8s_vgg16_pretrain.ckpt"
download(url, "fcn8s_vgg16_pretrain.ckpt", replace=True)

def load_vgg16():
    ckpt_vgg16 = "fcn8s_vgg16_pretrain.ckpt"
    param_vgg = load_checkpoint(ckpt_vgg16)
    load_param_into_net(net, param_vgg)

損失関数

セマンティックセグメンテーションは画像内の各ピクセルを分類するため、依然として分類問題です。そのため、損失関数としてFCNネットワークの出力とマスク間の交差エントロピー損失を計算する交差エントロピー損失関数を選択します。ここではmindspore.nn.CrossEntropyLoss()を損失関数として使用します。

カスタム評価指標 Metrics

この部分では、学習したモデルの効果を評価します。説明を容易にするために、以下のように仮定します:合計𝑘+1個のクラス(𝐿0から𝐿𝑘まで、背景を含む)があり、𝑝𝑖𝑗は本来𝑖クラスに属するが𝑗クラスと予測されたピクセル数を示します。つまり、𝑝𝑖𝑖は真の数を表し、𝑝𝑖𝑗と𝑝𝑗𝑖はそれぞれ偽陽性と偽陰性と解釈されますが、両者は偽陽性と偽陰性の和です。

  • ピクセル精度(Pixel Accuracy, PA):最も単純な指標で、正しく分類されたピクセルの総ピクセルに対する比率です。

𝑃𝐴=∑𝑘𝑖=0𝑝𝑖𝑖∑𝑘𝑖=0∑𝑘𝑗=0𝑝𝑖𝑗

  • 平均ピクセル精度(Mean Pixel Accuracy, MPA):PAの単純な改良版で、各クラス内で正しく分類されたピクセル数の比率を計算し、その後すべてのクラスの平均を取ります。

𝑀𝑃𝐴=1𝑘+1∑𝑖=0𝑘𝑝𝑖𝑖∑𝑘𝑗=0𝑝𝑖𝑗

  • 平均交差比(Mean Intersection over Union, MloU):セマンティックセグメンテーションの標準的な指標です。2つの集合の交差と和の比を計算します。セマンティックセグメンテーションの問題では、これらの集合は真値(ground truth)と予測値(predicted segmentation)です。この比率は、真陽性(intersection)を真陽性、偽陰性、偽陽性(和)の和で割ったものに変形できます。各クラスでloUを計算し、その平均を取ります。

𝑀𝐼𝑜𝑈=1𝑘+1∑𝑖=0𝑘𝑝𝑖𝑖∑𝑘𝑗=0𝑝𝑖𝑗+∑𝑘𝑗=0𝑝𝑗𝑖−𝑝𝑖𝑖

  • 頻度重み付き交差比(Frequency Weighted Intersection over Union, FWIoU):MloUの改良版で、各クラスの出現頻度に基づいて重みを設定する方法です。

𝐹𝑊𝐼𝑜𝑈=1∑𝑘𝑖=0∑𝑘𝑗=0𝑝𝑖𝑗∑𝑖=0𝑘𝑝𝑖𝑖∑𝑘𝑗=0𝑝𝑖𝑗+∑𝑘𝑗=0𝑝𝑗𝑖−𝑝𝑖𝑖

import numpy as np
import mindspore as ms
import mindspore.nn as nn
import mindspore.train as train

class PixelAccuracy(train.Metric):
    def __init__(self, num_class=21):
        super(PixelAccuracy, self).__init__()
        self.num_class = num_class

    def _generate_matrix(self, gt_image, pre_image):
        mask = (gt_image >= 0) & (gt_image < self.num_class)
        label = self.num_class * gt_image[mask].astype('int') + pre_image[mask]
        count = np.bincount(label, minlength=self.num_class**2)
        confusion_matrix = count.reshape(self.num_class, self.num_class)
        return confusion_matrix

    def clear(self):
        self.confusion_matrix = np.zeros((self.num_class,) * 2)

    def update(self, *inputs):
        y_pred = inputs[0].asnumpy().argmax(axis=1)
        y = inputs[1].asnumpy().reshape(4, 512, 512)
        self.confusion_matrix += self._generate_matrix(y, y_pred)

    def eval(self):
        pixel_accuracy = np.diag(self.confusion_matrix).sum() / self.confusion_matrix.sum()
        return pixel_accuracy

class PixelAccuracyClass(train.Metric):
    def __init__(self, num_class=21):
        super(PixelAccuracyClass, self).__init__()
        self.num_class = num_class

    def _generate_matrix(self, gt_image, pre_image):
        mask = (gt_image >= 0) & (gt_image < self.num_class)
        label = self.num_class * gt_image[mask].astype('int') + pre_image[mask]
        count = np.bincount(label, minlength=self.num_class**2)
        confusion_matrix = count.reshape(self.num_class, self.num_class)
        return confusion_matrix

    def update(self, *inputs):
        y_pred = inputs[0].asnumpy().argmax(axis=1)
        y = inputs[1].asnumpy().reshape(4, 512, 512)
        self.confusion_matrix += self._generate_matrix(y, y_pred)

    def clear(self):
        self.confusion_matrix = np.zeros((self.num_class,) * 2)

    def eval(self):
        mean_pixel_accuracy = np.diag(self.confusion_matrix) / self.confusion_matrix.sum(axis=1)
        mean_pixel_accuracy = np.nanmean(mean_pixel_accuracy)
        return mean_pixel_accuracy

class MeanIntersectionOverUnion(train.Metric):
    def __init__(self, num_class=21):
        super(MeanIntersectionOverUnion, self).__init__()
        self.num_class = num_class

    def _generate_matrix(self, gt_image, pre_image):
        mask = (gt_image >= 0) & (gt_image < self.num_class)
        label = self.num_class * gt_image[mask].astype('int') + pre_image[mask]
        count = np.bincount(label, minlength=self.num_class**2)
        confusion_matrix = count.reshape(self.num_class, self.num_class)
        return confusion_matrix

    def update(self, *inputs):
        y_pred = inputs[0].asnumpy().argmax(axis=1)
        y = inputs[1].asnumpy().reshape(4, 512, 512)
        self.confusion_matrix += self._generate_matrix(y, y_pred)

    def clear(self):
        self.confusion_matrix = np.zeros((self.num_class,) * 2)

    def eval(self):
        mean_iou = np.diag(self.confusion_matrix) / (
            np.sum(self.confusion_matrix, axis=1) + np.sum(self.confusion_matrix, axis=0) -
            np.diag(self.confusion_matrix))
        mean_iou = np.nanmean(mean_iou)
        return mean_iou

class FrequencyWeightedIntersectionOverUnion(train.Metric):
    def __init__(self, num_class=21):
        super(FrequencyWeightedIntersectionOverUnion, self).__init__()
        self.num_class = num_class

    def _generate_matrix(self, gt_image, pre_image):
        mask = (gt_image >= 0) & (gt_image < self.num_class)
        label = self.num_class * gt_image[mask].astype('int') + pre_image[mask]
        count = np.bincount(label, minlength=self.num_class**2)
        confusion_matrix = count.reshape(self.num_class, self.num_class)
        return confusion_matrix

    def update(self, *inputs):
        y_pred = inputs[0].asnumpy().argmax(axis=1)
        y = inputs[1].asnumpy().reshape(4, 512, 512)
        self.confusion_matrix += self._generate_matrix(y, y_pred)

    def clear(self):
        self.confusion_matrix = np.zeros((self.num_class,) * 2)

    def eval(self):
        freq = np.sum(self.confusion_matrix, axis=1) / np.sum(self.confusion_matrix)
        iu = np.diag(self.confusion_matrix) / (
            np.sum(self.confusion_matrix, axis=1) + np.sum(self.confusion_matrix, axis=0) -
            np.diag(self.confusion_matrix))

        frequency_weighted_iou = (freq[freq > 0] * iu[freq > 0]).sum()
        return frequency_weighted_iou

モデルの学習

VGG-16事前学習パラメータをインポートした後、損失関数と最適化器のインスタンスを作成し、Modelインターフェースを使用してネットワークをコンパイルし、FCN-8sネットワークを学習させます。

import mindspore
from mindspore import Tensor
import mindspore.nn as nn
from mindspore.train import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor, Model

device_target = "Ascend"
mindspore.set_context(mode=mindspore.PYNATIVE_MODE, device_target=device_target)

train_batch_size = 4
num_classes = 21

# モデル構造の初期化
net = FCN8s(n_class=21)

# vgg16事前学習パラメータのインポート
load_vgg16()

# 学習率の計算
min_lr = 0.0005
base_lr = 0.05
train_epochs = 1
iters_per_epoch = dataset.get_dataset_size()
total_step = iters_per_epoch * train_epochs

lr_scheduler = mindspore.nn.cosine_decay_lr(min_lr,
                                            base_lr,
                                            total_step,
                                            iters_per_epoch,
                                            decay_epoch=2)
lr = Tensor(lr_scheduler[-1])

# 損失関数の定義
loss = nn.CrossEntropyLoss(ignore_index=255)

# 最適化器の定義
optimizer = nn.Momentum(params=net.trainable_params(), learning_rate=lr, momentum=0.9, weight_decay=0.0001)

# loss_scaleの定義
scale_factor = 4
scale_window = 3000
loss_scale_manager = ms.amp.DynamicLossScaleManager(scale_factor, scale_window)

# モデルの初期化
if device_target == "Ascend":
    model = Model(net, loss_fn=loss, optimizer=optimizer, loss_scale_manager=loss_scale_manager, 
                  metrics={"pixel accuracy": PixelAccuracy(), 
                          "mean pixel accuracy": PixelAccuracyClass(), 
                          "mean IoU": MeanIntersectionOverUnion(), 
                          "frequency weighted IoU": FrequencyWeightedIntersectionOverUnion()})
else:
    model = Model(net, loss_fn=loss, optimizer=optimizer, 
                  metrics={"pixel accuracy": PixelAccuracy(), 
                          "mean pixel accuracy": PixelAccuracyClass(), 
                          "mean IoU": MeanIntersectionOverUnion(), 
                          "frequency weighted IoU": FrequencyWeightedIntersectionOverUnion()})

# ckptファイル保存のパラメータ設定
time_callback = TimeMonitor(data_size=iters_per_epoch)
loss_callback = LossMonitor()
callbacks = [time_callback, loss_callback]

save_steps = 330
keep_checkpoint_max = 5
config_ckpt = CheckpointConfig(save_checkpoint_steps=10,
                               keep_checkpoint_max=keep_checkpoint_max)
ckpt_callback = ModelCheckpoint(prefix="FCN8s",
                               directory="./ckpt",
                               config=config_ckpt)
callbacks.append(ckpt_callback)

model.train(train_epochs, dataset, callbacks=callbacks)

モデルの評価

IMAGE_MEAN = [103.53, 116.28, 123.675]
IMAGE_STD = [57.375, 57.120, 58.395]
DATA_FILE = "dataset/dataset_fcn8s/mindname.mindrecord"

# 学習済みの重みファイルをダウンロード
url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/FCN8s.ckpt"
download(url, "FCN8s.ckpt", replace=True)

net = FCN8s(n_class=num_classes)

ckpt_file = "FCN8s.ckpt"
param_dict = load_checkpoint(ckpt_file)
load_param_into_net(net, param_dict)

if device_target == "Ascend":
    model = Model(net, loss_fn=loss, optimizer=optimizer, loss_scale_manager=loss_scale_manager, 
                  metrics={"pixel accuracy": PixelAccuracy(), 
                          "mean pixel accuracy": PixelAccuracyClass(), 
                          "mean IoU": MeanIntersectionOverUnion(), 
                          "frequency weighted IoU": FrequencyWeightedIntersectionOverUnion()})
else:
    model = Model(net, loss_fn=loss, optimizer=optimizer, 
                  metrics={"pixel accuracy": PixelAccuracy(), 
                          "mean pixel accuracy": PixelAccuracyClass(), 
                          "mean IoU": MeanIntersectionOverUnion(), 
                          "frequency weighted IoU": FrequencyWeightedIntersectionOverUnion()})

# Datasetのインスタンス化
dataset = SegDataset(image_mean=IMAGE_MEAN,
                     image_std=IMAGE_STD,
                     data_file=DATA_FILE,
                     batch_size=train_batch_size,
                     crop_size=crop_size,
                     max_scale=max_scale,
                     min_scale=min_scale,
                     ignore_label=ignore_label,
                     num_classes=num_classes,
                     num_readers=2,
                     num_parallel_calls=4)
dataset_eval = dataset.get_dataset()

model.eval(dataset_eval)

モデルの推論

学習したネットワークを使用して、モデルの推論結果を表示します。

import cv2
import matplotlib.pyplot as plt

net = FCN8s(n_class=num_classes)

# ハイパーパラメータの設定
ckpt_file = "FCN8s.ckpt"
param_dict = load_checkpoint(ckpt_file)
load_param_into_net(net, param_dict)

eval_batch_size = 4
img_lst = []
mask_lst = []
res_lst = []

# 推論結果の表示(上:入力画像、下:推論結果画像)
plt.figure(figsize=(8, 5))
show_data = next(dataset_eval.create_dict_iterator())
show_images = show_data["data"].asnumpy()
mask_images = show_data["label"].reshape([4, 512, 512])
show_images = np.clip(show_images, 0, 1)

for i in range(eval_batch_size):
    img_lst.append(show_images[i])
    mask_lst.append(mask_images[i])

res = net(show_data["data"]).asnumpy().argmax(axis=1)

for i in range(eval_batch_size):
    plt.subplot(2, 4, i + 1)
    plt.imshow(img_lst[i].transpose(1, 2, 0))
    plt.axis("off")
    plt.subplots_adjust(wspace=0.05, hspace=0.02)
    plt.subplot(2, 4, i + 5)
    plt.imshow(res[i])
    plt.axis("off")
    plt.subplots_adjust(wspace=0.05, hspace=0.02)
plt.show()

まとめ

FCNの主な貢献は、全畳み込み層を使用し、画像でエンドツーエンドのセグメンテーションを実現する学習を提案した点にあります。従来のCNNを使用した画像セグメンテーション手法と比較して、FCNには2つの明確な利点があります:

  1. 任意サイズの入力画像を受け付け、すべての学習画像とテスト画像に固定サイズを要求する必要がありません。
  2. より効率的で、ピクセルブロックを使用することによる重複したストレージと畳み込み計算の問題を回避します。

同時に、FCNネットワークにも改善すべき点があります:

  1. 得られた結果が十分に精細ではありません。8倍のアップサンプリングは32倍の効果はるかに良いですが、アップサンプリングの結果は依然としてぼやけており、特に境界部分で平滑です。ネットワークは画像の詳細に敏感ではありません。
  2. 各ピクセルを分類しますが、ピクセル間の関係(不連続性と類似性)を十分に考慮していません。通常のピクセルベースの分離セグメンテーション手法で使用される空間的正則化(spatial regularization)ステップを無視しており、空間的一貫性に欠けています。

参考文献

[1] Long, Jonathan, Evan Shelhamer, and Trevor Darrell. "Fully convolutional networks for Semantic Segmentation." Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition. 2015.

タグ: FCN セマンティックセグメンテーション ディープラーニング 画像処理 MindSpore

5月11日 14:57 投稿