DiT (Diffusion Transformer) のアーキテクチャと潜在拡散推論パイプライン実装解析

DiTの基本概念と潜在拡散の枠組み

DiT(Diffusion Transformer)は、従来の拡散モデルにおいてU-Netバックボーンを純粋なTransformer構造へ置換したモデルアーキテクチャである。テキスト生成や動画生成で採用されている大規模トランスフォーマーと画像生成パイプラインの構造を統一することで、スケーラビリティと最適化効率を向上させることを目的としている。潜在拡散モデル(Latent Diffusion)の設計思想を継承しており、ピクセル空間ではなく圧縮された潜在空間においてノイズ除去プロセスを執行する。

基本的な構成要素は以下の3つに分類される。

  • スケジューラ(Sampler):逆拡散過程における時間ステップの制御とノイズ除去係数の計算を担当する。
  • VAE(変分自己符号化器):画像空間と潜在空間を双方向に変換し、高次元データの計算コストを低減する。
  • トランスフォーマーコア:時刻情報とクラス条件を統合し、次のステップで除去すべきノイズ成分を予測する主ネットワーク。

公開されている標準版DiTは大規模なテキスト対画像データセットではなく、ImageNetクラス分類データを用いて学習されているため、テキストプロンプトではなく整数のクラスIDを条件入力として受け取る。

推論パイプラインとサンプリング制御

画像生成プロセスは、正規分布からサンプリングされた初期潜在テンソルの生成から始まる。解像度256×256の3チャンネル画像をターゲットとする場合、VAEエンコーダを経由すると空間次元が1/8に縮小され、4チャンネルの潜在マップ(32×32×4)へ変換される。この低次元空間で拡散シミュレーションを実行することで、メモリ消費と計算グラフの複雑さを大幅に抑制できる。

推論フェーズでは、指定されたステップ数に従って逐次的なノイズ除去が実行される。各イテレーションにおいて、現在の時間ステップと生成対象のクラスIDがネットワークに入力され、予測されたノイズベクトルが潜在状態から分離される。最終的に得られたノイズ除去済み潜在表現をVAEデコーダへ渡すことで、元々の画像解像度へ再構成される。

Classifier-Free Guidanceの実装と制御フロー

生成品質を向上させるため、Classifier-Free Guidance(CFG)が適用される。ガイドランス強度が閾値を上回る場合、条件付き(クラス指定あり)と無条件(クラス未指定)の2つの潜在状態をバッチ次元で連結し、並列フォワードパスで処理する。ネットワーク出力からノイズ予測値を取得後、両者の差ベクトルに重みを掛けることで、指示準拠度と視覚的品質のバランスを調整する。

数式的には以下の加重合成が実行される。

guided_noise = uncond_noise + cfg_weight × (cond_noise − uncond_noise)

調整済みのノイズ推定値をスケジューラのステップ関数へ渡すことで、逆拡散公式に従って次の時刻の潜在状態へ更新される。このプロセスは指定されたイテレーション数完了時まで繰り返される。

トランスフォーマーベースのネットワーク設計

条件埋め込みとadaLN-Zeroモジュール

トランスフォーマーブロックへの入力前処理は、離散時間ベクトルとクラスIDの連続表現への変換から始まる。これらは正弦波埋め込みと線形投影を経て結合され、SiLU活性化と全結合層を通じてadaLN-Zero(Adaptive Layer Norm Zero)制御信号へ変換される。adaLN-Zeroは潜在特徴量に対して動的なスケールとシフトを適用し、加えてアテンション演算とフィードフォワード演算の出力に掛算するゲート値を生成する。これにより、各拡散ステップとカテゴリ条件に最適化された表現がネットワーク全体へ注入される。

パッチ埋め込みとシーケンス化

画像特徴の処理はパッチ埋め込み層が担う。空間解像度を半分へ縮小しつつチャンネル次元を拡張する畳み込み演算を行い、2次元テンソルをシーケンス形式(バッチ、パッチ数、チャンネル)へ平坦化する。この際、シークハンスな位置エンベディングが各トークンへ追加されることで、空間的相関を保持したままTransformerへ入力される。

トランスフォーマーブロックとUnpatchify出力

メインのトランスフォーマーブロックでは、標準的なSelf-AttentionとPointwise Feedforward Networkが階層的に配置される。adaLN-Zeroから取得した適応型パラメータは、各モジュールの前後で正規化制御やゲート演算に利用され、残差接続によって勾配伝播が安定する。最終出力層では、高次元のシーケンス表現を空間構造へ復元するUnpatchify処理が実行される。線形変換とレイヤノルムを適用した後、テンソル形状の再配置およびテンソル積による軸の整合化を経て、VAEデコーダが期待する潜在チャネル構成へ投影される。

完全な推論実装コード

以下のコードは、DiTの推論プロセスをモジュール化して実装した例である。変数命名と制御フローを標準化し、CFG制御と潜在空間変換の処理を明確に分離している。

import torch
import os
import json
from diffusers import AutoencoderKL, DPMSolverMultistepScheduler
from diffusers.models import Transformer2DModel

class DiffusionTransformerEngine:
    """DiTアーキテクチャを用いた潜在拡散推論パイプライン"""
    
    def __init__(self, checkpoint_root: str, device: str = "cuda"):
        self.device = torch.device(device)
        self.dtype = torch.float32 if "mps" in self.device.type else torch.float32
        
        self.scheduler = DPMSolverMultistepScheduler.from_pretrained(
            checkpoint_root, subfolder="scheduler"
        )
        self.dit_backbone = Transformer2DModel.from_pretrained(
            checkpoint_root, subfolder="transformer"
        ).to(self.device, dtype=self.dtype)
        self.vae_decoder = AutoencoderKL.from_pretrained(
            checkpoint_root, subfolder="vae"
        ).to(self.device, dtype=self.dtype)
        
        index_path = os.path.join(checkpoint_root, "model_index.json")
        with open(index_path, "r") as f:
            self.label_mapping = json.load(f).get("id2label", {})

    def _denormalize_latents(self, z_clean: torch.Tensor) -> torch.Tensor:
        """潜在ベクトルをVAEスケールに合わせて正規化し、画像空間へ復元"""
        z_rescaled = z_clean / self.vae_decoder.config.scaling_factor
        reconstructed = self.vae_decoder.decode(z_rescaled).sample
        # [-1, 1] → [0, 1] 変換とテンソル形式の最適化
        normalized = (reconstructed / 2.0 + 0.5).clamp(min=0.0, max=1.0)
        return normalized.permute(0, 2, 3, 1).cpu().float()

    def run_inference(
        self, 
        target_categories: list[int], 
        inference_steps: int = 25, 
        cfg_strength: float = 5.0, 
        random_seed: int = 42
    ) -> list:
        """クラス条件付き画像生成を実行"""
        torch.manual_seed(random_seed)
        
        # 初期潜在ノイズの生成 (256px画像対応: 32x32x4)
        latent_shape = (len(target_categories), 4, 32, 32)
        z_state = torch.randn(latent_shape, device=self.device, dtype=self.dtype)
        
        # CFG対応のためのバッチ拡張準備
        enable_guidance = cfg_strength > 1.0
        z_working = torch.cat([z_state, z_state], dim=0) if enable_guidance else z_state
        
        # クラスIDとダミーID(無条件生成用)の連結
        cat_ids = torch.tensor(target_categories, device=self.device)
        null_ids = torch.full_like(cat_ids, fill_value=1000)
        condition_ids = torch.cat([cat_ids, null_ids], dim=0) if enable_guidance else cat_ids
        
        # 時間ステップの初期化
        self.scheduler.set_timesteps(inference_steps, device=self.device)
        time_steps = self.scheduler.timesteps
        
        for current_t in time_steps:
            if enable_guidance:
                z_working = torch.cat([z_state, z_state], dim=0)
                
            # モデル入力のスケール調整
            z_scaled = self.scheduler.scale_model_input(z_working, current_t)
            t_tensor = torch.full((z_scaled.shape[0],), current_t, device=self.device)
            
            # トランスフォーマーへのフォワードパス
            noise_prediction = self.dit_backbone(
                z_scaled, timestep=t_tensor, class_labels=condition_ids
            ).sample
            
            # ガイドランスによるノイズ推定値の調整
            if enable_guidance:
                eps_cond_part, eps_rest_part = noise_prediction[:, :4], noise_prediction[:, 4:]
                e_pos, e_neg = torch.split(eps_cond_part, eps_cond_part.shape[0] // 2, dim=0)
                e_guided = e_neg + cfg_strength * (e_pos - e_neg)
                eps_final = torch.cat([e_guided, e_guided], dim=0)
                adjusted_noise = torch.cat([eps_final, eps_rest_part], dim=1)
            else:
                adjusted_noise = noise_prediction
                
            # チャンネル次元の分離(必要に応じて)
            if self.dit_backbone.config.out_channels // 2 == 4:
                model_output, _ = torch.split(adjusted_noise, 4, dim=1)
            else:
                model_output = adjusted_noise
                
            # 逆拡散ステップの適用
            z_step_result = self.scheduler.step(model_output, current_t, z_working).prev_sample
            
            if enable_guidance:
                z_state = z_step_result[:len(z_state)]
            else:
                z_state = z_step_result
                
        # 最終潜在表現から画像への変換
        image_tensor = self._denormalize_latents(z_state)
        return image_tensor

# 使用例
if __name__ == "__main__":
    model_root = "path/to/DiT-XL-2-256"
    generator = DiffusionTransformerEngine(model_root, device="cuda")
    
    # ImageNetクラスIDを指定(例: 425=クジラ, 843=傘)
    output_images = generator.run_inference(
        target_categories=[425, 843],
        inference_steps=25,
        cfg_strength=7.5,
        random_seed=100
    )
    
    for idx, img_data in enumerate(output_images):
        import PIL.Image
        pil_img = PIL.Image.fromarray((img_data * 255).numpy().astype("uint8"))
        pil_img.save(f"generated_sample_{idx}.png")

タグ: DiT Diffusion Transformer Latent Diffusion PyTorch Classifier-Free Guidance

5月15日 02:02 投稿