強化学習におけるActor-Criticフレームワークは、エージェントの行動を決定するActorとその行動の価値を評価するCriticという二つの役割を持つ。この動的バランスシステムの中心的な要素が、Actor Lossである。これはあたかも演出家のように、Actorの行動選択を徐々に最適な戦略へと導く役割を果たす。
1. Actor-Criticフレームワークの数学的基盤
Actor-Criticは、方策勾配法と価値関数法の利点を組み合わせた手法である。以下がその数学的構造である。
- 方策ネットワーク(Actor):パラメータθで表され、状態sにおける行動aの確率分布πθ(a|s)を出力する。
- 価値ネットワーク(Critic):パラメータwで表され、状態sの価値Vw(s)を推定する。
Actorが状態sで行動aを選択すると、Criticはその状態の価値V(s)を評価する。この相互作用は以下のような疑似コードで表現できる。
state = env.reset()
action_probs = actor_network(state)
action = sample(action_probs)
next_state, reward = env.step(action)
value = critic_network(state)
1.1 方策勾配法の進化
従来の方策勾配法はモンテカルロサンプリングに依存しており、高分散の問題があった。Actor-CriticではCriticをベースラインとして導入することで、この分散を低減する。
- 原始方策勾配: ∇J(θ) = 𝔼[∇logπ(a|s) * Gt]
- Critic導入後: ∇J(θ) = 𝔼[∇logπ(a|s) * (Q(s,a) - V(s))]
ここで、アドバンテージ関数A(s,a) = Q(s,a) - V(s)は、行動aが平均的な行動と比較してどれだけ優れているかを示す。TD誤差δ = r + γV(s') - V(s)を用いて近似される。
2. Actor Lossの詳細な構造
Actor Lossの核心は、Criticの評価信号を方策ネットワークの最適化に変換することにある。
2.1 基本形
最も基本的なActor Lossは以下のように定義される。
L_actor = -𝔼[logπ(a|s) * A(s,a)]
2.2 シチュエーション別の動作パターン
| フィードバックの種類 | Critic評価V(s) | Actor Lossの効果 | 方策の調整方向 |
|---|---|---|---|
| 正のフィードバック | V(s) > 0 | Loss減少 | π(a|s)を増加 |
| 負のフィードバック | V(s) < 0 | Loss増加 | π(a|s)を減少 |
| 中立フィードバック | V(s) ≈ 0 | 影響弱い | ほぼ変化なし |
2.3 勾配更新のメカニズム
パラメータ更新のプロセスは以下の通りである。
- 方策勾配の計算: ∇θL = -A(s,a) * ∇θlogπ(a|s)
- パラメータ更新(Adamオプティマイザー使用例):
advantage = critic(state) - target_value
policy_loss = -torch.log(prob_action) * advantage.detach()
policy_loss.backward()
optimizer.step()
3. モンテカルロ法とTD法の比較実験
実際の実装では、Actor Lossの計算方法にモンテカルロ法とTD法のどちらを用いるか選択できる。
3.1 モンテカルロ法
- 完全なエピソードのリターンGtを使用
- 利点: 不偏推定
- 欠点: 高分散、完全な軌跡が必要
returns = compute_returns(rewards, gamma=0.99)
loss = -torch.log(probs) * returns
3.2 TD(λ)法
- nステップTD誤差を使用
- 利点: 低分散、オンライン学習が可能
- 欠点: バイアスが生じる
next_value = critic(next_state)
td_target = reward + gamma * next_value
td_error = td_target - critic(state)
loss = -torch.log(probs) * td_error
性能比較表
| 指標 | モンテカルロ | TD(0) | TD(λ) |
|---|---|---|---|
| 分散 | 高 | 低 | 中 |
| バイアス | 無 | 有 | 有 |
| 収束速度 | 遅い | 速い | 比較的速い |
| データ効率 | 低 | 高 | 比較的高 |
4. 実践:PyTorch実装における重要なポイント
Actor-Criticの実装時には、いくつかの注意点がある。
4.1 確率分布の安定性
Softmax出力を使用する方策では、勾配消失が発生する可能性がある。対策として、数値的に安定したlog確率計算が推奨される。
log_probs = F.log_softmax(policy_output, dim=-1)
selected_log_probs = advantage * log_probs.gather(1, actions)
policy_loss = -selected_log_probs.mean()
4.2 アドバンテージの標準化
アドバンテージ関数を標準化することで、学習の安定性が向上する。
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
4.3 二重Criticネットワーク
SACアルゴリズムのアイデアを借用し、過大評価を防ぐために二つのCriticを利用する。
value1 = critic1(state)
value2 = critic2(state)
td_target = reward + gamma * torch.min(critic1(next_state), critic2(next_state))
loss1 = F.mse_loss(value1, td_target.detach())
loss2 = F.mse_loss(value2, td_target.detach())
5. 高度な最適化戦略
5.1 エントロピー正則化
Lossに方策のエントロピー項を追加し、探索を促進する。
entropy = -torch.sum(probs * torch.log(probs), dim=-1)
policy_loss = - (log_probs * advantage.detach()).mean() - 0.01 * entropy.mean()
5.2 信頼領域最適化
KLダイバージェンス制約を用いて、方策の更新幅を抑制する。
L(θ) = 𝔼[π_new/π_old * A] - β*KL(π_old||π_new)
5.3 適応的学習率
勾配のノルムに基づいて学習率を動的に調整する。
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
lr = base_lr * (1.0 / (1.0 + grad_norm))