畳み込みニューラルネットワークにおけるGlobal Response Normalizationの実装

GRNモジュールの構造

Global Response Normalization(GRN)はConvNeXtV2で提案された正規化手法であり、特徴マップのチャネル間依存性をモデル化する機構として機能します。SE、ECA、CBAMなどのアテンション機構と同様に、特徴量の再調整を実現します。

PyTorch実装例

import torch
import torch.nn as nn

class ChannelAttentionNorm(nn.Module):
    def __init__(self, num_channels, epsilon=1e-6, data_layout='NCHW'):
        super().__init__()
        self.epsilon = epsilon
        self.set_dimensions(data_layout)
        
        self.gamma = nn.Parameter(torch.zeros(num_channels))
        self.beta = nn.Parameter(torch.zeros(num_channels))
    
    def set_dimensions(self, layout):
        if layout == 'NHWC':
            self.spatial_axes = (1, 2)
            self.channel_axis = -1
            self.param_shape = (1, 1, 1, -1)
        else:
            self.spatial_axes = (2, 3)
            self.channel_axis = 1
            self.param_shape = (1, -1, 1, 1)

    def forward(self, feature_map):
        # グローバル特徴集約
        spatial_norm = feature_map.norm(p=2, dim=self.spatial_axes, keepdim=True)
        
        # 特徴量正規化
        channel_mean = spatial_norm.mean(dim=self.channel_axis, keepdim=True)
        normalized = spatial_norm / (channel_mean + self.epsilon)
        
        # 特徴量再調整
        reweighted = feature_map * normalized
        adjusted = reweighted * self.gamma.view(self.param_shape) + self.beta.view(self.param_shape)
        
        return feature_map + adjusted

# 検証用
if __name__ == "__main__":
    attention_layer = ChannelAttentionNorm(num_channels=128, data_layout='NCHW')
    test_input = torch.randn(8, 128, 56, 56)
    output = attention_layer(test_input)

機能モジュールの詳細

1. グローバル特徴集約

空間次元(高さと幅)に対してL2ノルムを計算し、空間的特徴を単一のベクトルに集約します。

spatial_norm = feature_map.norm(p=2, dim=self.spatial_axes, keepdim=True)

2. 特徴量正規化

チャネル次元で平均値を算出し、各チャネルの相対的重要度を0~1の範囲で正規化します。

channel_mean = spatial_norm.mean(dim=self.channel_axis, keepdim=True)
normalized = spatial_norm / (channel_mean + self.epsilon)

3. 特徴量再調整

正規化された重みを入力特徴量に適用し、学習可能なパラメータγ(重み)とβ(バイアス)で調整します。スキップ接続により勾配消失問題を緩和します。

reweighted = feature_map * normalized
adjusted = reweighted * self.gamma.view(self.param_shape) + self.beta.view(self.param_shape)
return feature_map + adjusted

タグ: ConvNeXtV2 GlobalResponseNormalization 畳み込みニューラルネットワーク 特徴量正規化 PyTorch

5月21日 07:45 投稿