「動手学強化学習」に基づく知識ポイント(5):第18章 オフライン強化学習(gymバージョン >= 0.26)

概要

本シリーズは「動手学強化学習」の内容に基づき、難点を詳細に分析します!具体的な内容については「動手学強化学習」をお読みください。

対応する章:動手学強化学習——オフライン強化学習

SACアルゴリズム部分

以下にデータセットを生成するコードを示します。SAC部分は14.5節のコードを直接使用するため、詳細な説明は省略します。——18.4 CQLコード実践

import numpy as np
import gym
from tqdm import tqdm
import random
import rl_utils
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Normal
import matplotlib.pyplot as plt


# 伝統的なSACアルゴリズム
class PolicyNetworkContinuous(torch.nn.Module):
    def __init__(self, state_dim, hidden_dim, action_dim, action_bound):
        super(PolicyNetworkContinuous, self).__init__()
        self.fc1 = torch.nn.Linear(state_dim, hidden_dim)
        self.fc_mu = torch.nn.Linear(hidden_dim, action_dim)
        self.fc_std = torch.nn.Linear(hidden_dim, action_dim)
        self.action_bound = action_bound

    def forward(self, x):
        x = F.relu(self.fc1(x))
        mu = self.fc_mu(x)
        std = F.softplus(self.fc_std(x))
        dist = Normal(mu, std)
        normal_sample = dist.rsample()  # rsample()は再パラメータ化サンプリング
        log_prob = dist.log_prob(normal_sample)
        action = torch.tanh(normal_sample)
        # tanh_normal分布の対数確率密度を計算
        log_prob = log_prob - torch.log(1 - torch.tanh(action).pow(2) + 1e-7)
        action = action * self.action_bound
        return action, log_prob


class QValueNetworkContinuous(torch.nn.Module):
    def __init__(self, state_dim, hidden_dim, action_dim):
        super(QValueNetworkContinuous, self).__init__()
        self.fc1 = torch.nn.Linear(state_dim + action_dim, hidden_dim)
        self.fc2 = torch.nn.Linear(hidden_dim, hidden_dim)
        self.fc_out = torch.nn.Linear(hidden_dim, 1)

    def forward(self, x, a):
        cat = torch.cat([x, a], dim=1)
        x = F.relu(self.fc1(cat))
        x = F.relu(self.fc2(x))
        return self.fc_out(x)


class SACContinuous:
    ''' 連続動作を扱うSACアルゴリズム '''
    def __init__(self, state_dim, hidden_dim, action_dim, action_bound,
                 actor_lr, critic_lr, alpha_lr, target_entropy, tau, gamma,
                 device):
        self.actor = PolicyNetworkContinuous(state_dim, hidden_dim, action_dim,
                                         action_bound).to(device)  # ポリシーネットワーク
        self.critic_1 = QValueNetworkContinuous(state_dim, hidden_dim,
                                            action_dim).to(device)  # 最初のQネットワーク
        self.critic_2 = QValueNetworkContinuous(state_dim, hidden_dim,
                                            action_dim).to(device)  # 2番目のQネットワーク
        self.target_critic_1 = QValueNetworkContinuous(state_dim,
                                                   hidden_dim, action_dim).to(
                                                       device)  # 最初の目標Qネットワーク
        self.target_critic_2 = QValueNetworkContinuous(state_dim,
                                                   hidden_dim, action_dim).to(
                                                       device)  # 2番目の目標Qネットワーク
        # 目標Qネットワークの初期パラメータをQネットワークと同じにする
        self.target_critic_1.load_state_dict(self.critic_1.state_dict())
        self.target_critic_2.load_state_dict(self.critic_2.state_dict())
        self.actor_optimizer = torch.optim.Adam(self.actor.parameters(),
                                                lr=actor_lr)
        self.critic_1_optimizer = torch.optim.Adam(self.critic_1.parameters(),
                                                   lr=critic_lr)
        self.critic_2_optimizer = torch.optim.Adam(self.critic_2.parameters(),
                                                   lr=critic_lr)
        # alphaのlog値を使用することで、学習結果をより安定させる
        self.log_alpha = torch.tensor(np.log(0.01), dtype=torch.float)
        self.log_alpha.requires_grad = True  # alphaの勾配を計算
        self.log_alpha_optimizer = torch.optim.Adam([self.log_alpha],
                                                    lr=alpha_lr)
        self.target_entropy = target_entropy  # 目標エントロピーの大きさ
        self.gamma = gamma
        self.tau = tau
        self.device = device

    def take_action(self, state):
        if isinstance(state, tuple):
            state = state[0]
        state = torch.tensor([state], dtype=torch.float).to(self.device)
        action = self.actor(state)[0]
        return [action.item()]

    def calc_target(self, rewards, next_states, dones):  # 目標Q値を計算
        next_actions, log_prob = self.actor(next_states)
        entropy = -log_prob
        q1_value = self.target_critic_1(next_states, next_actions)
        q2_value = self.target_critic_2(next_states, next_actions)
        next_value = torch.min(q1_value,
                               q2_value) + self.log_alpha.exp() * entropy
        td_target = rewards + self.gamma * next_value * (1 - dones)
        return td_target

    def soft_update(self, net, target_net):
        for param_target, param in zip(target_net.parameters(),
                                       net.parameters()):
            param_target.data.copy_(param_target.data * (1.0 - self.tau) +
                                    param.data * self.tau)

    def update(self, transition_dict):
        states = torch.tensor(transition_dict['states'],
                              dtype=torch.float).to(self.device)
        actions = torch.tensor(transition_dict['actions'],
                               dtype=torch.float).view(-1, 1).to(self.device)
        rewards = torch.tensor(transition_dict['rewards'],
                               dtype=torch.float).view(-1, 1).to(self.device)
        next_states = torch.tensor(transition_dict['next_states'],
                                   dtype=torch.float).to(self.device)
        dones = torch.tensor(transition_dict['dones'],
                             dtype=torch.float).view(-1, 1).to(self.device)
        rewards = (rewards + 8.0) / 8.0  # 倒立振子環境の報酬をリシェイプ

        # 2つのQネットワークを更新
        td_target = self.calc_target(rewards, next_states, dones)
        critic_1_loss = torch.mean(
            F.mse_loss(self.critic_1(states, actions), td_target.detach()))
        critic_2_loss = torch.mean(
            F.mse_loss(self.critic_2(states, actions), td_target.detach()))
        self.critic_1_optimizer.zero_grad()
        critic_1_loss.backward()
        self.critic_1_optimizer.step()
        self.critic_2_optimizer.zero_grad()
        critic_2_loss.backward()
        self.critic_2_optimizer.step()

        # ポリシーネットワークを更新
        new_actions, log_prob = self.actor(states)
        entropy = -log_prob
        q1_value = self.critic_1(states, new_actions)
        q2_value = self.critic_2(states, new_actions)
        actor_loss = torch.mean(-self.log_alpha.exp() * entropy -
                                torch.min(q1_value, q2_value))
        self.actor_optimizer.zero_grad()
        actor_loss.backward()
        self.actor_optimizer.step()

        # alpha値を更新
        alpha_loss = torch.mean(
            (entropy - self.target_entropy).detach() * self.log_alpha.exp())
        self.log_alpha_optimizer.zero_grad()
        alpha_loss.backward()
        self.log_alpha_optimizer.step()

        self.soft_update(self.critic_1, self.target_critic_1)
        self.soft_update(self.critic_2, self.target_critic_2)


env_name = 'Pendulum-v1'
env = gym.make(env_name)
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.shape[0]
action_bound = env.action_space.high[0]  # 動作の最大値
random.seed(0)
np.random.seed(0)
if not hasattr(env, 'seed'):
    def seed_fn(self, seed=None):
        env.reset(seed=seed)
        return [seed]
    env.seed = seed_fn.__get__(env, type(env))
torch.manual_seed(0)

actor_lr = 3e-4
critic_lr = 3e-3
alpha_lr = 3e-4
num_episodes = 100
hidden_dim = 128
gamma = 0.99
tau = 0.005  # ソフト更新パラメータ
buffer_size = 100000
minimal_size = 1000
batch_size = 64
target_entropy = -env.action_space.shape[0]
device = torch.device("cuda") if torch.cuda.is_available() else torch.device(
    "cpu")

replay_buffer = rl_utils.ReplayBuffer(buffer_size)
agent = SACContinuous(state_dim, hidden_dim, action_dim, action_bound,
                      actor_lr, critic_lr, alpha_lr, target_entropy, tau,
                      gamma, device)

return_list = rl_utils.train_off_policy_agent(env, agent, num_episodes,
                                              replay_buffer, minimal_size,
                                              batch_size)

episodes_list = list(range(len(return_list)))
plt.plot(episodes_list, return_list)
plt.xlabel('エピソード')
plt.ylabel('リターン')
plt.title('{}上でのSAC'.format(env_name))
plt.show()

CQLアルゴリズム

CQLの概要と主要メソッドの意義

CQL(Conservative Q-Learning) は、SACアルゴリズムに基づいて保守的正則化項を追加し、Q関数の過剰推定を減らし、オフラインRLの性能を改善するものです。

  • コンストラクタ

  • 意義: SACのactor、critic(2つと対応する目標ネットワーク)、温度パラメータと最適化器を初期化し、CQL特有のハイパーパラメータ(正則化係数betaとランダムサンプル数num_random)を渡します。

  • 入力: 状態、行動の次元、各学習率、目標エントロピー、tau、gamma、デバイス、beta、num_random。

  • 出力: 初期化されたCQLインスタンス、更新の準備完了。

  • take_action

  • 意義: 状態が与えられた場合、actorネットワークで行動をサンプリングし、環境との相互作用に使用します。

  • 入力: 単一の状態(例: [0.1, 0.2, -0.1])。

  • 出力: 対応する行動(例: [0.8])。

  • soft_update

  • 意義: 目標Qネットワークのパラメータを滑らかに更新し、学習の安定性を確保します。

  • 入力: 現在のQネットワークと対応する目標ネットワーク。

  • 出力: 目標ネットワークのパラメータが旧値と現在値の線形結合に更新される。

  • update

  • 意義: 一連の実際の環境遷移データを使用してactor、criticネットワークと温度パラメータ(α)を更新し、同時にSACにCQL正則化項を追加します。 L CQL = L critic + β ( log ⁡ ∑ exp ⁡ ( Q ( s , a ′ ) − log ⁡ π ref ( a ′ ) ) − E ( s , a ) ∼ D [ Q ( s , a ) ] ここでは、ランダムな行動、ポリシーの行動、および次の行動をサンプリングし、logsumexpの項を計算します。

  • ステップ:

  1. TD目標 y = r + γ ( min ⁡ ( Q 1 ′ , Q 2 ′ ) + α ⋅ entropy ) を計算します。
  2. critic_1_lossとcritic_2_loss(平均二乗誤差損失)をそれぞれ計算します。
  3. ランダムな行動(一様分布)を追加でサンプリングし、ポリシー生成の現在および次の行動を取得し、各Qネットワークの出力に対してlogsumexp操作を実行してCQL正則化項を形成します。
  4. 総critic損失 = SAC critic損失 + β × (CQL正則化項の差)。
  5. critic_1とcritic_2をそれぞれ更新します。
  6. actorを更新して、min ⁡ ( Q 1 , Q 2 ) − α log ⁡ π ( a ∣ s ) を最大化します。
  7. αを更新してポリシーエントロピーが目標エントロピーに近づくようにします。
  8. 最後に目標ネットワークをソフト更新します。
  • 入力: transition_dictに含まれる'states', 'actions', 'rewards', 'next_states', 'dones'。
  • 出力: モデルパラメータが更新され、ポリシーとQネットワークが改善されます。

CQLクラスの詳細分析

class ConservativeQLearning:
    ''' CQLアルゴリズム '''
    def __init__(self, state_dim, hidden_dim, action_dim, action_bound,
                 actor_lr, critic_lr, alpha_lr, target_entropy, tau, gamma,
                 device, beta, num_random):
        """
        CQLクラスのコンストラクタを定義し、状態、隠れ、行動の次元、行動の範囲、
        および各最適化器の学習率、目標エントロピー、tau、gamma、
        デバイス、CQL正則化係数beta、ランダムサンプル数num_randomを受け取ります。
        """
        '''ポリシーネットワーク(actor)を作成し、連続行動の出力分布とサンプリングを行います。'''
        self.actor = PolicyNetworkContinuous(state_dim, hidden_dim, action_dim, action_bound).to(device)
        '''2つのQネットワークを作成し、それぞれ(state,action)ペアの価値を評価し、過剰推定のバイアスを減らします。'''
        self.critic_1 = QValueNetworkContinuous(state_dim, hidden_dim, action_dim).to(device)
        self.critic_2 = QValueNetworkContinuous(state_dim, hidden_dim, action_dim).to(device)
        '''目標Qネットワークを作成し、TD目標の計算に使用し、学習を安定化させます。'''
        self.target_critic_1 = QValueNetworkContinuous(state_dim, hidden_dim, action_dim).to(device)
        self.target_critic_2 = QValueNetworkContinuous(state_dim, hidden_dim, action_dim).to(device)
        '''criticネットワークのパラメータを目標ネットワークにコピーし、初期状態で一致させます。'''
        self.target_critic_1.load_state_dict(self.critic_1.state_dict())
        self.target_critic_2.load_state_dict(self.critic_2.state_dict())
        '''ポリシーネットワークにAdam最適化器を割り当て、学習率をactor_lrに設定します。'''
        self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=actor_lr)
        '''2つのQネットワークにそれぞれAdam最適化器を作成します。'''
        self.critic_1_optimizer = torch.optim.Adam(self.critic_1.parameters(), lr=critic_lr)
        self.critic_2_optimizer = torch.optim.Adam(self.critic_2.parameters(), lr=critic_lr)
        '''温度パラメータの対数を初期化し、log_alpha = log(0.01) ≈ -4.6052とします。'''
        self.log_alpha = torch.tensor(np.log(0.01), dtype=torch.float)
        '''log_alphaが勾配を計算できるようにし、学習過程で自動的にalpha=exp({log_alpha})が調整されるようにします。'''
        self.log_alpha.requires_grad = True  # alphaの勾配を計算
        '''log_alphaにAdam最適化器を作成し、学習率をalpha_lrに設定します。'''
        self.log_alpha_optimizer = torch.optim.Adam([self.log_alpha], lr=alpha_lr)
        '''目標エントロピーを保存し、温度調整に使用します。'''
        self.target_entropy = target_entropy  # 目標エントロピーの大きさ
        '''割引因子gammaを保存します。'''
        self.gamma = gamma
        '''目標ネットワーク更新に使用するソフト更新係数tauを保存します。'''
        self.tau = tau
        '''CQL損失関数の係数betaを保存し、Qネットワークの追加正則化項のバランスを取ります。'''
        self.beta = beta  # CQL損失関数の係数
        '''CQL正則化項の計算でサンプリングされるランダム行動数を保存します。'''
        self.num_random = num_random  # CQLでの行動サンプル数

    def take_action(self, state):
        """ポリシー実行インターフェースを定義し、単一の状態が与えられた場合に対応する行動を出力します。"""
        if isinstance(state, tuple):
            state = state[0]
        state = torch.tensor([state], dtype=torch.float).to(device)
        action, log_prob = self.actor(state)
        return [action.item()]

    def soft_update(self, net, target_net):
        """目標ネットワークをソフト更新し、現在のネットワークパラメータで目標ネットワークパラメータを更新します。"""
        for param_target, param in zip(target_net.parameters(), net.parameters()):
            param_target.data.copy_(param_target.data * (1.0 - self.tau) + param.data * self.tau)

    def update(self, transition_dict):
        """
        ポリシーとQネットワークの更新プロセスを定義し、環境から収集した経験データを使用してすべてのネットワークパラメータを更新し、同時にCQLの追加正則化項を計算します。
        """
        '''transition_dictからデータを抽出'''
        states = torch.tensor(transition_dict['states'], dtype=torch.float).to(device)
        actions = torch.tensor(transition_dict['actions'], dtype=torch.float).view(-1, 1).to(device)
        rewards = torch.tensor(transition_dict['rewards'], dtype=torch.float).view(-1, 1).to(device)
        next_states = torch.tensor(transition_dict['next_states'], dtype=torch.float).to(device)
        dones = torch.tensor(transition_dict['dones'], dtype=torch.float).view(-1, 1).to(device)
        '''報酬を正規化処理し、ここでは倒立振子環境の報酬をシフトとスケーリングして、報酬範囲をより安定させます。'''
        rewards = (rewards + 8.0) / 8.0  # 倒立振子環境の報酬をリシェイプ
        '''すべての次の状態に対してactorを使用して次の行動とその対数確率を取得します。'''
        next_actions, log_prob = self.actor(next_states)
        '''エントロピー項を計算し、エントロピー = - log_probとなります。'''
        entropy = -log_prob
        '''目標Qネットワークを使用して次の状態でのQ値推定を計算します。'''
        q1_value = self.target_critic_1(next_states, next_actions)
        q2_value = self.target_critic_2(next_states, next_actions)
        '''次の時刻の価値推定を計算し、より小さいQ値(ダブルQメカニズム)を選択し、エントロピー正則化項α⋅entropyを追加します。'''
        next_value = torch.min(q1_value, q2_value) + self.log_alpha.exp() * entropy
        '''
        TD目標を計算し、doneが1(終了状態)の場合は割引を適用しません。
        td_target = 現在の即時報酬 + gamma*次の段階の報酬
        '''
        td_target = rewards + self.gamma * next_value * (1 - dones)
        '''
        2つのQネットワークの平均二乗誤差(MSE)損失をそれぞれ計算し、td_targetに対して(勾配が目標に流れないようにdetachします)。
        '''
        critic_1_loss = torch.mean(F.mse_loss(self.critic_1(states, actions), td_target.detach()))
        critic_2_loss = torch.mean(F.mse_loss(self.critic_2(states, actions), td_target.detach()))

        # 上記はSACと同じ、以下のQネットワーク更新はCQLの追加部分
        '''現在のバッチサイズを取得します。'''
        batch_size = states.shape[0]
        '''
        役割:
        - ランダムな一様分布の行動を生成し、形状を(batch_size*num_random, action_dim)に設定し、範囲を[-1,1]に設定します。
        - CQL正則化項の計算に使用され、追加の行動サンプルとして機能します。
        数値の例:
        - batch_size=64, num_random=10, action_dim=1の場合、形状(640,1)のランダムな行動が生成されます。
          例: [[0.23], [-0.45], ...]
        '''
        random_unif_actions = torch.rand([batch_size * self.num_random, actions.shape[-1]], dtype=torch.float).uniform_(-1, 1).to(device)
        '''
        役割:
        一様分布の対数確率密度を計算します。
        - 連続区間[-1,1]では、各次元の密度は1/2なので、対数確率はlog(0.5)となります。
        - action_dim個の次元に対して、合計はlog(0.5^action_dim)となります。
        '''
        random_unif_log_pi = np.log(0.5**next_actions.shape[-1])
        # random_unif_log_pi = np.log(0.5) * next_actions.shape[-1]
        '''
        役割: データセットを拡張
        - statesに新しい次元を追加し、num_random回繰り返し、最後に(batch_size*num_random, state_dim)にreshapeします。
        数値の例:
        - states shape=(64,3)の場合、
          unsqueeze後は(64,1,3)に、
          repeat後は(64,10,3)に、view後は(640,3)になります。
        '''
        tmp_states = states.unsqueeze(1).repeat(1, self.num_random, 1).view(-1, states.shape[-1])
        tmp_next_states = next_states.unsqueeze(1).repeat(1, self.num_random, 1).view(-1, next_states.shape[-1])
        '''ポリシーネットワークを使用してtmp_states(繰り返された実際の状態)から行動をサンプリングし、ランダムな現在の行動とその対数確率を取得します。'''
        random_curr_actions, random_curr_log_pi = self.actor(tmp_states)
        '''同様に、tmp_next_statesから次の行動と対応する対数確率をサンプリングします。'''
        random_next_actions, random_next_log_pi = self.actor(tmp_next_states)
        '''criticネットワークを使用してtmp_statesとランダムな一様な行動random_unif_actionsのQ値を計算し、形状を(batch_size, num_random, 1)にreshapeします。'''
        q1_unif = self.critic_1(tmp_states, random_unif_actions).view(-1, self.num_random, 1)
        q2_unif = self.critic_2(tmp_states, random_unif_actions).view(-1, self.num_random, 1)
        '''criticネットワークを使用してtmp_statesとランダムな現在の行動のQ値を計算し、reshapeします。'''
        q1_curr = self.critic_1(tmp_states, random_curr_actions).view(-1, self.num_random, 1)
        q2_curr = self.critic_2(tmp_states, random_curr_actions).view(-1, self.num_random, 1)
        '''criticネットワークを使用してtmp_statesとランダムな次の行動のQ値を計算し、reshapeします。'''
        q1_next = self.critic_1(tmp_states, random_next_actions).view(-1, self.num_random, 1)
        q2_next = self.critic_2(tmp_states, random_next_actions).view(-1, self.num_random, 1)
        '''
        役割:
        - 3つの部分のQ値を連結します:
          1. ランダムな一様サンプリング行動: 対数確率(固定値)を引きます;
          2. ポリシーでサンプリングされた現在の行動: 対応する対数確率を引きます(勾配伝搬を防ぐためにdetachします);
          3. ポリシーでサンプリングされた次の行動: 同様に。
        - 連結次元は第1次元(行動サンプル次元)です。
        数値の例:
        - 各部分の形状が(64,10,1)の場合、連結後のq1_cat形状は(64,30,1)になります。
        Conservative Q-Learning (CQL)では、オフラインデータ外の行動に対してQ関数が過剰に推定されるのを防ぐため、critic損失に正則化項を追加します。正則化項の考え方は「保守的に」Q値を推定し、未見の行動に対してQ値が過剰に高くならないようにすることです。具体的には、CQLの正則化項は次の形式に似ています:
                            Penalty=logE_{a∼μ}[exp(Q(s,a)−logμ(a))]−E_{(s,a)∼D}[Q(s,a)]
        '''
        q1_cat = torch.cat([
            q1_unif - random_unif_log_pi,
            q1_curr - random_curr_log_pi.detach().view(-1, self.num_random, 1),
            q1_next - random_next_log_pi.detach().view(-1, self.num_random, 1)
        ], dim=1)
        q2_cat = torch.cat([
            q2_unif - random_unif_log_pi,
            q2_curr - random_curr_log_pi.detach().view(-1, self.num_random, 1),
            q2_next - random_next_log_pi.detach().view(-1, self.num_random, 1)
        ], dim=1)
        '''
        役割:
        - q1_catとq2_catを行動次元でlogsumexp操作を実行し、平均を取ってスカラー値を取得します。
          これはすべてのランダム行動サンプルのソフト最大値を表します。
        - logsumexpは平滑な最大値関数で、計算式は:
                            log∑_iexp(xi)
        '''
        qf1_loss_1 = torch.logsumexp(q1_cat, dim=1).mean()
        qf2_loss_1 = torch.logsumexp(q2_cat, dim=1).mean()
        '''criticネットワークが現在の実際の(states, actions)に対してQ値の平均をそれぞれ計算します。'''
        qf1_loss_2 = self.critic_1(states, actions).mean()
        qf2_loss_2 = self.critic_2(states, actions).mean()
        '''元のSACのcritic損失(平均二乗誤差)とCQL正則化項を加算し、最終的なcritic損失を形成します。'''
        qf1_loss = critic_1_loss + self.beta * (qf1_loss_1 - qf1_loss_2)
        qf2_loss = critic_2_loss + self.beta * (qf2_loss_1 - qf2_loss_2)
        '''
        critic_1の損失に対して逆伝播更新を実行します。
        critic_2の損失に対して逆伝播更新を実行します。
        '''
        self.critic_1_optimizer.zero_grad()
        qf1_loss.backward(retain_graph=True)
        self.critic_1_optimizer.step()
        self.critic_2_optimizer.zero_grad()
        qf2_loss.backward(retain_graph=True)
        self.critic_2_optimizer.step()

        # ポリシーネットワークを更新
        '''現在のポリシーを使用して実際の状態から行動をサンプリングし、対応する対数確率を取得します。
        self.actorはまだ更新されていません'''
        new_actions, log_prob = self.actor(states)
        entropy = -log_prob
        '''
        現在のポリシーによって生成された行動のQ値を評価し、critic_1とcritic_2でそれぞれ計算し、過剰推定を減らすために最小値を取ります。
        self.critic_1とself.critic_2は先ほど更新されたものです
        '''
        q1_value = self.critic_1(states, new_actions)
        q2_value = self.critic_2(states, new_actions)
        '''
        役割:
        - ポリシー損失を計算し、min(𝑄1,𝑄2)−𝛼log𝜋を最大化することを目指します。ここでは負の値を損失として取ります。
        '''
        actor_loss = torch.mean(-self.log_alpha.exp() * entropy - torch.min(q1_value, q2_value))
        self.actor_optimizer.zero_grad()
        actor_loss.backward()
        self.actor_optimizer.step()

        # alpha値を更新
        '''
        役割:
        - 温度パラメータαの損失を計算し、ポリシーエントロピーが目標エントロピーに近づくようにします。
        - detach()はentropyの勾配を逆伝播させず、alphaのみを更新することを意味します。
        SACのオリジナル論文では:
                            J(α)=E_{a∼π}[−α(logπ(a∣s)+Htarget)]、ここでH=−logπ(a∣s)
        '''
        alpha_loss = torch.mean((entropy - self.target_entropy).detach() * self.log_alpha.exp())
        self.log_alpha_optimizer.zero_grad()
        alpha_loss.backward()
        self.log_alpha_optimizer.step()

        self.soft_update(self.critic_1, self.target_critic_1)
        self.soft_update(self.critic_2, self.target_critic_2)


random.seed(0)
np.random.seed(0)

if not hasattr(env, 'seed'):
    def seed_fn(self, seed=None):
        env.reset(seed=seed)
        return [seed]
    env.seed = seed_fn.__get__(env, type(env))
torch.manual_seed(0)

beta = 5.0
num_random = 5
num_epochs = 100
num_trains_per_epoch = 500

agent = ConservativeQLearning(state_dim, hidden_dim, action_dim, action_bound, actor_lr,
            critic_lr, alpha_lr, target_entropy, tau, gamma, device, beta,
            num_random)

return_list = []
for i in range(10):
    with tqdm(total=int(num_epochs / 10), desc='反復 %d' % i) as pbar:
        for i_epoch in range(int(num_epochs / 10)):
            # ここでの環境との相互作用はポリシー評価のみで、最後にプロット用であり、学習には使用されません
            epoch_return = 0
            state = env.reset()
            done = False
            while not done:
                action = agent.take_action(state)
                result = env.step(action)
                if len(result) == 5:
                    next_state, reward, done, truncated, info = result
                    done = done or truncated  # terminatedとtruncatedフラグをマージできます
                else:
                    next_state, reward, done, info = result
                # next_state, reward, done, _ = env.step(action)
                state = next_state
                epoch_return += reward
            return_list.append(epoch_return)

            for _ in range(num_trains_per_epoch):
                b_s, b_a, b_r, b_ns, b_d = replay_buffer.sample(batch_size)
                transition_dict = {
                    'states': b_s,
                    'actions': b_a,
                    'next_states': b_ns,
                    'rewards': b_r,
                    'dones': b_d
                }
                agent.update(transition_dict)

            if (i_epoch + 1) % 10 == 0:
                pbar.set_postfix({
                    'epoch':
                    '%d' % (num_epochs / 10 * i + i_epoch + 1),
                    'return':
                    '%.3f' % np.mean(return_list[-10:])
                })
            pbar.update(1)


epochs_list = list(range(len(return_list)))
plt.plot(epochs_list, return_list)
plt.xlabel('エポック')
plt.ylabel('リターン')
plt.title('{}上でのCQL'.format(env_name))
plt.show()

mv_return = rl_utils.moving_average(return_list, 9)
plt.plot(episodes_list, mv_return)
plt.xlabel('エピソード')
plt.ylabel('リターン')
plt.title('{}上でのCQL'.format(env_name))
plt.show()

タグ: 強化学習 オフライン強化学習 SAC CQL PyTorch

6月16日 20:06 投稿