大規模ニューラルネットワーク訓練における完全分片データ並行(FSDP)とZeRO最適化の技術解説

分散並行学習とメモリ制約への対応

従来のデータ並行(DP)や分散データ並行(DDP)では、ミニバッチ入力の分割と勾配同期が主たる手法でした。しかし、パラメータ数が数千億規模に達した現代の大規模モデル訓練において、単一アクセラレータのメモリ容量は明確なボトルネックとなっています。完全分片データ並行(Fully Sharded Data Parallel: FSDP)は、モデルパラメータ、最適化状態、勾配データをプロセス間で動的に分割し、必要なタイミングでのみ集合通信を用いて再構築するアーキテクチャです。本記事では、混精度計算やメモリ推定の基礎からZeRO最適化の段階的進化、そしてPyTorch実装と通信効率の数理的分析までを体系的に解説します。

基礎技術:精度形式と混精度計算

高精度な浮動小数点演算は計算コストとメモリ帯域を大きく消費します。深度学习では、精度の低下がモデル性能に与える影響を最小限に抑えつつ、計算効率を最大化するための数値形式が標準化されています。

  • FP32:IEEE 754単精度浮動小数点。長年標準とされてきましたが、メモリフットプリントが大きい。
  • FP16:半精度浮動小数点。メモリ効率が2倍向上し、TPU/GPUで広く採用。ただし表現範囲が狭く、下溢れ(underflow)のリスクがある。
  • BF16:Brain Floating Point。指数部をFP32と同様に8bit確保し、FP16の表現範囲不足を補完。大規模モデル訓練で事実上の標準となりつつある。
  • TF32:Tensor Core向け最適化形式。FP32のmantissaを10bitに丸めて計算し、結果をFP32で復元。コード変更なしで高速化を図れる。

混精度訓練では、フォワード/バックワードパスで低精度(例:BF16)を用いて演算とメモリ使用量を削減し、重み更新時には高精度(FP32)の主コピーに適用します。PyTorch FSDPでは、MixedPrecisionクラスを用いてこの挙動を宣言的に設定できます。

from torch.distributed.fsdp import MixedPrecision
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP

bf16_precision = MixedPrecision(
    param_dtype=torch.bfloat16,
    reduce_dtype=torch.bfloat16,
    buffer_dtype=torch.bfloat16
)

distributed_model = FSDP(
    target_network,
    mixed_precision=bf16_precision,
    auto_wrap_policy=custom_wrap_fn
)

損失スケーリング(Loss Scaling)

低精度演算では微小な勾配値がゼロと判定される「下溢れ」が発生します。これを回避するために、損失値に定数倍のスケーリングを適用し、チェーンルールを通じて勾配値を有効な表現範囲に押し上げます。PyTorchのGradScalerは動的スケーリングを実装しており、勾配にNaN/Infが検出された場合にスケーリング係数を自動的に低下させ、安定したイテレーション後には係数を漸進的に増加させます。

from torch.cuda.amp import GradScaler, autocast

scaler = GradScaler()
with autocast():
    preds = network(inputs)
    loss = criterion(preds, labels)

scaler.scale(loss).backward()
scaler.step(solver)
scaler.update()

メモリ消費の定量評価

訓練時のメモリ要件は「モデル状態」と「残差状態」に分類されます。

  • モデル状態:パラメータ(W)、勾配(G)、最適化状態(Adamのmomentum/varianceなど)。必須の永続データ。
  • 残差状態:活性化値(Activation)、通信バッファ、メモリ断片。計算途中の一時データ。

混合精度訓練(FP32パラメータ/最適化状態、FP16勾配/計算)におけるモデル状態のメモリ消費は、パラメータ数$\Phi$に対して以下のようになります。

構成要素消費メモリ(バイト)
FP32 パラメータ$4\Phi$
FP32 Adam Momentum$4\Phi$
FP32 Adam Variance$4\Phi$
FP16 勾配$2\Phi$
FP16 計算用パラメータコピー$2\Phi$
合計$16\Phi$

Transformerアーキテクチャを想定すると、隠れ層次元$hd$、レイヤー数$nl$、バッチサイズ$bsz$、シーケンス長$seq$を用いて、モデル状態は約$12 \times nl \times hd^2$パラメータに近似できます。残差状態(活性化メモリ)はチェックポイント間隔$ci$を考慮し、$2 \times bsz \times seq \times hd \times nl / ci$と見積もられます。これらがGPUメモリ容量を超過する際に、分割並行の必要性が生まれます。

ZeRO最適化アーキテクチャ

ZeRO(Zero Redundancy Optimizer)は、メモリ使用量を直線的に削減する3段階のモデル状態分割戦略と、残差状態最適化から成ります。

ZeRO-DP:モデル状態の分割段階

並行程$N_d$に対する分割戦略は以下の3段階で構成されます。

  1. ZeRO-1(最適化状態分割):AdamのMomentumとVarianceをプロセスで分割。各プロセスは$1/N_d$の状態のみ保持。通信は勾配同期時のみ発生し、メモリは約1/4に削減。
  2. ZeRO-2(勾配分割):勾配をReduce-Scatterで分割。各プロセスは担当パラメータの勾配のみ計算・保持。メモリは約1/8に削減。All-Gatherで重みを復元後に更新実行。
  3. ZeRO-3(完全分割 / FSDP):パラメータ、勾配、最適化状態を全て分割。フォワード/バックワード計算直前にAll-Gatherで重みを取得し、計算後には即時破棄。メモリ使用量は$1/N_d$に比例して減少。通信量は標準DPの約1.5倍に増加するトレードオフを持つ。

ZeRO-R:残差状態の最適化

ZeRO-Rは活性化メモリと断片化管理に焦点を当てます。

  • 分割活性化チェックポイント:モデル並行度に応じてチェックポイントを分割。必要時にAll-Gatherで再構成し、メモリを$1/N_m$に削減。極端なメモリ逼迫時はCPUメモリへオフロード($P_{a+cpu}$)し、ホスト-デバイス帯域を犠牲にしてGPUメモリを解放します。
  • 定長バッファ管理:集合通信用のテンソルパディングを固定長化し、動的メモリ確保による断片化を抑制。
  • メモリデフラグメンテーション:テンソルライフサイクルを事前計画し、割り当て/解放パターンを最適化。

ZeRO-Infinity:異種メモリ統合

NPUメモリ限界を超え、CPU RAMおよびNVMeストレージをメモリ階層として統合する拡張版です。重みと勾配を低速メモリへシームレスにオフロードし、計算直前に高速メモリへストリーミング転送します。メモリリソースのオーケストレーションにより、単一ノードの物理メモリ限界を超えた超大規模モデルの訓練が可能になります。

PyTorch FSDP実装パターン

以下は、PyTorchのtorch.distributed.fsdpを用いた実装例です。コード構造は見通しよく変更し、変数名と初期化フローを最適化しています。

import os
import functools
import torch
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, MixedPrecision, ShardingStrategy
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
from transformers import T5Block
from torch.optim.lr_scheduler import StepLR
from packaging.version import Version

def setup_distributed_training(config):
    # 分散環境情報の取得
    proc_idx = int(os.environ.get("LOCAL_RANK", "0"))
    num_procs = int(os.environ.get("WORLD_SIZE", "1"))
    
    torch.cuda.set_device(proc_idx)
    
    # Transformerブロック単位で自動分割するポリシーを定義
    auto_wrap_policy = functools.partial(
        transformer_auto_wrap_policy,
        transformer_layer_cls={T5Block}
    )
    
    # ハードウェアサポートに基づく混合精度判定
    cuda_ver = Version(torch.version.cuda or "0.0")
    hardware_supports_bf16 = (
        cuda_ver >= Version("11.0") and
        torch.cuda.is_bf16_supported()
    )
    
    precision_cfg = MixedPrecision(
        param_dtype=torch.bfloat16 if hardware_supports_bf16 else torch.float32,
        reduce_dtype=torch.bfloat16 if hardware_supports_bf16 else torch.float32,
        buffer_dtype=torch.bfloat16 if hardware_supports_bf16 else torch.float32
    ) if hardware_supports_bf16 else None
    
    # モデルの読み込み(CPUメモリ上)とFSDPラッピング
    base_network = load_t5_base_model()
    sharded_network = FSDP(
        base_network,
        auto_wrap_policy=auto_wrap_policy,
        mixed_precision=precision_cfg,
        sharding_strategy=ShardingStrategy.FULL_SHARD,
        device_id=torch.cuda.current_device()
    )
    
    solver = torch.optim.AdamW(sharded_network.parameters(), lr=config.lr)
    decay_scheduler = StepLR(solver, step_size=1, gamma=config.gamma)
    
    for epoch in range(config.epochs):
        # train_loop_step(config, sharded_network, solver)
        decay_scheduler.step()

通信コストと計算効率の数理的評価

メモリ削減と通信増加のトレードオフを定量化するため、算術強度(Arithmetic Intensity: AIT)と帯域幅利用率を分析します。AITは総演算量と総データ移動量の比であり、高いAITは通信 bottleneck の影響を相殺します。

Transformerの単一イテレーションにおける演算量とデータ移動量を推定すると、以下のように表現できます。

  • 演算量:$4 \times bsz \times seq \times 12 \times nl \times hd^2$(フォワード+バックワード×2近似)
  • 重み・勾配のデータ移動:$4 \times parameters$(フォワード取得、バックワード取得、勾配書き込み、状態更新)
  • 最適化状態のデータ移動:$2 \times optimizer\_states \approx 32 \times parameters$
  • 活性化チェックポイントの移動:$4 \times nl/ci \times hd \times seq \times bsz$

各コンポーネントの算術強度($ait$)は演算量÷データ移動量で求められます。重みと勾配の$ait$は$seq \times bsz$に比例し、最適化状態は$seq \times bsz / 4$となります。実際のGPU/NPUにおいて、重み・勾配通信は70 GB/s以上の帯域で50%を超える効率が達成可能です。一方、最適化状態はフォワード/バックワード終了時に集中して発生するため、計算オーバーラップが困難で、実効帯域1.5 TB/s近い要件が生まれます。

活性化チェックポイントを適用した場合、$ait$は$24 \times hd \times ci$となり、隠れ次元$hd$が8Kを超えると1 GB/s程度の帯域で50%以上の効率が維持されます。ZeRO-FSDPはこの特性を設計に組み込み、バッチサイズ増加による通信相対コストの低下と、チェックポイント間隔の最適化によって、超大規模モデル訓練のスループットとメモリ効率的なバランスを実現しています。

タグ: PyTorch FSDP ZeRO 分散学習 混精度計算

6月9日 18:52 投稿