レジデュアルネットワーク(ResNet)の原理と実装

より深いネットワークを設計するにつれて、「新たに追加された層がニューラルネットワークの性能をどのように向上させるか」ということを深く理解することが重要になります。さらに重要なのは、そのようなネットワークを設計する能力であり、追加された層がネットワークをより表現力豊かなものにします。質的な飛躍を達成するためには、いくつかの数学的基礎知識が必要です。

関数クラス

まず、特定のニューラルネットワークアーキテクチャクラス\(\mathcal{F}\)を想定しましょう。これは学習率やその他のハイパーパラメータ設定を含みます。すべての\(f \in \mathcal{F}\)に対して、適切なデータセットでトレーニングすることで得られるいくつかのパラメータセット(例えば重みとバイアス)が存在します。今、\(f^*\)が私たちが実際に見つけたい関数であると仮定しましょう。もし\(f^* \in \mathcal{F}\)であれば、私たちは簡単にそれをトレーニングして得ることができますが、通常はそんなに幸運ではありません。代わりに、私たちは\(\mathcal{F}\)における最良の選択である関数\(f^*_\mathcal{F}\)を見つけようとします。例えば、\(\mathbf{X}\)特徴量と\(\mathbf{y}\)ラベルを持つデータセットが与えられた場合、以下の最適化問題を解くことによってそれを見つけようとすることができます:

[f^*_\mathcal{F} := \mathop{\mathrm{argmin}}_f L(\mathbf{X}, \mathbf{y}, f) \text{ subject to } f \in \mathcal{F}. ]では、より真の\(f^*\)に近い関数をどのように得るのでしょうか?唯一の合理的な可能性は、より強力なアーキテクチャ\(\mathcal{F}'\)を設計する必要があるということです。言い換えれば、\(f^*_{\mathcal{F}'}\)が\(f^*_{\mathcal{F}}\)よりも「より近い」ことを期待しています。しかし、もし\(\mathcal{F} \not\subseteq \mathcal{F}'\)であれば、新しい体系が「より近い」とは保証されません。実際、\(f^*_{\mathcal{F}'}\)は悪化する可能性があります:図7.6.1に示すように、ネストされていない関数クラス(non-nested function)の場合、より複雑な関数クラスが必ずしも「真の」関数\(f^*\)に近づくとは限りません(複雑度は\(\mathcal{F}_1\)から\(\mathcal{F}_6\)に向かって増加)。下図の左側では、\(\mathcal{F}_3\)が\(\mathcal{F}_1\)よりも\(f^*\)により近いですが、\(\mathcal{F}_6\)はさらに離れています。一方、下図右側のネストされた関数クラス(nested function)\(\mathcal{F}_1 \subseteq \ldots \subseteq \mathcal{F}_6\)では、この問題を回避できます。

したがって、より複雑な関数クラスが小さな関数クラスを含む場合にのみ、性能の向上を保証できます。深層ニューラルネットワークでは、新たに追加された層を恒等写像(identity function)\(f(\mathbf{x}) = \mathbf{x}\)としてトレーニングできる場合、新しいモデルと元のモデルは同様に有効になります。同時に、新しいモデルがトレーニングデータセットに適合するためにより優れた解を導き出す可能性があるため、層を追加することはトレーニング誤差を減らしやすくなります。

この問題に対して、He Kaimingらは残差ネットワーク(ResNet)を提案しました (He et al., 2016)。これは2015年のImageNet画像認識チャレンジで優勝し、その後の深層ニューラルネットワークの設計に大きな影響を与えました。残差ネットワークの核心的な考え方は、追加される各層が元の関数をその要素の1つとして含みやすくなるべきだというものです。これにより、残差ブロック(residual blocks)が生まれ、この設計は深層ニューラルネットワークの構築方法に大きな影響を与えました。これにより、ResNetは2015年のImageNet大規模視覚認識チャレンジで優勝しました。

残差ブロック

ニューラルネットワークの局所に焦点を当てましょう:図7.6.2に示すように、元の入力を\(x\)とし、学習したい理想の写像を\(f(\mathbf{x})\)(下図の活性化関数の入力として)とします。下図左側の破線枠内の部分はこの写像\(f(\mathbf{x})\)を直接近似する必要がありますが、右側の破線枠内の部分は残差写像\(f(\mathbf{x}) - \mathbf{x}\)を近似する必要があります。残差写像は現実では最適化しやすいことが多いです。この節の冒頭で述べた恒等写像を学習したい理想の写像\(f(\mathbf{x})\)として、下図右側の破線枠内の上の重み付け演算(例えばアフィン)の重みとバイアスパラメータを0に設定するだけで、\(f(\mathbf{x})\)は恒等写像になります。実際には、理想の写像\(f(\mathbf{x})\)が恒等写像に非常に近い場合、残差写像も恒等写像の微細な変動を捉えやすくなります。図7.6.2右側はResNetの基本アーキテクチャである残差ブロック(residual block)です。残差ブロックでは、入力が層間のデータ経路をより速く前方に伝播できます。

ResNetはVGGの完全な\(3\times3\)畳み込み層設計を踏襲しています。残差ブロックにはまず、同じ出力チャネル数を持つ2つの\(3\times3\)畳み込み層があります。各畳み込み層の後にバッチ正規化層とReLU活性化関数が続きます。次に、これら2つの畳み込み演算をスキップする層間データ経路を通じて、入力を最後のReLU活性化関数の前に直接加えます。この設計では、2つの畳み込み層の出力が入力の形状と同じである必要があり、それらを加算できるようにします。チャネル数を変更したい場合は、入力を必要な形状に変換した後に加算演算を行うために、追加の\(1\times1\)畳み込み層を導入する必要があります。残差ブロックの実装は以下の通りです:

import torch
from torch import nn
from torch.nn import functional as F
from d2l import torch as d2l

class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, use_projection=False, stride=1):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, stride=stride)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
        if use_projection:
            self.projection = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride)
        else:
            self.projection = None
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.bn2 = nn.BatchNorm2d(out_channels)
        
    def forward(self, X):
        identity = X
        out = F.relu(self.bn1(self.conv1(X)))
        out = self.bn2(self.conv2(out))
        
        if self.projection:
            identity = self.projection(identity)
            
        out += identity
        return F.relu(out)

このコードは2種類のネットワークを生成します:1つはuse_projection=Falseの場合で、活性化関数を適用する前に入力を出力に追加します。もう1つはuse_projection=Trueの場合で、\(1\times1\)畳み込みで調整されたチャネルと解像度を追加します。

次に、入力と出力の形状が一致する場合を見てみましょう。

block = ResidualBlock(3, 3)
X = torch.rand(4, 3, 6, 6)
Y = block(X)
print(Y.shape)
torch.Size([4, 3, 6, 6])

出力チャネル数を増加させると同時に、高さと幅を半分にすることもできます。

block = ResidualBlock(3, 6, use_projection=True, stride=2)
print(block(X).shape)
torch.Size([4, 6, 3, 3])

ResNetモデル

ResNetの最初の2層は以前に紹介したGoogLeNetと同じです:出力チャネル数が64、ストライドが2の\(7\times7\)畳み込み層の後、ストライドが2の\(3\times3\)最大プーリング層が続きます。違いは、ResNetの各畳み込み層の後にバッチ正規化層が追加されている点です。

b1 = nn.Sequential(nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3),
                  nn.BatchNorm2d(64), nn.ReLU(),
                  nn.MaxPool2d(kernel_size=3, stride=2, padding=1))

GoogLeNetの後には4つのInceptionブロックモジュールが続きますが、ResNetは4つの残差ブロックモジュールを使用し、各モジュールでは同じ出力チャネル数を持つ複数の残差ブロックを使用します。最初のモジュールのチャネル数は入力チャネル数と同じです。以前にストライドが2の最大プーリング層が使用されているため、高さと幅を減らす必要はありません。その後の各モジュールでは、最初の残差ブロックで前のモジュールのチャネル数を2倍にし、高さと幅を半分にします。

次にこのモジュールを実装しましょう。最初のモジュールには特別な処理を行っています。

def create_resnet_block(in_channels, out_channels, num_blocks, first_block=False):
    layers = []
    for i in range(num_blocks):
        if i == 0 and not first_block:
            layers.append(ResidualBlock(in_channels, out_channels, use_projection=True, stride=2))
        else:
            layers.append(ResidualBlock(out_channels, out_channels))
    return nn.Sequential(*layers)

次にResNetにすべての残差ブロックを追加します。ここでは各モジュールに2つの残差ブロックを使用します。

b2 = create_resnet_block(64, 64, 2, first_block=True)
b3 = create_resnet_block(64, 128, 2)
b4 = create_resnet_block(128, 256, 2)
b5 = create_resnet_block(256, 512, 2)

最後に、GoogLeNetと同様に、ResNetにグローバル平均プーリング層と全結合層の出力を追加します。

net = nn.Sequential(b1, b2, b3, b4, b5,
                   nn.AdaptiveAvgPool2d((1, 1)),
                   nn.Flatten(), nn.Linear(512, 10))

各モジュールには4つの畳み込み層(恒等写像の\(1\times1\)畳み込み層を除く)があります。最初の\(7\times7\)畳み込み層と最後の全結合層を合わせて、合計18層です。したがって、このモデルは通常ResNet-18と呼ばれます。異なるチャネル数とモジュール内の残差ブロック数を構成することで、異なるResNetモデルを取得できます。例えば、より深い152層のResNet-52などです。ResNetの主なアーキテクチャはGoogLeNetに似ていますが、ResNetのアーキテクチャはよりシンプルで、変更も容易です。これらの要因がすべて、ResNetが迅速に広く使用される理由となっています。下図は完全なResNet-18を示しています。

ResNetをトレーニングする前に、ResNet内の異なるモジュールの入力形状がどのように変化するかを見てみましょう。これまでのすべてのアーキテクチャでは、解像度が低下し、チャネル数が増加し、グローバル平均プーリング層がすべての特徴を集約するまで続きます。

X = torch.rand(size=(1, 1, 224, 224))
for layer in net:
    X = layer(X)
    print(layer.__class__.__name__, 'output shape:\t', X.shape)
Sequential output shape:	 torch.Size([1, 64, 56, 56])
Sequential output shape:	 torch.Size([1, 64, 56, 56])
Sequential output shape:	 torch.Size([1, 128, 28, 28])
Sequential output shape:	 torch.Size([1, 256, 14, 14])
Sequential output shape:	 torch.Size([1, 512, 7, 7])
AdaptiveAvgPool2d output shape:	 torch.Size([1, 512, 1, 1])
Flatten output shape:	 torch.Size([1, 512])
Linear output shape:	 torch.Size([1, 10])

モデルのトレーニング

以前と同様に、Fashion-MNISTデータセットでResNetをトレーニングします。

learning_rate, epochs, batch_size = 0.05, 10, 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size, resize=96)
d2l.train_ch6(net, train_iter, test_iter, epochs, learning_rate, d2l.try_gpu())
loss 0.016, train acc 0.995, test acc 0.892
1737.5 examples/sec on cuda:0

まとめ

  • ネストされた関数(nested function)を学習することは、ニューラルネットワークのトレーニングにおける理想的な状況です。深層ニューラルネットワークでは、別の層を恒等写像(identity function)として学習することは比較的容易です(これは極端なケースですが)。
  • 残差写像は、同じ関数をより容易に学習できます。例えば、重み層のパラメータをゼロに近似することです。
  • 残差ブロック(residual blocks)を利用することで、効果的な深層ニューラルネットワークをトレーニングできます:入力は層間の残差接続を通じてより速く前方に伝播します。
  • 残差ネットワーク(ResNet)は、その後の深層ニューラルネットワークの設計に大きな影響を与えました。

タグ: 深層学習 ニューラルネットワーク ResNet 残差学習 コンピュータビジョン

5月23日 07:32 投稿