生成モデルの基礎:GANの理論と実装入門

生成モデル概要

GPTやQwenなどの一般的な生成モデルは主にテキスト生成に特化していますが、画像生成技術では主に以下の三つのアプローチが主流です:1. GAN、2. VAE、3. 拡散モデル。本記事では生成モデルの基礎としてGANについて詳細に解説します。

GAN(Generative Adversarial Networks)の理論

GANの核心概念は、生成ネットワークGがデータ分布を学習し、識別ネットワークDがその生成物が訓練データかモデル出力かを判定することです。

数式的表現では、潜在変数zからノイズ分布p_z(z)を生成し、生成器G(z;θ_g)を通じてデータ空間にマッピングします。識別器D(x;θ_d)は、入力が実データか生成データかを識別します。

損失関数は以下のようになります:

V(G,D) = E[log D(x)] + E[log(1 - D(G(z)))]

このゲームにおいて、識別器は生成物をできるだけ識別しようとし、生成器は識別されないように本物に近いデータを生成しようとします。

数学的解析

最適な識別器D*を得るために、以下の式を最大化します:

D*(x) = p_data(x) / (p_data(x) + p_gen(x))

このとき、最適な識別器における値は:

max_D V(G,D) = -2log2 + 2JSD(p_data||p_gen)

生成器の最適化目標は、Jensen-Shannonダイバージェンスを最小化することであり、理想的にはp_data = p_genが成立する状態を目指します。

実装コード

MNISTデータセットを使用した基本的なGAN実装:

class Generator(nn.Module):
    def __init__(self, latent_dim, output_dim):
        super().__init__()
        self.network = nn.Sequential(
            nn.Linear(latent_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Linear(512, output_dim),
            nn.Tanh()
        )
    
    def forward(self, noise):
        return self.network(noise)

class Discriminator(nn.Module):
    def __init__(self, input_dim):
        super().__init__()
        self.classifier = nn.Sequential(
            nn.Linear(input_dim, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )
    
    def forward(self, data):
        return self.classifier(data)

def train_step(real_batch, generator, discriminator, g_optim, d_optim, criterion):
    batch_size = real_batch.size(0)
    true_labels = torch.ones(batch_size, 1, device=device)
    false_labels = torch.zeros(batch_size, 1, device=device)
    
    # 識別器の訓練
    d_optim.zero_grad()
    
    # 実データの処理
    real_output = discriminator(real_batch)
    loss_real = criterion(real_output, true_labels)
    
    # 生成データの処理
    latent_noise = torch.randn(batch_size, latent_size, device=device)
    generated_samples = generator(latent_noise)
    fake_output = discriminator(generated_samples.detach())
    loss_fake = criterion(fake_output, false_labels)
    
    d_loss = loss_real + loss_fake
    d_loss.backward()
    d_optim.step()
    
    # 生成器の訓練
    g_optim.zero_grad()
    fake_output = discriminator(generated_samples)
    g_loss = criterion(fake_output, true_labels)
    g_loss.backward()
    g_optim.step()

問題点と改善策

訓練の不安定性

JSダイバージェンスが飽和状態に達すると、勾配消失が発生し学習が進まなくなることがあります。WGANではWasserstein距離を使用してこの問題を解決しています。

モード崩壊

生成器が特定の出力パターンのみを生成し続ける現象です。これにより多様性が失われます。

WGANの実装変更点

WGANでは以下の変更が必要です:

  • 識別器(クリティック)の最終層にシグモイド非使用
  • Adamなどのモメンタムベース最適化の回避
  • 勾配クリッピングの実施
# WGANの損失計算
def wgan_discriminator_loss(critic_real, critic_fake):
    return -(torch.mean(critic_real) - torch.mean(critic_fake))

def wgan_generator_loss(critic_fake):
    return -torch.mean(critic_fake)

# 勾配制限
for param in critic.parameters():
    param.data.clamp_(-clip_bound, clip_bound)

タグ: GAN WGAN 生成モデル 深層学習 PyTorch

6月6日 22:01 投稿