PFRLを試してみる

はじめに

最近Preferred Networksが公開したpytorchによる強化学習ライブラリ
PFRLの内容を確かめて、openai-gymに実装されている
Pendulum問題を学習させてみた。

PFRL

PFRLはchainerによる強化学習ライブラリchainerrl1の後継ライブラリである。
強化学習を行うために必要な機能や、いくつかのモデルフリー強化学習アルゴリズム
が搭載されている。

強化学習

強化学習教師あり学習などと異なり、データを予め用意しない。
Agentが学習時に環境から状態を受け取って、行動を選択し、
報酬を元にデータを収集し、最適な方策を学習する。

f:id:nakamrnk:20200803080034j:plain

Agent

Agentはひとつの強化学習アルゴリズムに対応し、
環境との相互作用により方策を学ぶ。

PFRLは2020/8/3時点で以下のアルゴリズムが実装されている。

  • DQN
  • Categorical DQN
  • Rainbow
  • IQN
  • DDPG
  • A3C
  • ACER
  • PPO
  • TRPO
  • TD3
  • SAC

環境

強化学習が方策を学ぶため相互作用するための対象。
行動を受け取って、状態を変化させるシミュレータ。
PFRLは主に環境へのインターフェースを提供しており、
シミュレータはopenai-gym2やmujoco3などを利用する必要がある。

検証

今回はopenai-gymのpendulum問題に対してPFRLのSAC4を動かしてみる。

Pendulum問題

トルクを制御して一次元の振り子を立たせる問題。

  • 状態 : 3次元 連続値 (cos θ、 sinθ, 角速度)
  • 行動 : 1次元 連続値 (トルク)
  • 報酬 : 振り子の先端が頂点付近にあれば高くなる
  • 1エピソードは200ステップ

本来一次元の振り子なので、角度θと角速度だけでいいのだが、
角度θが0, 360間で不連続となるので、cosθ、sinθの2成分に分けている。

コード

ライブラリ読み込み

!pip install pfrl
import os
import random

import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.notebook import tqdm as tqdm

import pfrl
import torch
import torch.nn
import gym

環境設定

env = gym.make('Pendulum-v0')
obs_size = env.observation_space.low.size
n_actions = env.action_space.shape[0]

print('observation space:', env.observation_space)
print('action space:', env.action_space)

環境はopenai-gymのPendulum-v0をそのまま使う。

Q-Function, Policy-Function

from torch import distributions as dists


class QFunction(torch.nn.Module):

    def __init__(self, obs_size, n_actions):
        super().__init__()
        self.l1 = torch.nn.Linear(obs_size + n_actions, 50)
        self.l2 = torch.nn.Linear(50, 50)
        self.l3 = torch.nn.Linear(50, 1)

    def forward(self, x):
        state, action = x
        h = torch.cat([state, action], 1)
        h = torch.nn.functional.relu(self.l1(h))
        h = torch.nn.functional.relu(self.l2(h))
        h = self.l3(h)
        return h

    def __init__(self, obs_size, n_actions, log_std_max=3, log_std_min=-15):
        super().__init__()
        self.l1 = torch.nn.Linear(obs_size, 50)
        self.l2 = torch.nn.Linear(50, 50)
        self.mean = torch.nn.Linear(50, n_actions)
        self.log_std = torch.nn.Linear(50, n_actions)
        self.log_std_max = log_std_max
        self.log_std_min = log_std_min
    
    def forward(self, x):
        h = x
        h = torch.nn.functional.relu(self.l1(h))
        h = torch.nn.functional.relu(self.l2(h))
        mean = self.mean(h)
        log_std = self.log_std(h)
        log_std = torch.clamp(log_std, min=self.log_std_min, max=self.log_std_max)
        dist = dists.Normal(mean, log_std.exp())
        return dist



SACは状態と行動値を受け取って行動価値を返すQ関数と
状態を受け取って行動を返す方策関数からなる Actor-Criticである。
方策関数は平均と標準偏差をネットワークで予測し、ガウス分布を返す。
Q関数は状態と行動をconcatしてMLPに通して価値を返す。

SAC Agent

#https://github.com/pfnet/pfrl/blob/master/pfrl/agents/soft_actor_critic.py

from torch.nn import functional as F

class ModSAC(pfrl.agents.SoftActorCritic):
    def _batch_act_train(self, batch_obs):
        if self.burnin_action_func is not None and self.n_policy_updates == 0:
            batch_action = [self.burnin_action_func() for _ in range(len(batch_obs))]
        else:
            # deterministic フラグを追加
            batch_action = self.batch_select_greedy_action(batch_obs, deterministic=self.act_deterministically)
        self.batch_last_obs = list(batch_obs)
        self.batch_last_action = list(batch_action)

        return batch_action  

    def update_q_func(self, batch):
        """Compute loss for a given Q-function."""

        batch_next_state = batch["next_state"]
        batch_rewards = batch["reward"]
        batch_terminal = batch["is_state_terminal"]
        batch_state = batch["state"]
        batch_actions = batch["action"]
        batch_discount = batch["discount"]

        with torch.no_grad(), pfrl.utils.evaluating(self.policy), pfrl.utils.evaluating(
            self.target_q_func1
        ), pfrl.utils.evaluating(self.target_q_func2):
            next_action_distrib = self.policy(batch_next_state)
            next_actions = next_action_distrib.sample()
            next_log_prob = next_action_distrib.log_prob(next_actions)
            next_q1 = self.target_q_func1((batch_next_state, next_actions))
            next_q2 = self.target_q_func2((batch_next_state, next_actions))
            next_q = torch.min(next_q1, next_q2)
            # unsqueeze処理を無視(方策関数側で次元を落とすべき?)
#            entropy_term = self.temperature * next_log_prob[..., None]
            entropy_term = self.temperature * next_log_prob
            assert next_q.shape == entropy_term.shape

            target_q = batch_rewards + batch_discount * (
                1.0 - batch_terminal
            ) * torch.flatten(next_q - entropy_term)

        predict_q1 = torch.flatten(self.q_func1((batch_state, batch_actions)))
        predict_q2 = torch.flatten(self.q_func2((batch_state, batch_actions)))
        # print("===")
        # print(batch_rewards)
        # print(batch_terminal)
        # print(batch_discount)
        # print(next_q)
        # print(entropy_term)
        
        # print(target_q, predict_q1)
        loss1 = 0.5 * F.mse_loss(target_q, predict_q1)
        loss2 = 0.5 * F.mse_loss(target_q, predict_q2)

        # Update stats
        self.q1_record.extend(predict_q1.detach().cpu().numpy())
        self.q2_record.extend(predict_q2.detach().cpu().numpy())
        self.q_func1_loss_record.append(float(loss1))
        self.q_func2_loss_record.append(float(loss2))

        self.q_func1_optimizer.zero_grad()
        loss1.backward()
        if self.max_grad_norm is not None:
            clip_l2_grad_norm_(self.q_func1.parameters(), self.max_grad_norm)
        self.q_func1_optimizer.step()

        self.q_func2_optimizer.zero_grad()
        loss2.backward()
        if self.max_grad_norm is not None:
            clip_l2_grad_norm_(self.q_func2.parameters(), self.max_grad_norm)
        self.q_func2_optimizer.step()


    def update_policy_and_temperature(self, batch):
        """Compute loss for actor."""

        batch_state = batch["state"]

        action_distrib = self.policy(batch_state)
        actions = action_distrib.rsample()
        log_prob = action_distrib.log_prob(actions)
        q1 = self.q_func1((batch_state, actions))
        q2 = self.q_func2((batch_state, actions))
        q = torch.min(q1, q2)

        # unsqueeze処理を無視(方策関数側で次元を落とすべき?)
#        entropy_term = self.temperature * log_prob[..., None]
        entropy_term = self.temperature * log_prob
        assert q.shape == entropy_term.shape
        loss = torch.mean(entropy_term - q)

        self.policy_optimizer.zero_grad()
        loss.backward()
        if self.max_grad_norm is not None:
            clip_l2_grad_norm_(self.policy.parameters(), self.max_grad_norm)
        self.policy_optimizer.step()

        self.n_policy_updates += 1

        if self.entropy_target is not None:
            self.update_temperature(log_prob.detach())

        # Record entropy
        with torch.no_grad():
            try:
                self.entropy_record.extend(
                    action_distrib.entropy().detach().cpu().numpy()
                )
            except NotImplementedError:
                # Record - log p(x) instead
                self.entropy_record.extend(-log_prob.detach().cpu().numpy())

PFRL(pfrl==0.1.0)実装のSACがそのままでは動かなかったので、一部修正している。
(方策関数の出力形式がおかしい?)
(追記 : 公式のexampleに記載あり Independentを介する必要がある
https://github.com/pfnet/pfrl/blob/master/examples/mujoco/reproduction/soft_actor_critic/train_soft_actor_critic.py)

学習・評価


max_episode_len = 200


def evaluate(agent, num_eval):
  trajects = []
  prg = tqdm(range(num_eval))
  with agent.eval_mode():
    for i in prg:
        obs = env.reset()
        R = 0
        t = 0
        while True:
            # Uncomment to watch the behavior in a GUI window
            # env.render()
            action = agent.act(obs)
            obs, r, done, _ = env.step(action)
            R += r
            t += 1
            reset = t == max_episode_len
            trajects.append(list(obs) + [action[0], r, done, i])
#            agent.observe(obs, r, done, reset)
            if done or reset:
                break
  result_df = pd.DataFrame(trajects, columns=["cos", "sin", "dot", "action", "reward", "done", "run"]) 
  result_df["theta"] = np.arctan2(result_df["sin"], result_df["cos"]) * 180.0/np.pi
  mean_return = result_df.groupby(["run"])["reward"].sum().mean()
  result_df["mean_return"] = mean_return
  result_df["action"] = np.clip(result_df["action"], -2, 2)
  return mean_return, result_df


def train(agent, run, stat_period=50, show_period=10, eval_period=50, 
          num_eval=100, n_episodes=300):
  results = []
  Rs = []
  for i in range(1, n_episodes + 1):
      obs = env.reset()
      R = 0  # return (sum of rewards)
      t = 0  # time step
      while True:
          action = agent.act(obs)
          obs, reward, done, _ = env.step(action)
          R += reward
          t += 1
          reset = t == max_episode_len
          agent.observe(obs, reward, done, reset)
          if done or reset:
              break
      Rs.append(R)
      if i % show_period== 0:
          print('run', run, 'episode:', i, 'R:', np.mean(Rs[-50:]))
      if i % stat_period == 0:
          print('statistics:', agent.get_statistics())
      if i % eval_period == 0:
        mean, res_df = evaluate(agent, num_eval)
        print("evaluate :", mean)
        results.append([run, i, np.mean(Rs[-50:]), mean])
  return results, res_df


実行

%%time

n_episodes=500
gamma = 0.99
phi = lambda x: x.astype(numpy.float32, copy=False)
gpu = -1
num_runs = 3

def set_seed(seed):
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

res_dfs = []
results = []
for run in range(num_runs):
  env.seed(run)
  set_seed(run)
  policy_func = PolicyFunc(obs_size, n_actions)
  q_func1 = QFunction(obs_size, n_actions)
  q_func2 = QFunction(obs_size, n_actions)

  poly_optimizer = torch.optim.Adam(policy_func.parameters(), eps=1e-3)
  q1_optimizer = torch.optim.Adam(q_func1.parameters(), eps=1e-3)
  q2_optimizer = torch.optim.Adam(q_func2.parameters(), eps=1e-3)

  replay_buffer = pfrl.replay_buffers.ReplayBuffer(capacity=10 ** 6)
  explorer = pfrl.explorers.AdditiveGaussian(
      0.2, low=env.action_space.low, high=env.action_space.high)

  agent = ModSAC(
      policy_func, q_func1, q_func2,
      poly_optimizer, q1_optimizer, q2_optimizer,
      replay_buffer,
      gamma,
      replay_start_size=500,
      update_interval=1,
      phi=phi,
      gpu=gpu,
  )
  rows, res_df = train(agent, run, n_episodes=n_episodes)
  results.extend(rows)
  res_dfs.append(res_df)
log_df = pd.DataFrame(results, columns=["run", "episode", "train", "eval"])
print('Finished.')

SACはOff-polictyのActor-Criticであり、Q関数、方策関数、Experience-Replay
などを用意する必要がある。 また、SACは2つのQ関数を用いて安定性を高めているので、
初期に2つQ関数を用意する。

ハイパーパラメータは適当に設定している。
強化学習は初期値によって学習速度が大きくブレることがあるので、
今回は3run学習している(本来は10runほどやったほうがいいと思う)。
学習時の探索アルゴリズムはAdditiveGaussianとしている。
Pendulum問題は初期値によって1エピソードの合計報酬(収益)が大きく変わるので、
評価時は100エピソードに対する平均収益を求めている。

学習曲線

train_eval_df = log_df.set_index(["run", "episode"]).stack().reset_index().rename(columns={"level_2":"mode", 0:"return"})
sns.relplot(x="episode", y="return", data=train_eval_df, kind="line", hue="run", col="mode", palette="Set1")

f:id:nakamrnk:20200803143020p:plain

run1は100エピソード程度で収束しているのに対して、
run0 は500エピソードでも収束していない。
探索アルゴリズムやパラメータを調整すればもう少し、
runごとのバラつきは小さくなると思う。

軌跡

results_df = []
for r, r_df  in enumerate(res_dfs):
  r_df = r_df.rename(columns={"run":"episode"})
  r_df["run"] = "run {} mean return {:.2f}".format(r, r_df["mean_return"].values[0])
  results_df.append(r_df)
results_df = pd.concat(results_df)
sns.relplot(x="theta", y="dot", data=results_df[results_df["episode"]<8], col="run", hue="action", palette="RdBu", style="episode")

f:id:nakamrnk:20200803144239p:plain

学習終了時の評価100エピソード中8エピソード分の
状態空間軌跡を描画すると上図のようになる。

横軸は頂点からの振り子の選択の角度θ。 縦軸は角速度。 色はトルクの値。
原点(0, 0)がゴール(振り子の先端が上を向き、速度0)である。

基本的にどの軌跡もゴール付近には辿りつけている。
振り子が下のほう(-180, 180付近)にある場合は
トルクが大きく何回か左右にトルク振りながら加速し、
速度が十分に達すると、左上または右下から原点付近に向う。
ゴール近くになるとトルクの符号が変わり速度を落として
原点に静止させるように制御を行っている。

まとめ

今回はPFRLライブラリを使ってみた。
Pendulum問題に対してSACで学習ができることを
確認していくつかのrunについての結果を比較した。
今後は他のアルゴリズムや他の問題についても検証してみたい。

参考文献