3D-CNNを用いたCOVID-19胸部CTスキャンデータセットの二値分類実装

1、はじめに

本稿では、3次元畳み込みニューラルネットワーク(3D CNN)を構築し、電子コンピュータ断層撮影(CT)スキャンが新型コロナウイルス肺炎に感染しているかどうかを予測する方法を紹介します。

通常、2次元CNNはRGB画像(3チャネル)の処理に使用されます。

3D CNN:このアーキテクチャは3次元データまたは2次元フレームシーケンス(CTスキャン内のスライスなど)を入力として受け取り、3次元深度または連続するビデオフレームから多チャネル情報を生成し、各チャネルで畳み込みとダウンサンプリング操作を個別に行います。最後にすべてのチャネル情報を組み合わせて最終的な特徴記述を得ます。

2、関連ライブラリのインポート

In [ ]

!pip install nibabel -t /home/aistudio/external-libraries
!pip install scipy -t /home/aistudio/external-libraries

本例はPaddle 2.3.2を基に作成されています。環境がこのバージョンでない場合は、まず公式サイトでPaddle 2.3.2をインストールしてください。

In [38]

import os
import zipfile
import numpy as np
import paddle
import paddle.nn as nn
from paddle.io import Dataset
import random
import matplotlib.pyplot as plt
import sys 
sys.path.append('/home/aistudio/external-libraries')
from scipy import ndimage
import nibabel as nib

paddle.__version__
'2.3.2'

3、データセットの解凍

完全なデータセットリンク:https://www.medrxiv.org/content/10.1101/2020.05.20.20100362v1

In [ ]

# データを保存するディレクトリを作成します。
os.makedirs("MosMedData")

# 新しく作成したディレクトリにデータを解凍します。
with zipfile.ZipFile("data/data112472/ct-0.zip", "r") as z_fp:
    z_fp.extractall("./MosMedData/")

with zipfile.ZipFile("data/data112472/ct-23.zip", "r") as z_fp:
    z_fp.extractall("./MosMedData/")

4、データの読み込みと前処理

CTスキャンの関連放射線学的所見をラベルとして使用し、新型コロナウイルス肺炎の存在を予測する分類器を構築します。したがって、このタスクは二値分類問題です。

これらのファイルはNIFTI形式で提供され、拡張子は.niiです。スキャンを読み取るために、nibabelパッケージを使用します。このパッケージはpip install nibabelでインストールできます。CTスキャンはHounsfield単位(HU)で生のボクセル強度が保存されています。このデータセットでは、その範囲は-1024から2000以上です。400以上の骨は異なる放射線強度を持っているため、これがより高い境界として使用されます。通常、-1000から400の間の閾値を使用してCTスキャンを標準化します。

データを処理するために、以下の操作を実行します:

  • まず体積を90度回転させ、方向を固定します。
  • HU値を0から1の間にスケーリングします。
  • 幅、高さ、深さのサイズを調整します。

ここでは、データを処理するためのいくつかのヘルパー関数を定義します。これらの機能は、トレーニングおよび検証データセットの構築に使用されます。

In [5]

def read_nifti_file(filepath):
    """ファイルの読み込みとデータの取得"""
    # ファイルの読み込み
    scan = nib.load(filepath)
    # 生データの取得
    scan = scan.get_fdata()
    return scan

def normalize(volume):
    """データの正規化"""
    min_val = -1000
    max_val = 400
    volume[volume < min_val] = min_val
    volume[volume > max_val] = max_val
    volume = (volume - min_val) / (max_val - min_val)
    volume = volume.astype("float32")
    return volume

def resize_volume(img):
    """z軸に沿ってリサイズ"""
    # 必要な深さを設定
    desired_depth = 64
    desired_width = 128
    desired_height = 128
    # 現在の深さを取得
    current_depth = img.shape[-1]
    current_width = img.shape[0]
    current_height = img.shape[1]
    # 深さ係数を計算
    depth = current_depth / desired_depth
    width = current_width / desired_width
    height = current_height / desired_height
    depth_factor = 1 / depth
    width_factor = 1 / width
    height_factor = 1 / height
    # 回転
    img = ndimage.rotate(img, 90, reshape=False)
    # z軸に沿ってリサイズ
    img = ndimage.zoom(img, (width_factor, height_factor, depth_factor), order=1)
    return img

def process_scan(path):
    """スキャンの読み込みとリサイズ"""
    # スキャンファイルの読み込み
    volume = read_nifti_file(path)
    # 正規化
    volume = normalize(volume)
    # 幅、高さ、深さの調整
    volume = resize_volume(volume)
    return volume

クラスディレクトリからCTスキャンのパスを読み込みます。

In [6]

# フォルダ「CT-0」には正常な肺組織を持つCTスキャンが含まれています。
# ウイルス性肺炎のないCTです。
normal_scan_paths = [
    os.path.join(os.getcwd(), "MosMedData/CT-0", x)
    for x in os.listdir("MosMedData/CT-0")
]

# フォルダ「CT-23」には感染したCTスキャンが含まれています。
abnormal_scan_paths = [
    os.path.join(os.getcwd(), "MosMedData/CT-23", x)
    for x in os.listdir("MosMedData/CT-23")
]

print("正常な肺組織を持つCTスキャン: " + str(len(normal_scan_paths)))
print("異常な肺組織を持つCTスキャン: " + str(len(abnormal_scan_paths)))
正常な肺組織を持つCTスキャン: 100
異常な肺組織を持つCTスキャン: 100

5、トレーニングと検証データセットの構築

データセットディレクトリからスキャンを読み込み、ラベルを割り当てます。

スキャン結果をダウンサンプリングし、形状を128x128x64にします。

元のHU値を0から1の範囲に再スケーリングします。最後に、データセットをトレーニングサブセットと検証サブセットに分割します。

In [7]

abnormal_path = "abnormal_scans.npy"
normal_path = "normal_scans.npy"
# この変換は時間がかかるため、実行を中断すると再度変換が必要になります。
# そのため、変換結果をキャッシュします。
if os.path.exists(abnormal_path) and os.path.exists(normal_path):
    abnormal_scans = np.load(abnormal_path)
    normal_scans = np.load(normal_path)
else:
    abnormal_scans = np.array([process_scan(path) for path in abnormal_scan_paths])
    normal_scans = np.array([process_scan(path) for path in normal_scan_paths])
    np.save(abnormal_path, abnormal_scans)
    np.save(normal_path, normal_scans)

CTスキャンは、トレーニング中にランダムな角度で回転させることでデータ拡張も行われます。データがrank-3テンソルの形状で保存されているため、3D畳み込みを実行できるように、軸4にサイズ1の次元を追加します。

In [8]

# 肺スキャンデータセット
class MedicalDataset(Dataset):
    """
        ステップ1:paddle.io.Datasetクラスを継承
    """
    def __init__(self, abnormal_scans, normal_scans, mode='train'):
        """
        ステップ2:コンストラクタを実装
        Args:
            abnormal_scans: 異常なCTスキャン
            normal_scans: 正常なCTスキャン
            mode: train、val
        """
        super().__init__()
        self.images = []
        self.labels = []
        self.is_training = False
        abnormal_labels = np.array([1 for _ in range(len(abnormal_scans))])
        normal_labels = np.array([0 for _ in range(len(normal_scans))])
        if mode == 'train':
            x_train = np.concatenate((abnormal_scans[:70], normal_scans[:70]), axis=0)
            y_train = np.concatenate((abnormal_labels[:70], normal_labels[:70]), axis=0)
            self.images = x_train
            self.labels = y_train
            self.is_training = True
        else:
            x_val = np.concatenate((abnormal_scans[70:], normal_scans[70:]), axis=0)
            y_val = np.concatenate((abnormal_labels[70:], normal_labels[70:]), axis=0)
            self.images = x_val
            self.labels = y_val
            
    def __getitem__(self, index):
        """
        ステップ3:__getitem__メソッドを実装し、指定されたインデックスでデータを取得する方法を定義し、単一のデータ(トレーニングデータ/テストデータと対応するラベル)を返します
        """
        img = self.images[index]
        lab = self.labels[index]
        if self.is_training:
            img = self.augment(img)
        img = np.array(img).astype('float32')
        lab = np.array([lab], dtype="int64")
        img = np.expand_dims(img, 3)
        return img, lab

    def __len__(self):
        """
        ステップ4:__len__メソッドを実装し、データセットの総数を返します
        """
        return len(self.labels)

    def augment(self, volume):
        # データ拡張用の回転角度を定義
        angles = [-20, -10, -5, 5, 10, 20]
        # ランダムに角度を選択
        angle = random.choice(angles)
        # CTスキャンの回転
        volume = ndimage.rotate(volume, angle, reshape=False)
        volume[volume < 0] = 0
        volume[volume > 1] = 1
        return volume

In [9]

train_dataset = MedicalDataset(abnormal_scans, normal_scans)
validation_dataset = MedicalDataset(abnormal_scans, normal_scans, mode='val')

train_loader = paddle.io.DataLoader(train_dataset, batch_size=10, shuffle=True)
validation_loader = paddle.io.DataLoader(validation_dataset, batch_size=10, shuffle=False)

拡張されたCTスキャンの可視化。

In [11]

images, labels = train_dataset.__getitem__(0)
print(images.shape)
image = np.squeeze(images, 3)
print("CTスキャンの次元:", image.shape)
plt.imshow(np.squeeze(image[:, :, 30]), cmap="gray")
plt.show()
(128, 128, 64, 1)
CTスキャンの次元: (128, 128, 64)
<Figure size 640x480 with 1 Axes>

CTスキャンには多くのスライスがあるため、複数のスライスを可視化します。

In [12]

def visualize_slices(num_rows, num_columns, width, height, data):
    """CTの横断面を描画"""
    data = np.rot90(np.array(data))
    data = np.transpose(data)
    data = np.reshape(data, (num_rows, num_columns, width, height))
    rows_data, columns_data = data.shape[0], data.shape[1]
    heights = [slc[0].shape[0] for slc in data]
    widths = [slc.shape[1] for slc in data[0]]
    fig_width = 6.0
    fig_height = fig_width * sum(heights) / sum(widths)
    f, axarr = plt.subplots(
        rows_data,
        columns_data,
        figsize=(fig_width, fig_height),
        gridspec_kw={"height_ratios": heights},
    )
    for i in range(rows_data):
        for j in range(columns_data):
            axarr[i, j].imshow(data[i][j], cmap="gray")
            axarr[i, j].axis("off")
    plt.subplots_adjust(wspace=0, hspace=0, left=0, right=1, bottom=0, top=1)
    plt.show()


# CT横断面の描画
visualize_slices(8, 8, 128, 128, image[:, :, :64])
<Figure size 600x600 with 64 Axes>

6、3D畳み込みニューラルネットワークの定義

畳み込みニューラルネットワーク(CNN)と言うと、通常は画像分類用の2次元CNNを指します。しかし、現実世界では1次元CNNと3次元CNNという他の2種類の畳み込みニューラルネットワークも使用されています。このガイドでは、1Dと3D CNNとその現実世界での応用について紹介します。畳み込みネットワークの概念に一般的に慣れていることを前提とします。初心者は1次元CNNが1次元データを処理し、2次元CNNが2次元データを処理すると誤解しやすいです!!!

畳み込みニューラルネットワーク(CNN)では、1次元および2次元フィルターは実際には1次元および2次元ではありません。これは単なる表現の慣習です。

2次元CNN | Conv2D

Lenet-5アーキテクチャで初めて導入された標準的な畳み込みニューラルネットワークでは、Conv2Dは通常画像データに使用されます。2次元CNNと呼ばれるのは、カーネルが2次元に沿ってデータ上をスライドするためです。

カーネルが画像上をスライド

CNNの全体的な利点は、他のネットワークではできない空間特徴をカーネルを使用してデータから抽出できることです。たとえば、CNNは画像内のエッジ、色の分布などを検出でき、これによりこれらのネットワークは画像分類や空間属性を含む他の類似データで非常に強力になります。

1次元CNN | Conv1D

Conv1Dを紹介する前に、ヒントを提供します。Conv1Dでは、カーネルは1次元に沿ってスライドします。ここで一時停止して、カーネルが1次元のみでスライドし、空間特性を持つデータタイプを考えてみましょう? 答えは時系列データです。以下のデータを見てみましょう。

加速度計からの系列データ

このデータは、腕に取り付けられた加速度計から収集されたものです。データは3軸すべての加速度を表します。1次元CNNは、加速度計データに基づいて活動認識タスクを実行できます。たとえば、人の姿勢、歩行、ジャンプなどです。このデータには2つの次元があります。1次元は時間ステップで、他の次元は3軸の加速度値です。下の図は、カーネルが加速度計データ上を移動する方法を示しています。各行は特定の軸の時間系列加速度を表します。カーネルは時間軸に沿ってのみ1次元で移動できます。

Conv1Dはセンサーデータ、特に加速度計データで広く使用されています。

3次元CNN | Conv3D

Conv3Dでは、カーネルは3次元に沿ってスライドします。では、どのタイプのデータがカーネルが3次元で移動する必要があるでしょうか?

カーネルが3Dデータ上をスライド

Conv3Dは主に3D画像データで使用されます。たとえば、磁気共鳴画像(MRI)データです。MRIデータは、脳、脊髄、内部器官などを検査するために広く使用されています。このケースでは、コンピュータ断層撮影(CT)スキャンも3次元データであり、これは身体の周囲から異なる角度で撮影された一連のX線画像を組み合わせて作成されます。Conv3Dを使用して、この医学データを分類したり特徴を抽出したりできます。

Layerに基づいてモデルネットワーク構造を定義し、モデルの可視化を表示します。

In [13]

# 3D CNNネットワーク
class Medical3DCNN(nn.Layer):
    def __init__(self):
        super(Medical3DCNN, self).__init__()     
        self.network = paddle.nn.Sequential(
            paddle.nn.Conv3D(in_channels=1, out_channels=64, kernel_size=3),
            paddle.nn.ReLU(),
            paddle.nn.MaxPool3D(kernel_size=2),
            paddle.nn.BatchNorm3D(64),

            paddle.nn.Conv3D(in_channels=64, out_channels=64, kernel_size=3),
            paddle.nn.ReLU(),
            paddle.nn.MaxPool3D(kernel_size=2),
            paddle.nn.BatchNorm3D(64),

            paddle.nn.Conv3D(in_channels=64, out_channels=128, kernel_size=3),
            paddle.nn.ReLU(),
            paddle.nn.MaxPool3D(kernel_size=2),
            paddle.nn.BatchNorm3D(128),

            paddle.nn.Conv3D(in_channels=128, out_channels=256, kernel_size=3),
            paddle.nn.ReLU(),
            paddle.nn.MaxPool3D(kernel_size=2),
            paddle.nn.BatchNorm3D(256),
            
            paddle.nn.AdaptiveAvgPool3D(output_size=1),
            paddle.nn.Flatten(),
            paddle.nn.Linear(256, 512),
            paddle.nn.Dropout(p=0.3),

            paddle.nn.Linear(512, 1),
            paddle.nn.Sigmoid()
        )

    def forward(self, inputs):
        # データ入力は【バッチ、幅、高さ、深さ、チャネル】なので、Paddleのデフォルト入力要件【バッチ、チャネル、深さ、幅、高さ】に調整
        inputs = inputs.transpose((0, 4, 3, 1, 2))
        x = self.network(inputs)
        return x

In [30]

model = Medical3DCNN()
paddle.summary(model, (1, 128, 128, 64, 1))
-----------------------------------------------------------------------------------
   Layer (type)           Input Shape           Output Shape          Param #    
===================================================================================
     Conv3D-9       [[1, 1, 64, 128, 128]]  [1, 64, 62, 126, 126]      1,792     
      ReLU-9        [[1, 64, 62, 126, 126]] [1, 64, 62, 126, 126]        0       
    MaxPool3D-9     [[1, 64, 62, 126, 126]]  [1, 64, 31, 63, 63]         0       
   BatchNorm3D-9     [[1, 64, 31, 63, 63]]   [1, 64, 31, 63, 63]        256      
     Conv3D-10       [[1, 64, 31, 63, 63]]   [1, 64, 29, 61, 61]      110,656    
      ReLU-10        [[1, 64, 29, 61, 61]]   [1, 64, 29, 61, 61]         0       
   MaxPool3D-10      [[1, 64, 29, 61, 61]]   [1, 64, 14, 30, 30]         0       
  BatchNorm3D-10     [[1, 64, 14, 30, 30]]   [1, 64, 14, 30, 30]        256      
     Conv3D-11       [[1, 64, 14, 30, 30]]  [1, 128, 12, 28, 28]      221,312    
      ReLU-11       [[1, 128, 12, 28, 28]]  [1, 128, 12, 28, 28]         0       
   MaxPool3D-11     [[1, 128, 12, 28, 28]]   [1, 128, 6, 14, 14]         0       
  BatchNorm3D-11     [[1, 128, 6, 14, 14]]   [1, 128, 6, 14, 14]        512      
     Conv3D-12       [[1, 128, 6, 14, 14]]   [1, 256, 4, 12, 12]      884,992    
      ReLU-12        [[1, 256, 4, 12, 12]]   [1, 256, 4, 12, 12]         0       
   MaxPool3D-12      [[1, 256, 4, 12, 12]]    [1, 256, 2, 6, 6]          0       
  BatchNorm3D-12      [[1, 256, 2, 6, 6]]     [1, 256, 2, 6, 6]        1,024     
AdaptiveAvgPool3D-3   [[1, 256, 2, 6, 6]]     [1, 256, 1, 1, 1]          0       
     Flatten-3        [[1, 256, 1, 1, 1]]         [1, 256]               0       
     Linear-5             [[1, 256]]              [1, 512]            131,584    
     Dropout-3            [[1, 512]]              [1, 512]               0       
     Linear-6             [[1, 512]]               [1, 1]               513      
     Sigmoid-3             [[1, 1]]                [1, 1]                0       
===================================================================================
Total params: 1,352,897
Trainable params: 1,350,849
Non-trainable params: 2,048
-----------------------------------------------------------------------------------
Input size (MB): 4.00
Forward/backward pass size (MB): 1222.30
Params size (MB): 5.16
Estimated Total Size (MB): 1231.46
-----------------------------------------------------------------------------------

{'total_params': 1352897, 'trainable_params': 1350849}

7、モデルのトレーニング

前述の3D CNNとデータセットを使用してモデルをトレーニングします。

これは二値分類タスクであるため、ネットワークの出力は0-1の間の浮動小数点数です。したがって、ここではBCE損失を使用します。

In [31]

def calculate_accuracy(predictions, labels):
    pred = np.squeeze(predictions.cpu().numpy(), -1)
    pred = pred < 0.5
    true = np.squeeze(labels.cpu().numpy(), -1) == 0
    correct = np.sum(pred == true)
    return correct / len(pred)

training_losses = []
training_accuracies = []
validation_losses = []
validation_accuracies = []

def train_model():
    scheduler = paddle.optimizer.lr.ExponentialDecay(learning_rate=0.0001, gamma=0.96, verbose=True)
    optimizer = paddle.optimizer.AdamW(learning_rate=scheduler, parameters=model.parameters())
    criterion = paddle.nn.BCELoss()
    epochs = 25
    best_accuracy = 0
    save_path = "best_medical_model"
    
    for epoch in range(epochs):
        
        total_train, train_acc, train_loss = 0, 0, 0
        model.train()
        for i, (img, labels) in enumerate(train_loader):     
            predictions = model(img)
            loss = criterion(predictions, labels.astype(np.float32))
            optimizer.clear_grad()
            loss.backward()
            optimizer.minimize(loss)
            acc = calculate_accuracy(predictions, labels)
            train_acc += acc
            train_loss += loss.numpy()
            if i % 10 == 0:
                print(f"Epoch {epoch} Iter {i} Loss: {loss.numpy()}, Accuracy: {acc}")
            total_train += 1
        total_val, val_acc, val_loss = 0, 0, 0
        model.eval()
        for i, (img, labels) in enumerate(validation_loader):
            with paddle.no_grad():
                predictions = model(img)
                loss = criterion(predictions, labels.astype(np.float32))
            val_acc += calculate_accuracy(predictions, labels)
            val_loss += loss.numpy()
            total_val += 1
        print(f"Epoch {epoch} Validation Loss: {val_loss/total_val} Accuracy: {val_acc/total_val}")
        # 最良のモデルのみ保存
        if val_acc/total_val > best_accuracy:
            best_accuracy = val_acc/total_val
            model_state = model.state_dict()
            paddle.save(model_state, save_path)
            print(f"Best accuracy: {best_accuracy}")
        scheduler.step()
        training_losses.append(train_loss/total_train)
        training_accuracies.append(train_acc/total_train)
        validation_losses.append(val_loss/total_val)
        validation_accuracies.append(val_acc/total_val)

train_model()
Epoch 0: ExponentialDecay set learning rate to 0.0001.
Epoch 0 Iter 0 Loss: [0.6301036], Accuracy: 0.5
Epoch 0 Iter 10 Loss: [0.7426481], Accuracy: 0.4
Epoch 0 Validation Loss: [1.1324016] Accuracy: 0.5
Best accuracy: 0.5
Epoch 1: ExponentialDecay set learning rate to 9.6e-05.
Epoch 1 Iter 0 Loss: [0.7325362], Accuracy: 0.5
Epoch 1 Iter 10 Loss: [0.5694233], Accuracy: 0.6
Epoch 1 Validation Loss: [1.0659176] Accuracy: 0.5
Epoch 2: ExponentialDecay set learning rate to 9.216e-05.
Epoch 2 Iter 0 Loss: [0.5374714], Accuracy: 0.6
Epoch 2 Iter 10 Loss: [0.40864113], Accuracy: 0.8
Epoch 2 Validation Loss: [0.7484214] Accuracy: 0.5166666666666667
Best accuracy: 0.5166666666666667
Epoch 3: ExponentialDecay set learning rate to 8.847359999999999e-05.
Epoch 3 Iter 0 Loss: [0.4230544], Accuracy: 0.8
Epoch 3 Iter 10 Loss: [0.4823562], Accuracy: 0.9
Epoch 3 Validation Loss: [0.6880954] Accuracy: 0.6166666666666667
Best accuracy: 0.6166666666666667
Epoch 4: ExponentialDecay set learning rate to 8.493465599999999e-05.
Epoch 4 Iter 0 Loss: [0.6360772], Accuracy: 0.6
Epoch 4 Iter 10 Loss: [0.6532255], Accuracy: 0.6
Epoch 4 Validation Loss: [0.5975937] Accuracy: 0.7000000000000001
Best accuracy: 0.7000000000000001
Epoch 5: ExponentialDecay set learning rate to 8.153726975999998e-05.
Epoch 5 Iter 0 Loss: [0.676396], Accuracy: 0.6
Epoch 5 Iter 10 Loss: [0.5197036], Accuracy: 0.7
Epoch 5 Validation Loss: [0.5965887] Accuracy: 0.7000000000000001
Epoch 6: ExponentialDecay set learning rate to 7.827577896959998e-05.
Epoch 6 Iter 0 Loss: [0.5222303], Accuracy: 0.7
Epoch 6 Iter 10 Loss: [0.28494963], Accuracy: 0.9
Epoch 6 Validation Loss: [0.62934595] Accuracy: 0.6833333333333332
Epoch 7: ExponentialDecay set learning rate to 7.514474781081598e-05.
Epoch 7 Iter 0 Loss: [0.5931848], Accuracy: 0.7
Epoch 7 Iter 10 Loss: [0.34113288], Accuracy: 0.9
Epoch 7 Validation Loss: [0.5483835] Accuracy: 0.7000000000000001
Epoch 8: ExponentialDecay set learning rate to 7.213895789838334e-05.
Epoch 8 Iter 0 Loss: [0.4956868], Accuracy: 0.8
Epoch 8 Iter 10 Loss: [0.8255708], Accuracy: 0.5
Epoch 8 Validation Loss: [0.67082316] Accuracy: 0.5833333333333334
Epoch 9: ExponentialDecay set learning rate to 6.9253399582448e-05.
Epoch 9 Iter 0 Loss: [0.5112359], Accuracy: 0.8
Epoch 9 Iter 10 Loss: [0.5681261], Accuracy: 0.6
Epoch 9 Validation Loss: [0.6324358] Accuracy: 0.65
Epoch 10: ExponentialDecay set learning rate to 6.648326359915008e-05.
Epoch 10 Iter 0 Loss: [0.4329744], Accuracy: 0.9
Epoch 10 Iter 10 Loss: [0.7594853], Accuracy: 0.5
Epoch 10 Validation Loss: [0.5622995] Accuracy: 0.7000000000000001
Epoch 11: ExponentialDecay set learning rate to 6.382393305518408e-05.
Epoch 11 Iter 0 Loss: [0.3882313], Accuracy: 0.8
Epoch 11 Iter 10 Loss: [0.48908353], Accuracy: 0.8
Epoch 11 Validation Loss: [0.54760677] Accuracy: 0.7666666666666666
Best accuracy: 0.7666666666666666
Epoch 12: ExponentialDecay set learning rate to 6.12709757329767e-05.
Epoch 12 Iter 0 Loss: [0.348113], Accuracy: 0.8
Epoch 12 Iter 10 Loss: [0.41176265], Accuracy: 0.9
Epoch 12 Validation Loss: [0.5595346] Accuracy: 0.7833333333333332
Best accuracy: 0.7833333333333332
Epoch 13: ExponentialDecay set learning rate to 5.882013670365765e-05.
Epoch 13 Iter 0 Loss: [0.3314957], Accuracy: 1.0
Epoch 13 Iter 10 Loss: [0.32749084], Accuracy: 1.0
Epoch 13 Validation Loss: [0.5549964] Accuracy: 0.7333333333333334
Epoch 14: ExponentialDecay set learning rate to 5.6467331235511337e-05.
Epoch 14 Iter 0 Loss: [0.36251408], Accuracy: 0.9
Epoch 14 Iter 10 Loss: [0.23271704], Accuracy: 1.0
Epoch 14 Validation Loss: [0.5716041] Accuracy: 0.7333333333333334
Epoch 15: ExponentialDecay set learning rate to 5.4208637986090884e-05.
Epoch 15 Iter 0 Loss: [0.35667133], Accuracy: 0.9
Epoch 15 Iter 10 Loss: [0.38534072], Accuracy: 0.9
Epoch 15 Validation Loss: [0.55513054] Accuracy: 0.7333333333333334
Epoch 16: ExponentialDecay set learning rate to 5.2040292466647244e-05.
Epoch 16 Iter 0 Loss: [0.37971848], Accuracy: 0.9
Epoch 16 Iter 10 Loss: [0.37076467], Accuracy: 0.8
Epoch 16 Validation Loss: [0.546179] Accuracy: 0.6833333333333332
Epoch 17: ExponentialDecay set learning rate to 4.9958680767981346e-05.
Epoch 17 Iter 0 Loss: [0.4073751], Accuracy: 0.8
Epoch 17 Iter 10 Loss: [0.35747206], Accuracy: 0.8
Epoch 17 Validation Loss: [0.50928646] Accuracy: 0.7666666666666667
Epoch 18: ExponentialDecay set learning rate to 4.7960333537262095e-05.
Epoch 18 Iter 0 Loss: [0.26402354], Accuracy: 1.0
Epoch 18 Iter 10 Loss: [0.3062905], Accuracy: 0.9
Epoch 18 Validation Loss: [0.57567877] Accuracy: 0.7166666666666667
Epoch 19: ExponentialDecay set learning rate to 4.6041920195771606e-05.
Epoch 19 Iter 0 Loss: [0.19812493], Accuracy: 1.0
Epoch 19 Iter 10 Loss: [0.37852132], Accuracy: 0.9
Epoch 19 Validation Loss: [0.55201167] Accuracy: 0.6666666666666666
Epoch 20: ExponentialDecay set learning rate to 4.4200243387940746e-05.
Epoch 20 Iter 0 Loss: [0.2532389], Accuracy: 1.0
Epoch 20 Iter 10 Loss: [0.19800806], Accuracy: 1.0
Epoch 20 Validation Loss: [0.6551569] Accuracy: 0.7000000000000001
Epoch 21: ExponentialDecay set learning rate to 4.243223365242311e-05.
Epoch 21 Iter 0 Loss: [0.26303864], Accuracy: 1.0
Epoch 21 Iter 10 Loss: [0.36150858], Accuracy: 0.8
Epoch 21 Validation Loss: [0.5345089] Accuracy: 0.6833333333333332
Epoch 22: ExponentialDecay set learning rate to 4.073494430632618e-05.
Epoch 22 Iter 0 Loss: [0.18942678], Accuracy: 1.0
Epoch 22 Iter 10 Loss: [0.51950914], Accuracy: 0.7
Epoch 22 Validation Loss: [0.56237996] Accuracy: 0.6333333333333333
Epoch 23: ExponentialDecay set learning rate to 3.910554653407313e-05.
Epoch 23 Iter 0 Loss: [0.19643877], Accuracy: 0.9
Epoch 23 Iter 10 Loss: [0.2246792], Accuracy: 1.0
Epoch 23 Validation Loss: [0.61442226] Accuracy: 0.7000000000000001
Epoch 24: ExponentialDecay set learning rate to 3.754132467271021e-05.
Epoch 24 Iter 0 Loss: [0.30469954], Accuracy: 1.0
Epoch 24 Iter 10 Loss: [0.08427861], Accuracy: 1.0
Epoch 24 Validation Loss: [0.51900846] Accuracy: 0.7333333333333333
Epoch 25: ExponentialDecay set learning rate to 3.60396716858018e-05.

8、モデルパフォーマンスの可視化

ここでは、トレーニングセットと検証セットのモデル精度と損失をプロットします。

検証セットはクラスバランスが取れているため、精度はモデルパフォーマンスを偏りのない方法で表現できます。

In [32]

fig, ax = plt.subplots(2, 1, figsize=(8, 12))
ax = ax.ravel()


ax[0].plot(training_losses)
ax[0].plot(validation_losses)
ax[0].set_title("Model {}".format("Loss"))
ax[0].set_xlabel("epochs")
ax[0].legend(["train", "val"])

ax[1].plot(training_accuracies)
ax[1].plot(validation_accuracies)
ax[1].set_title("Model {}".format("Accuracy"))
ax[1].set_xlabel("epochs")
ax[1].legend(["train", "val"])
<matplotlib.legend.Legend at 0x7efe65dae250>
<Figure size 800x1200 with 2 Axes>

注意すべき点は、サンプル数が非常に少なく200個しかないことです。

そのため、検証セットとトレーニングセットでの予測結果に顕著な差異が見られます。ここでは1000以上のCTスキャンで構成される完全なデータセットを見つけることができます。

完全なデータセットを使用すると、精度は83%に達します。

どちらの場合も、分類パフォーマンスの変動性は6-7%です。

9、モデルの予測

モデルを予測させ、効果を示します。

In [33]

param_state_dict = paddle.load("best_medical_model")
model.set_dict(param_state_dict)
model.eval() # 予測モード
def predict(index):
    image, label = validation_dataset.__getitem__(index)
    image_tensor = paddle.to_tensor(np.expand_dims(image, axis=0), dtype=paddle.float32)
    prediction = model(image_tensor)
    scores = [1 - prediction[0][0], prediction[0][0]]
    class_names = ["normal", "abnormal"]
    for score, name in zip(scores, class_names):
        print(
            "Model prediction: %.2f probability, this CT %s"
            % ((100 * score), name)
        )
    print("True label:", class_names[int(label[0])])
    data = np.rot90(image)
    data = np.transpose(data)
    data = np.reshape(data, (8, 8, 128, 128))
    rows_data, columns_data = data.shape[0], data.shape[1]
    heights = [slc[0].shape[0] for slc in data]
    widths = [slc.shape[1] for slc in data[0]]
    fig_width = 6.0
    fig_height = fig_width * sum(heights) / sum(widths)

    f, axarr = plt.subplots(
        rows_data,
        columns_data,
        figsize=(fig_width, fig_height),
        gridspec_kw={"height_ratios": heights},
    )

    for i in range(rows_data):
        for j in range(columns_data):
            axarr[i, j].imshow(data[i][j], cmap="gray")
            axarr[i, j].axis("off")
    plt.subplots_adjust(wspace=0, hspace=0, left=0, right=1, bottom=0, top=1)

    plt.text(-930, -850, "The true label is:" + class_names[int(label[0])], size=15,
            color="yellow", style="italic", weight="light",
            bbox=dict(facecolor="black", alpha=0.3))  
    plt.text(-930, -790, "This model is %.2f percent confident that CT scan is %s" % ((100 * scores[0]), class_names[0]), size=15,
            color="w", style="italic", weight="light",
            bbox=dict(facecolor="black", alpha=0.3))  
    plt.text(-930, -730, "This model is %.2f percent confident that CT scan is %s" % ((100 * scores[1]), class_names[1]), size=15,
            color="r", style="italic", weight="light",
            bbox=dict(facecolor="black", alpha=0.3)) 
    plt.show()

In [35]

predict(2)
Model prediction: 43.85 probability, this CT normal
Model prediction: 56.15 probability, this CT abnormal
True label: abnormal
<Figure size 600x600 with 64 Axes>

In [37]

predict(11)
Model prediction: 32.01 probability, this CT normal
Model prediction: 67.99 probability, this CT abnormal
True label: abnormal
<Figure size 600x600 with 64 Axes>

In [27]

predict(55)
Model prediction: 79.78 probability, this CT normal
Model prediction: 20.22 probability, this CT abnormal
True label: normal
<Figure size 600x600 with 64 Axes>

link

タグ: 3D-CNN CTスキャン COVID-19 医療画像解析 ディープラーニング

5月26日 01:48 投稿