GAN画像生成の基本概念
生成的対抗ネットワーク(Generative Adversarial Networks, GAN)は、無監督学習において複雑なデータ分布を生成するための手法です。Ian J. Goodfellowらが2014年に提案し、生成モデルと識別モデルの対立学習によって高品質なデータ生成を実現します。
GANの基本構造は以下の通りです:
- 生成モデル:潜在変数を入力として、訓練データに類似した偽画像を生成
- 識別モデル:入力画像が本物か偽物かを分類
学習過程では、生成モデルは偽画像をより本物に近づけ、識別モデルは偽画像を正確に識別する能力を向上させます。理想的な均衡点では、識別モデルの出力確率が0.5となり、偽画像と本物画像の分布が一致します。
数学的表現では、識別モデルの出力確率をD(x)、生成モデルをG(z)とすると、損失関数は次式で表されます:
$$\min_G \max_D V(D,G) = \mathbb{E}_{x \sim p_{data}}[\log D(x)] + \mathbb{E}_{z \sim p_z}[ \log(1-D(G(z)))]$$
データセットの準備
MNIST手書き数字データセット(70,000枚)を用いてGANを訓練します。画像サイズは28×28ピクセル、グレースケールで、訓練用60,000枚、テスト用10,000枚のデータが含まれます。
データのロードと前処理
MindSporeのMnistDatasetを使用し、データをバッチ処理します:
import mindspore.dataset as ds
noise_dim = 128 # 潜在空間の次元
batch_size = 32
train_set = ds.MnistDataset(dataset_dir='./mnist_data/train')
test_set = ds.MnistDataset(dataset_dir='./mnist_data/test')
def preprocess(dataset):
dataset = ds.GeneratorDataset(dataset, ["image", "label"], shuffle=True)
dataset = dataset.map(
operations=lambda x: (x.astype('float32') / 127.5 - 1, np.random.normal(size=noise_dim).astype('float32')),
output_columns=["image", "noise_vector"]
)
return dataset.batch(batch_size)
mnist_loader = preprocess(train_set)
生成モデルの構築
生成モデルは5層の全結合層で構成され、ReLU活性化関数とBatchNormを適用します:
class Generator(nn.Cell):
def __init__(self, input_dim, output_size):
super().__init__()
self.model = nn.SequentialCell(
nn.Dense(input_dim, 256),
nn.ReLU(),
nn.BatchNorm1d(256),
nn.Dense(256, 512),
nn.ReLU(),
nn.BatchNorm1d(512),
nn.Dense(512, 1024),
nn.ReLU(),
nn.BatchNorm1d(1024),
nn.Dense(1024, output_size),
nn.Tanh()
)
def construct(self, x):
x = self.model(x)
return ops.reshape(x, (-1, 1, 28, 28))
generator = Generator(noise_dim, 28*28)
generator.update_parameters_name('gen_model')
識別モデルの構築
識別モデルは3層の全結合層で構成され、LeakyReLUとSigmoidを適用します:
class Discriminator(nn.Cell):
def __init__(self):
super().__init__()
self.model = nn.SequentialCell(
nn.Dense(28*28, 512),
nn.LeakyReLU(0.2),
nn.Dense(512, 256),
nn.LeakyReLU(0.2),
nn.Dense(256, 1),
nn.Sigmoid()
)
def construct(self, x):
x = ops.reshape(x, (-1, 28*28))
return self.model(x)
discriminator = Discriminator()
discriminator.update_parameters_name('disc_model')
学習プロセス
識別モデルと生成モデルの学習を交互に行います:
loss_fn = nn.BCELoss()
optimizer_d = nn.Adam(discriminator.trainable_params(), learning_rate=0.0002)
optimizer_g = nn.Adam(generator.trainable_params(), learning_rate=0.0002)
def train_discriminator(real_images, noise_vectors):
fake_images = generator(noise_vectors)
real_preds = discriminator(real_images)
fake_preds = discriminator(fake_images)
real_loss = loss_fn(real_preds, ops.ones_like(real_preds))
fake_loss = loss_fn(fake_preds, ops.zeros_like(fake_preds))
total_loss = real_loss + fake_loss
return total_loss
def train_generator(noise_vectors):
fake_images = generator(noise_vectors)
fake_preds = discriminator(fake_images)
return loss_fn(fake_preds, ops.ones_like(fake_preds))
for epoch in range(10):
for real_images, _ in mnist_loader:
# 識別モデルの更新
disc_loss = train_discriminator(real_images, noise_vectors)
optimizer_d(grads_d)
# 生成モデルの更新
gen_loss = train_generator(noise_vectors)
optimizer_g(grads_g)
学習結果の可視化
訓練中の生成画像を保存し、進捗を確認します:
def generate_images(noise_vectors, epoch):
generated = generator(noise_vectors)
plt.figure(figsize=(5,5))
for i in range(25):
plt.subplot(5,5,i+1)
plt.imshow(generated[i,0].asnumpy(), cmap='gray')
plt.axis('off')
plt.savefig(f'./results/generated_epoch_{epoch}.png')