PyTorch で構築する ResNet50 実装のステップバイステップ解説

Residual Network (ResNet) の基礎構造

ResNet(Residual Network)は、非常に深いニューラルネットにおいて勾配消失問題を解決するために考案されたアーキテクチャです。この手法の核となるのは「スキップ接続(Skip Connection)」または「残差パス」と呼ばれる仕組みで、入力を直接出力層へ結合することで、F(x) + x という形での学習を実現します。これにより、層が深くなっても情報が効率的に伝播され、最適化が容易になります。

ResNet の主要な構成要素として、主に以下の 2 つの基本ユニットがあります:

  • Standard Block: 通常の 2 つの畳み込みレイヤーを持つ構成。比較的浅いネットワーク(例:ResNet18, ResNet34)で使用されます。
  • Bottleneck Block: 1×1、3×3、1×1 の 3 つの畳み込みレイヤーで構成される縮約と復元のブロック。計算効率を向上させ、より深いネットワーク(例:ResNet50, ResNet101)に適しています。

PyTorch による実装コード

以下に、ResNet50 のベースとなるネットワークを PyTorch でゼロから定義する手順を示します。既存のライブラリを使用せず、構造を理解するための実装例となっています。

1. ライブラリのインポートと定数定義

まず、必要なモジュールを読み込みます。

import torch.nn as nn
import math

# 重みファイルへのパス設定(例示)
WEIGHT_PATHS = {
    'resnet50': 'weights/resnet50_checkpoint.pth'
}

2. ユーティリティ関数の定義

一般的な 3x3 コンボリューションレイヤーを作成するヘルパー関数です。

def build_convolution(in_channels, out_channels, stride=1):
    return nn.Conv2d(
        in_channels, 
        out_channels, 
        kernel_size=3, 
        stride=stride, 
        padding=1, 
        bias=False
    )

3. Standard Residual Block の実装

基本的な残差ブロッククラスです。expansion_factor は 1 です。

class StandardResBlock(nn.Module):
    expansion_factor = 1

    def __init__(self, input_channels, output_channels, stride=1, downsample=None):
        super().__init__()
        
        self.conv_a = build_convolution(input_channels, output_channels, stride)
        self.bn_a = nn.BatchNorm2d(output_channels)
        self.activation = nn.ReLU(inplace=True)
        
        self.conv_b = build_convolution(output_channels, output_channels, stride)
        self.bn_b = nn.BatchNorm2d(output_channels)
        
        self.downsample = downsample
        self.stride = stride

    def forward(self, input_tensor):
        residual_path = input_tensor

        out = self.conv_a(input_tensor)
        out = self.bn_a(out)
        out = self.activation(out)

        out = self.conv_b(out)
        out = self.bn_b(out)

        # 入出力チャネルまたはストライドが異なる場合のみ、残差パスに変換を追加
        if self.downsample is not None:
            residual_path = self.downsample(residual_path)

        out += residual_path
        out = self.activation(out)
        return out

4. Bottleneck Residual Block の実装

ResNet50 以降で使用される、より圧縮されたブロック構造です。expansion_factor は 4 になります。

class BottleneckResBlock(nn.Module):
    expansion_factor = 4

    def __init__(self, input_channels, output_channels, stride=1, downsample=None):
        super().__init__()
        
        # 1x1 畳み込みでチャネル数を削減・変換
        self.conv_1 = nn.Conv2d(input_channels, output_channels, kernel_size=1, stride=stride, bias=False)
        self.bn_1 = nn.BatchNorm2d(output_channels)
        
        # 3x3 畳み込み
        self.conv_2 = nn.Conv2d(output_channels, output_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn_2 = nn.BatchNorm2d(output_channels)
        
        # 1x1 畳み込みで元に戻す(チャネル数を 4 倍拡張)
        self.conv_3 = nn.Conv2d(output_channels, output_channels * 4, kernel_size=1, bias=False)
        self.bn_3 = nn.BatchNorm2d(output_channels * 4)
        
        self.activation = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, input_tensor):
        residual_path = input_tensor

        identity = self.conv_1(input_tensor)
        identity = self.bn_1(identity)
        identity = self.activation(identity)

        middle = self.conv_2(identity)
        middle = self.bn_2(middle)
        middle = self.activation(middle)

        out = self.conv_3(middle)
        out = self.bn_3(out)

        if self.downsample is not None:
            residual_path = self.downsample(residual_path)

        out += residual_path
        out = self.activation(out)
        return out

5. メイン ResNet クラスの実装

すべてのレイヤをまとめる親クラスです。_build_stage メソッドが各セクション(layer1〜4)を動的に生成します。

class DeepResNet(nn.Module):
    def __init__(self, block_type, stage_blocks_config, num_classes=1000):
        super().__init__()
        
        self.initial_features = 64
        
        # 初期畳み込み(Conv1)
        self.conv_init = nn.Conv2d(3, self.initial_features, kernel_size=7, stride=2, padding=3, bias=False)
        self.norm_init = nn.BatchNorm2d(self.initial_features)
        self.relu_init = nn.ReLU(inplace=True)
        self.max_pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1, ceil_mode=True)
        
        # ステージごとのブロック構築
        self.stage1 = self._build_stage(block_type, 64, stage_blocks_config[0])
        self.stage2 = self._build_stage(block_type, 128, stage_blocks_config[1], stride=2)
        self.stage3 = self._build_stage(block_type, 256, stage_blocks_config[2], stride=2)
        self.stage4 = self._build_stage(block_type, 512, stage_blocks_config[3], stride=2)

        # プールリングと完全結合層
        self.avg_pool = nn.AvgPool2d(7)
        self.fc_classifier = nn.Linear(512 * block_type.expansion_factor, num_classes)

        # 重みの初期化ロジック
        self._initialize_weights()

    def _build_stage(self, block_cls, channels_per_block, num_blocks, stride=1):
        down_proj = None
        
        # ストライドやチャンネル数が一致しない場合はダウンサンプリングが必要
        if stride != 1 or self.initial_features != channels_per_block * block_cls.expansion_factor:
            down_proj = nn.Sequential(
                nn.Conv2d(self.initial_features, channels_per_block * block_cls.expansion_factor, 
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(channels_per_block * block_cls.expansion_factor)
            )

        layers_list = []
        first_block = block_cls(self.initial_features, channels_per_block, stride, down_proj)
        layers_list.append(first_block)
        
        # 次の入力チャネル数を更新
        self.initial_features = channels_per_block * block_cls.expansion_factor
        
        # 残りのブロック追加
        for _ in range(1, num_blocks):
            layers_list.append(block_cls(self.initial_features, channels_per_block))

        return nn.Sequential(*layers_list)

    def _initialize_weights(self):
        for layer_module in self.modules():
            if isinstance(layer_module, nn.Conv2d):
                n = layer_module.kernel_size[0] * layer_module.kernel_size[1] * layer_module.out_channels
                nn.init.normal_(layer_module.weight, mean=0.0, std=math.sqrt(2. / n))
                if layer_module.bias is not None:
                    nn.init.constant_(layer_module.bias, 0)
            elif isinstance(layer_module, nn.BatchNorm2d):
                nn.init.constant_(layer_module.weight, 1)
                nn.init.constant_(layer_module.bias, 0)

    def forward(self, input_data):
        x = self.conv_init(input_data)
        x = self.norm_init(x)
        x = self.relu_init(x)
        x = self.max_pool(x)

        x = self.stage1(x)
        x = self.stage2(x)
        x = self.stage3(x)
        x = self.stage4(x)

        x = self.avg_pool(x)
        x = x.view(x.size(0), -1)
        x = self.fc_classifier(x)
        return x

6. モデルの検証と確認

ResNet50 のパラメータ配置を確認するには、以下の様にインスタンス化し構造を確認します。

# ResNet50 に相当する構成 [3, 4, 6, 3] ボリューム設定
# ResNet18 などの基本型は StandardResBlock を使用
model_instance = DeepResNet(BottleneckResBlock, [3, 4, 6, 3], 1000)
print(model_instance)

上記を実行すると、畳み込み層の数やストライド、全結合層の入力次元などが一覧表示されます。また、StandardResBlockBottleneckResBlock を切り替えることで、同じインターフェースから様々な深さの ResNet バリアントを瞬時に構築可能です。

タグ: PyTorch deep-learning ResNet computer-vision neural-network

6月4日 16:32 投稿