DenseNetにSEモジュールを統合した猴痘ウイルス画像分類モデルの実装

データ準備と前処理

必要なライブラリをインポートし、画像データをロードする。

import torch
import torch.nn as nn
import torchvision.transforms as T
from torchvision.datasets import ImageFolder
from pathlib import Path
from PIL import Image
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings("ignore")

データディレクトリからクラス名を取得し、標準的なImageNet正規化を適用したトランスフォームを定義する。

data_root = Path("./data/4-data/")
class_names = [p.name for p in data_root.iterdir() if p.is_dir()]

transform = T.Compose([
    T.Resize((224, 224)),
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

dataset = ImageFolder(data_root, transform=transform)

データセットを8:2で訓練・テストに分割し、DataLoaderを構築する。

train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_set, test_set = torch.utils.data.random_split(dataset, [train_size, test_size])

batch_size = 32
train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=1)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=1)

モデル構築:DenseNet-121 + SEモジュール

まず、Squeeze-and-Excitation(SE)モジュールを実装する。これはチャネルごとの重要度を学習し、特徴マップに重みを付与する機構である。

class SEBlock(nn.Module):
    def __init__(self, channels, reduction=16):
        super().__init__()
        self.squeeze = nn.AdaptiveAvgPool2d(1)
        self.excite = nn.Sequential(
            nn.Linear(channels, channels // reduction, bias=False),
            nn.ReLU(),
            nn.Linear(channels // reduction, channels, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        b, c, _, _ = x.shape
        y = self.squeeze(x).view(b, c)
        y = self.excite(y).view(b, c, 1, 1)
        return x * y

次に、DenseNetの基本構成要素であるConvBlockとDenseBlockを定義する。

class Bottleneck(nn.Module):
    def __init__(self, in_ch, growth_rate):
        super().__init__()
        inter_ch = 4 * growth_rate
        self.bn1 = nn.BatchNorm2d(in_ch)
        self.conv1 = nn.Conv2d(in_ch, inter_ch, 1, bias=False)
        self.bn2 = nn.BatchNorm2d(inter_ch)
        self.conv2 = nn.Conv2d(inter_ch, growth_rate, 3, padding=1, bias=False)

    def forward(self, x):
        out = self.conv1(torch.relu(self.bn1(x)))
        out = self.conv2(torch.relu(self.bn2(out)))
        return torch.cat([x, out], dim=1)

class DenseLayer(nn.Module):
    def __init__(self, num_bottlenecks, in_ch, growth_rate):
        super().__init__()
        layers = []
        for i in range(num_bottlenecks):
            layers.append(Bottleneck(in_ch + i * growth_rate, growth_rate))
        self.layers = nn.Sequential(*layers)

    def forward(self, x):
        return self.layers(x)

Transition層は特徴マップの解像度を下げ、チャネル数を調整する。

class Transition(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.bn = nn.BatchNorm2d(in_ch)
        self.conv = nn.Conv2d(in_ch, out_ch, 1, bias=False)
        self.pool = nn.AvgPool2d(2)

    def forward(self, x):
        x = self.conv(torch.relu(self.bn(x)))
        return self.pool(x)

最終的なネットワークでは、最後のDenseBlockの出力にSEモジュールを挿入する。

class ModifiedDenseNet(nn.Module):
    def __init__(self, num_classes=4, growth_rate=32):
        super().__init__()
        self.growth_rate = growth_rate

        # 初期層
        self.stem = nn.Sequential(
            nn.Conv2d(3, 64, 7, stride=2, padding=3, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(3, stride=2, padding=1)
        )

        # Dense Blocks と Transition Layers
        self.dense1 = DenseLayer(6, 64, growth_rate)
        self.trans1 = Transition(64 + 6 * growth_rate, 128)

        self.dense2 = DenseLayer(12, 128, growth_rate)
        self.trans2 = Transition(128 + 12 * growth_rate, 256)

        self.dense3 = DenseLayer(24, 256, growth_rate)
        self.trans3 = Transition(256 + 24 * growth_rate, 512)

        self.dense4 = DenseLayer(16, 512, growth_rate)

        # 最終チャネル数
        final_ch = 512 + 16 * growth_rate
        self.se = SEBlock(final_ch, reduction=16)

        # 分類ヘッド
        self.final_bn = nn.BatchNorm2d(final_ch)
        self.avgpool = nn.AdaptiveAvgPool2d(1)
        self.classifier = nn.Linear(final_ch, num_classes)

    def forward(self, x):
        x = self.stem(x)
        x = self.trans1(self.dense1(x))
        x = self.trans2(self.dense2(x))
        x = self.trans3(self.dense3(x))
        x = self.dense4(x)
        x = self.se(x)
        x = torch.relu(self.final_bn(x))
        x = self.avgpool(x).flatten(1)
        return self.classifier(x)

訓練と評価

デバイス設定、損失関数、最適化手法を定義し、訓練ループを実装する。

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ModifiedDenseNet(num_classes=len(class_names)).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-4)
epochs = 20

訓練関数とテスト関数は以下の通り。

def train_one_epoch(loader, model, loss_fn, opt):
    model.train()
    total_loss, correct = 0, 0
    n_samples = len(loader.dataset)
    for X, y in loader:
        X, y = X.to(device), y.to(device)
        pred = model(X)
        loss = loss_fn(pred, y)

        opt.zero_grad()
        loss.backward()
        opt.step()

        total_loss += loss.item()
        correct += (pred.argmax(1) == y).sum().item()
    return correct / n_samples, total_loss / len(loader)

def evaluate(loader, model, loss_fn):
    model.eval()
    total_loss, correct = 0, 0
    n_samples = len(loader.dataset)
    with torch.no_grad():
        for X, y in loader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            total_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).sum().item()
    return correct / n_samples, total_loss / len(loader)

単一画像の推論関数も用意する。

def infer_image(img_path, model, transform, classes):
    img = Image.open(img_path).convert("RGB")
    tensor = transform(img).unsqueeze(0).to(device)
    model.eval()
    with torch.no_grad():
        logits = model(tensor)
    pred_idx = logits.argmax(1).item()
    print(f"予測クラス: {classes[pred_idx]}")

実行と可視化

訓練ループを実行し、精度と損失をプロットする。

train_acc_hist, test_acc_hist = [], []
train_loss_hist, test_loss_hist = [], []

for epoch in range(epochs):
    train_acc, train_loss = train_one_epoch(train_loader, model, criterion, optimizer)
    test_acc, test_loss = evaluate(test_loader, model, criterion)

    train_acc_hist.append(train_acc)
    test_acc_hist.append(test_acc)
    train_loss_hist.append(train_loss)
    test_loss_hist.append(test_loss)

    print(f"Epoch {epoch+1:2d} | Train Acc: {train_acc*100:.1f}% | Test Acc: {test_acc*100:.1f}%")

# 精度・損失のプロット
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(train_acc_hist, label="Train Accuracy")
plt.plot(test_acc_hist, label="Test Accuracy")
plt.legend(); plt.title("Accuracy")

plt.subplot(1, 2, 2)
plt.plot(train_loss_hist, label="Train Loss")
plt.plot(test_loss_hist, label="Test Loss")
plt.legend(); plt.title("Loss")
plt.show()

# 推論例
infer_image("./data/4-data/Monkeypox/M01_01_00.jpg", model, transform, class_names)

SEモジュールの統合戦略と考察

DenseNetの各層は密接に接続され、特徴が再利用されるため、不要なチャネル情報が蓄積される可能性がある。SEモジュールは、チャネルごとの重要度を動的に調整することで、この問題を緩和する。

本実装では、SEモジュールを最後のDenseBlock直後に配置した。これにより、高レベルのセマンティック特徴に対してチャネル重み付けを行い、分類性能を向上させる狙いがある。

この配置の利点は、追加パラメータが少なく(約6万パラメータ)、計算オーバーヘッドが小さいことである。一方で、中間層の特徴にはSEが作用しないため、局所的な特徴選択能力は制限される。より高い性能を求める場合は、各Transition層後にSEを挿入するなどの拡張も検討できる。

タグ: DenseNet SE-Net PyTorch 画像分類 注意力機構

5月19日 06:27 投稿