制御可能な文章生成RAG - 技術概要

はじめに

蓄えられた知識を元に文章生成を行うTransformerモデル
Retrieval-augmented generation (RAG)の技術について調査した。

Deep Learning による文章生成

T5, GPT-3などのDeep Learningを用いたモデル1は一見すると
人間に近い性能の文章を生成できる。
しかしながら、それらのモデルはいくつか問題を抱えている。

  • 知識と文法を一つのモデルで学習
    • 間違った知識で自然な文章が生成される可能性あり (Fake news)
  • 追加学習が難しい
    • 新たな知識が増えた場合のモデル更新が難しい
  • 出力結果の説明が難しい

これに対してRAGはモデルの外にある知識を利用して文章を生成する。
具体的には、外部知識として文章群を用意し、文章を生成するときはその中から
関連する文章を選び、それらを元に文章を出力する。

RAG

RAGはT5と同様に文章を入力として文章を出力するモデルである。
(出力文章をクラスとすれば分類問題にも使える)

RAGは内部的にはDPR (Retriever)とBART (Generator)の2つのモデルを利用している。
transformers2ライブラリに実装済み。
https://huggingface.co/transformers/model_doc/rag.html

概念図

f:id:nakamrnk:20201019093305j:plain

DPRは文章間の関連度を求めるためのネットワークであり、 これにより
知識文章群から入力文章に関連する文章とその特徴量を抜き出す。
BARTは元の入力とDPRが出力した特徴量から文章を生成する。

これにより、出力された文章に対して知識文章群のどの文章が関連するかが
明確となるため、判定の根拠が示しやすい。 また、知識群に不適切なものや古い知識が
含まれていた場合にそれらを取り除くことで、間違った出力が抑制されることが期待される。

RAGの利用

RAGは前述した通り、良い性質を備えたモデルであるが、 入出力文章データを
用意するだけで学習ができるBARTやT5よりもモデル作成に必要なものが多い。

知識文章の用意

RAGの論文では知識文章としてWikipediaの文章を100単語ごとに区切ったものを使用している。
膨大な質の低いデータが良いのか、数は少ないが質の高いデータが良いのかは検証の余地がある。

DPRの学習

DPRは知識群から特徴量を抽出するEpと入力文章から特徴量を抽出する
Eqの2つのエンコーダからなるモデルである。 RAG論文ではRAGの学習中はDPRの
Epは固定してEq部分のみをfine-tuningしているので、RAGを学習するためには
DPRの事前学習モデルが必要となる。

まとめ

RAGの論文を読み、技術概要をまとめた。
学習は大変そうだが、うまく使えば制御のしやすい文章生成モデルとなり、
面白いことができそうだと思った。
とりあえず単純なデータセットを作って学習し、挙動を理解したい。

参考文献

XAIについての検証 - 手法比較

はじめに

前回までいくつかのXAI手法の検証を行ってきた。
今回はそれらの手法を比較するコードを実装し、githubに公開した。
https://github.com/NeverendingNotification/pytorch-xai-analyze

このコードを用いていくつかの状況で各XAIアルゴリズムの挙動を検証した。

比較したアルゴリズム

  • Anchors1
  • SHAP2
  • Grad-CAM3

初期値依存性

Deep Learningの学習結果はネットワーク重み初期値にある程度依存する。
XAIアルゴリズムの初期値依存性を見るために、 乱数SEED以外のパラメータを
固定して数回学習し、 XAIの可視化結果を比較した。

学習曲線 (評価データ)

f:id:nakamrnk:20201017064850j:plain

精度 (評価データ)

class-F run1 run2 run3
T-shirt/top 0.872945 0.867698 0.870871
Trouser 0.987443 0.985915 0.98542
Pullover 0.897 0.891337 0.897654
Dress 0.921944 0.92323 0.924303
Coat 0.869186 0.880193 0.88101
Sandal 0.978894 0.979879 0.980529
Shirt 0.766631 0.765803 0.771222
Sneaker 0.962451 0.961727 0.959334
Bag 0.986028 0.988 0.985522
Ankle boot 0.972348 0.970707 0.965377
mean-F 0.921487 0.921449 0.922124
acc 0.9221 0.9218 0.9225

評価データに対する性能的には初期SEEDごとのブレはほぼ存在しない。

XAI結果比較 (Ankle bootクラス)

f:id:nakamrnk:20201017064713j:plain

4行ずつ一組で(元画像、Anchors, SHAP, Grad-CAM)。
上から順にrun1, run2, run3。

  • Anchorの挙動はrunごとにややぶれているが、靴底の窪んでいる部分と正面部分に反応しているものが多い。
  • SHAPは共通してつま先付近(+かかとも?)に強く反応している。
  • Grad-CAMは共通して靴の正面部分に反応している。

モデル初期値によってXAIアルゴリズムの挙動が大きく変わることはないようだ。

ネットワークのパラメータ数依存

ネットワークの複雑さによってXAIの挙動が変化するか検証した。
変更したパラメータは

  • チャンネル数 (model.feature.c0)
    • (16, 32, 64, 128)
  • 層数 (model.feature.num_layres)
    • (2, 3 ,4)

はデフォルト値。

チャンネル数比較

学習曲線 (評価データ)

f:id:nakamrnk:20201017070633j:plain

精度 (評価データ)
class-F channel_16 channel_32 channel_64 channel_128
T-shirt/top 0.856566 0.872945 0.878073 0.885328
Trouser 0.981891 0.987443 0.988978 0.989442
Pullover 0.873657 0.897 0.898709 0.905698
Dress 0.910366 0.921944 0.924988 0.937406
Coat 0.861098 0.869186 0.886608 0.906126
Sandal 0.974975 0.978894 0.981379 0.987976
Shirt 0.746842 0.766631 0.793354 0.799401
Sneaker 0.957594 0.962451 0.969966 0.96895
Bag 0.982491 0.986028 0.98854 0.989033
Ankle boot 0.965169 0.972348 0.974333 0.972152
mean-F 0.911065 0.921487 0.928493 0.934151
acc 0.9111 0.9221 0.9289 0.9341

チャンネル数は増加するほど性能が向上している。

XAI結果比較 (Ankle bootクラス)

f:id:nakamrnk:20201017072520j:plain

上から順にチャンネル数(16, 32, 64, 128)。
全体の傾向としてはどのアルゴリズムも多くは変わっていない。

層数比較

学習曲線 (評価データ)

f:id:nakamrnk:20201017073148j:plain

精度 (評価データ)
class-F layers_2 layers_3 layers_4
T-shirt/top 0.821782 0.872945 0.889555
Trouser 0.972039 0.987443 0.990964
Pullover 0.82341 0.897 0.907023
Dress 0.856865 0.921944 0.945744
Coat 0.778786 0.869186 0.909
Sandal 0.940759 0.978894 0.986987
Shirt 0.66242 0.766631 0.807771
Sneaker 0.92137 0.962451 0.971852
Bag 0.960716 0.986028 0.987026
Ankle boot 0.942799 0.972348 0.973764
mean-F 0.868095 0.921487 0.936969
acc 0.8687 0.9221 0.9372

層数が増加するほど性能が向上している。

XAI結果比較 (Ankle bootクラス)

f:id:nakamrnk:20201017073457j:plain

上から順に層数 2, 3, 4。

Anchorsの挙動

層数を増やすほど反応箇所が減少している。
Anchorsの反応箇所はその部分を残すと予測結果があまり変わらない
superpixels領域を示しているため、層数が増加し、視野の広い特徴を
獲得するほどに、ローカルなsuperpixelsに対しての依存度が下がっている
ためと思われる。 解釈性という観点ではあまりよくないため、
superpixelsのとり方を考えるなど対策が必要。

SHAPの挙動

層数2の場合、つま先やかかと付近に強く反応している。
層数が増えると特定箇所への反応は弱くなり、正面付近の輪郭
全体に分布するようになる。 Anchorsと同様画像の特定部位への
依存性が落ちているためと思われる。

Grad-CAMの挙動

Grad-CAMは全てのパラメータで2層目の特徴マップを可視化している。

2層モデルの場合可視化している層は最終層であり、 靴の履き口、
正面、つま先など凹んでいる局所構造に反応している。
3層モデルの場合は靴正面のライン上に反応している。

4層モデルの場合はGrad-CAMが消滅しているものがある。
これはGrad-CAMが特徴量マップの重み付け和をとったあとに
ReLUを通すため、得られたマップ全体が負値ならば0となってしまうためである。
(最終層(4層目)を可視化する場合はこのような挙動はしないはず)
潰れていないマップを見るためには最終層を可視化するほうが良いが、
解像度がさらに落ちるため、局所性は失われる。

まとめ

各XAIアルゴリズムの挙動を比較した。
チャンネル数を増やした場合はあまり挙動が変わらないのに、
層数を変えると挙動が大きく変わるのはおもしろいと思った。
ネットワークのアーキテクチャや問題を変えた場合の挙動も理解したい。

参考文献

XAIについての検証 - Grad-CAM

はじめに

前回前々回に引き続きfashion-mnistデータについてXAIの検証を行う。
今回はGrad-CAMについて検証した。

Grad-CAM

Grad-CAMはXAIアルゴリズムのひとつであり、
特定のクラス予測に対する特徴量マップの勾配から計算した重みで
特徴量マップの重み付け和を求めることで、そのクラス予測が
画像のどの部分を元に行われているかを可視化する手法である。

Grad-CAMはXAIだけでなく他の分野でもAttention mapとして利用されている。

  • 人物識別1
  • 異常検知2
  • 弱教師ありSegmentation3

検証

実装はpytorch-gradcamを利用した。

可視化層選択

Grad-CAMにおいてはNeural Netのどの層を可視化するかという自由度
が存在する(Global Average Pooling前の畳み込み層がよく利用される印象)。

今回検証するモデルは畳み込み層3層から構成されており、
活性化関数の前後も含めると6箇所特徴量マップの候補が存在する。
(BN前後を考慮すると9だが、今回はBN後の特徴量マップのみを考える)。

以下はSandal クラス画像における各特徴量マップに対するGrad-CAMである。

f:id:nakamrnk:20201014095432j:plain

一番上の画像は元画像であり、その下に1層目のActivationの前後、
2層目のActivationの前後、 3層目のActivationの前後に対するGrad-CAMを表示している。

1層から順に特徴量マップの解像度が下がり、抽象化されているためGrad-CAMも
ぼやけていく。 1層は入力画像との差異が小さく、3層は解像度が低すぎるため、
今回は2層のactivation前の特徴量マップを解析に利用する。

Bagクラス

f:id:nakamrnk:20201014100353j:plain

最上段は元画像、それ以下は予測クラス上位1, 2, 3位のクラスに対するGrad-CAMである。
Bag予測に対するGrad-CAMの傾向を観察すると、持ち手部分やバッグ上端のたいらな部分に
強く反応している。 AnchorsやSHAPよりノイズも少なく理解しやすいと感じる。

Trouser クラス

f:id:nakamrnk:20201014100820j:plain

Trouserクラスにおいては、股下部分や裾部分に反応しており、 正しい特徴を捉えられている。

靴クラス

Sandalクラス

f:id:nakamrnk:20201014101406j:plain

Sneakerクラス

f:id:nakamrnk:20201014101418j:plain

Ankle bootクラス

f:id:nakamrnk:20201014101432j:plain

  • Sandalクラスは紐や隙間に反応している
    • SHAPと似た傾向
  • Sneakerクラスは靴の正面部分に反応しているものが多い
    • SHAPと同様はっきりしないものも多い
  • Ankle bootは靴正面ラインに反応している

前回のSHAPはAnkle bootクラスのつま先付近に反応していたが、
今回のGrad-CAMはどちらかというと正面ラインに反応しているように見える。
画像解像度を下げて広いコンテクストを扱えるGrad-CAMとピクセル単位の
勾配をみているSHAPでは解釈が異なるようだ。
複数のXAI手法を比較し、違いを見ることでモデルの予測傾向をつかめるかもしれない。

トップスクラス

トップスクラスは2層のActivation後の特徴量マップを採用した。
(Activation前のGrad-CAMが見づらかった)

T-shirt/top

f:id:nakamrnk:20201014104817j:plain

Pullover

f:id:nakamrnk:20201014104841j:plain

Dress

f:id:nakamrnk:20201014104852j:plain

Coat

f:id:nakamrnk:20201014104903j:plain

Shirt

f:id:nakamrnk:20201014104922j:plain

  • T-shirt/topクラスは全体的にぼやけているが、首や胸、袖下などに反応しているものもある
  • Pulloverクラスもぼやけているが、左袖の端に反応しているものが多い
  • Dressクラスは左右のボディラインに反応
  • Coatクラスは首元と下端に反応
    • Coatは袖が胴よりも下まで伸びているものが多いのでそれを特徴と捉えている?
  • Shirtクラスは正面部分に反応

Dress, Coatクラスはある程度人間にも理解できる反応箇所だが、
性能の低いShirtクラスは分かりづらい結果となっている。
他のXAI手法でも人間の感覚に合わなかったため、 モデル自体が
あまり良い特徴を捉えられていない可能性がある。   

まとめ

今回はGrad-CAMについて検証を行った。
見栄え自体はAnchorsやSHAPよりも良いと感じたが、
最終層を特徴量マップとすると解像度が低いため、局所的な特徴を議論したい
場合は入力に近い層を特徴量マップとして採用したり、
他の手法とも比較する必要があると感じた。

参考文献

XAIについての検証 - SHAP

はじめに

前回に引き続きfashion-mnistデータに対するXAIの検証を行う。
今回はSHAPアルゴリズムについて検証する。

SHAP

SHAP1はXAIアルゴリズムの一つである。

  • 各特徴量が加減算的に予測に寄与するとする
  • ある特徴を使う場合と使わない場合の差から寄与度(SHAP値)を求める
  • Gradient を利用してCNNにも適用できる

実装は以下で公開されている。
https://github.com/slundberg/shap

検証

前回学習したモデルをそのまま利用。
pytorchによるSHAP Deep Explainerは以下にチュートリアルがあったのでそれを参考にした。
pytorch mnistチュートリアル

コード (jupyter notebook)

前回と共通 (DataLoader, モデル定義)

import os

import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
%matplotlib inline
import seaborn as sns

import torch
from torch.utils.data import DataLoader
import torchvision


data_root = "data"
os.makedirs(data_root, exist_ok=True)

IMG_SIZE= 28
batch_size = 104



train_transform =  torchvision.transforms.Compose([
        torchvision.transforms.RandomHorizontalFlip(),
        torchvision.transforms.RandomRotation(5),
        torchvision.transforms.RandomResizedCrop(IMG_SIZE, scale=(0.9, 1.0)),
        torchvision.transforms.ToTensor(),
    ])
test_transform = torchvision.transforms.ToTensor()


train_dataset = torchvision.datasets.FashionMNIST(data_root, train=True, transform=train_transform, target_transform=None, download=True)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

valid_dataset = torchvision.datasets.FashionMNIST(data_root, train=False, transform=test_transform, target_transform=None, download=True)
valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False)

print("train : ", len(train_dataset))
print("valid : ", len(valid_dataset))

import torch.nn as nn
import torch.nn.functional as F

class ConvMod(nn.Sequential):
  def __init__(self, in_channel, out_channel, kernel_size, padding=1, stride=2, norm_fn=None, act_func=None):
    seqs = []
    seqs.append(nn.Conv2d(in_channel, out_channel, kernel_size, padding=padding, stride=stride))
    if norm_fn is not None:
      seqs.append(norm_fn(out_channel))
    if act_func is not None:
      seqs.append(act_func())
    super().__init__(*seqs)


class Model(nn.Module):
  MOD = ConvMod
  def __init__(self, num_classes=10, c0=16, num_conv=3, stride=2, norm_fn=None, act_func=nn.ReLU, last_act=None):
    super().__init__()
    seqs = []
    cin = 1
    cout = c0
    for n in range(num_conv):
      seqs.append(self.MOD(cin, cout, 3, padding=1, stride=stride, norm_fn=norm_fn, act_func=act_func))
      cin, cout = cout, cout * 2
    self.feature = nn.Sequential(*seqs)
    
    self.pool = nn.AdaptiveAvgPool2d(1)
    
    self.pred = nn.Linear(c0*4, num_classes)
    
    if last_act is not None:
      self.last_act = last_act()
    else:
      self.last_act = None

  def forward(self, x):
    
    x = self.feature(x)
    
    gap = self.pool(x).squeeze(3).squeeze(2)
    
    o = self.pred(gap)
    
    if self.last_act is not None:
      o = self.last_act(o)

    return o

モデル読み込み

from functools import partial
MODEL_PATH = "model.pth"
c0 = 32
device = "cpu"
valid_df = pd.read_csv("valid.csv", index_col=0)


model = Model(num_classes=len(train_dataset.classes), c0=c0, norm_fn=nn.BatchNorm2d, last_act=partial(nn.Softmax, dim=1)).to(device)
model.load_state_dict(torch.load(MODEL_PATH))
model.eval()
print((valid_df["true"]==valid_df["pred"]).mean())
valid_df.head()

解析・可視化

import shap

def plot_figures(shap_numpy, image_numpy, probs, num_vis_rank=3):
  num_classes = len(shap_numpy)
  nrows = 1 + num_vis_rank
  ncols = len(image_numpy)
  figsize = (ncols * 3, nrows*2.5)
  
  fig, axes  = plt.subplots(figsize=figsize, ncols=ncols, nrows=nrows)
  max_val = np.percentile(np.abs(np.concatenate(shap_numpy)), 99.9)
    
  for j in range(ncols):
    axes[0, j].imshow(image_numpy[j][:, :, 0], cmap="gray")
    axes[0, j].set_axis_off()

    p = probs[j]
    inds = np.argsort(p)[::-1]
    pro = p[inds]
    
    for n in range(num_vis_rank):
      class_id = inds[n]
      class_name = valid_dataset.classes[class_id]
      prob_str = "{:.1f}%".format(pro[n]*100)
      axes[1 + n, j].imshow(shap_numpy[class_id][j][:, :, 0], cmap="bwr", vmin=-max_val, vmax=max_val)
      axes[1 + n, j].set_axis_off()
      axes[1 + n, j].set_title(class_name + " : " +  prob_str, fontsize=18)
  fig.tight_layout()
  return fig
  
  
def analyze(v_indices, bg_indices):
  background = torch.cat([valid_dataset[i][0][None] for i in bg_indices])
  test_images = torch.cat([valid_dataset[i][0][None] for i in v_indices])
  with torch.no_grad():
    outputs = model(background)
    expected_value = outputs.mean(0).cpu().numpy()
    probs = model(test_images).cpu().numpy()
  e = shap.DeepExplainer(model, background)
  shap_values = e.shap_values(test_images)
  sh_arr = np.array([shap_value.sum(axis=(1, 2, 3)) for shap_value in shap_values])

  shap_numpy = [np.swapaxes(np.swapaxes(s, 1, -1), 1, 2) for s in shap_values]
  test_numpy = np.swapaxes(np.swapaxes(test_images.numpy(), 1, -1), 1, 2)
  return plot_figures(shap_numpy, test_numpy, probs)  
  
fig = analyze(valid_df[valid_df["true"] == 8].sample(10).index, valid_df.sample(200).index)
fig.tight_layout()
fig.savefig("tmp.jpg")

前回モデルのクラスごとの予測性能
クラス precision recall F-measure
T-shirt/top 0.884774 0.86 0.872211
Trouser 0.993946 0.985 0.989453
Pullover 0.86724 0.908 0.887152
Dress 0.90619 0.937 0.921337
Coat 0.878 0.878 0.878
Sandal 0.985787 0.971 0.978338
Shirt 0.804233 0.76 0.781491
Sneaker 0.95122 0.975 0.962963
Bag 0.984048 0.987 0.985522
Ankle boot 0.968938 0.967 0.967968

Bagクラス

f:id:nakamrnk:20201013113510j:plain

最上段は元の画像、 以降は予測結果上位クラス(1, 2, 3位)に対する
SHAP値の分布である。赤い部分は予測に対して正の寄与をしている領域
青い部分は予測に対して負の寄与をしている領域である。

Bagクラスは多くの画像が正しく判定できているため、上から2行目の第1予測
クラスへの反応が大きい。 バッグの左右端の領域や持ち手部分がBag予測に
寄与していることが分かり、妥当な結果と言える。
第2予測以降はほとんど反応していない。

Trouserクラス

f:id:nakamrnk:20201013114412j:plain

Trouserクラスも比較的高精度で予測出来ているクラスである。
Trouserクラスの場合は股下の背景部分に赤い領域が多く、
そこに注目して判定を行っている。この構造は他のクラスにないため
判定基準としては妥当と思われる。

一方で股下構造の見えない 右から3番目の画像は誤判定している
(Dress クラス 85 %, Trouser 13%) 。このように典型的な特徴から
外れた画像に対しては誤判定が起こりやすい。

靴クラス比較

Sandalクラス

f:id:nakamrnk:20201013115142j:plain

Sneakerクラス

f:id:nakamrnk:20201013115228j:plain

Ankle bootクラス

f:id:nakamrnk:20201013115307j:plain

靴クラスを比較すると

  • Sandalクラスは隙間部分や紐部分に反応している
  • Sneakerクラスは特定箇所への反応が弱い
    • 全体の構造を見て判定しているためSHAPでは特徴がでない?
  • Ankle bootクラスはつま先に強く反応しているものが多い

Ankle bootはくるぶしを覆う靴なのでくるぶし当たりに反応するほうが
人間の感覚からは自然であろう。 しかし、実際はつま先付近の構造に
強く反応しているため、つま先付近の構造に対して何か他クラスとの
違いを発見したものと思われる。

その特徴が適切なものならば良いのだが、Sneakerクラス画像の右から二番目は
つま先付近に反応してAnkle bootクラスと誤判定しており、 Ankle bootクラスの
特徴としては十分ではないと思われる (ラベルミスの可能性もあるが)。
今回のモデルで偶然このような学習が進んだのか、 現状の学習手法に問題が
あるかは今後の検証課題である。

トップスクラス比較

T-shirt/top クラス

f:id:nakamrnk:20201013122027j:plain

Pullover クラス

f:id:nakamrnk:20201013122039j:plain

Dress クラス

f:id:nakamrnk:20201013122051j:plain

Coat クラス

f:id:nakamrnk:20201013122116j:plain

Shirt クラス

f:id:nakamrnk:20201013122131j:plain

各クラスに対するSHAPについて

  • T-shirt/top は肩から脇付近への反応がやや強い
  • Pullover は長袖部分への反応が強いように見える
    • そこまではっきりはしていない
  • Dressクラスは肩から胸付近と腰付近に反応
  • Coatクラスは首元への反応がやや強い
  • Shirtクラスは首元と体中心付近に反応

現状トップスについてはSHAPによって説得力のある説明は
できないと思う。Dressクラスの特徴的なボディラインや
T-shirt/topクラスの肩まわりなどある程度人間の感覚に近い傾向
も見て取れるが、noisyであまり綺麗に特徴を捉えているとは言えない。
正の寄与と負の寄与が混在しているような領域も多く、
一見してどちらが優勢なのか分かりづらいのもマイナス点である。

まとめ

前回に引き続きfashion-mnistデータに対してXAIの検証を行った。
SHAPはTrouserの股下などわりと細かい特徴も捉えられているが、
正負の寄与が混じった領域の解釈などに難があると感じた。
今後はGrad-CAMなどの滑らかなsaliency map系のXAIと比較していきたい。

参考文献

XAIについての検証 - Anchors

はじめに

画像系Deep LearingにおけるXAI (Explainable AI)のひとつAnchorsを用いて
Fahion-mnistデータに対して学習を行ったモデルの解析を行った。

XAI

Deep Learning モデルはその性能の高さから様々な分野で利用されているが、
処理の多くがNeural Networkにより抽出された特徴量に依存しているため、
人間には理解することが難しい(ブラックボックス化している)1

このAIのブラックボックス化を緩和するためXAIと呼ばれる技術が注目されている。
XAI関連の技術は様々(NNの説明可能モデルへの置き換え、ブラックボックスの中身検査等2)
であるが今回はモデルの出力結果を説明する技術の一つAnchorsを検証する。

学習データ・モデル

学習データはFashion-Mnistを用いた。
これは28x28のファッション画像データに対する10クラス分類である。

学習データは60,000(各クラス6,000)、評価データ10,000(各クラス1000)である。

ソースコード (jupyter notebook)

モジュール

import os

import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
%matplotlib inline
import seaborn as sns

データ読み込み

import torch
from torch.utils.data import DataLoader
import torchvision


data_root = "data"
os.makedirs(data_root, exist_ok=True)

IMG_SIZE= 28
batch_size = 32



train_transform =  torchvision.transforms.Compose([
        torchvision.transforms.RandomHorizontalFlip(),
        torchvision.transforms.RandomRotation(5),
        torchvision.transforms.RandomResizedCrop(IMG_SIZE, scale=(0.9, 1.0)),
        torchvision.transforms.ToTensor(),
    ])
test_transform = torchvision.transforms.ToTensor()


train_dataset = torchvision.datasets.FashionMNIST(data_root, train=True, transform=train_transform, target_transform=None, download=True)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

valid_dataset = torchvision.datasets.FashionMNIST(data_root, train=False, transform=test_transform, target_transform=None, download=True)
valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False)

print("train : ", len(train_dataset))
print("valid : ", len(valid_dataset))

モデル定義

import torch.nn as nn
import torch.nn.functional as F

class ConvMod(nn.Sequential):
  def __init__(self, in_channel, out_channel, kernel_size, padding=1, stride=2, norm_fn=None, act_func=None):
    seqs = []
    seqs.append(nn.Conv2d(in_channel, out_channel, kernel_size, padding=padding, stride=stride))
    if norm_fn is not None:
      seqs.append(norm_fn(out_channel))
    if act_func is not None:
      seqs.append(act_func())
    super().__init__(*seqs)


class Model(nn.Module):
  
  MOD = ConvMod
  
  
  def __init__(self, num_classes=10, c0=16, num_conv=3, stride=2, norm_fn=None, act_func=nn.ReLU):
    super().__init__()
    seqs = []
    cin = 1
    cout = c0
    for n in range(num_conv):
      seqs.append(self.MOD(cin, cout, 3, padding=1, stride=stride, norm_fn=norm_fn, act_func=act_func))
      cin, cout = cout, cout * 2
    self.feature = nn.Sequential(*seqs)
    
    self.pool = nn.AdaptiveAvgPool2d(1)
    
    self.pred = nn.Linear(c0*4, num_classes)
    
    
    
  def forward(self, x):
    
    x = self.feature(x)
    
    gap = self.pool(x).squeeze(3).squeeze(2)
    
    o = self.pred(gap)
    return o
    
 

学習

from torch.optim import Adam
from tqdm import tqdm as tqdm


def evaluation(model, valid_loader, criteria, device):
  model.eval()
  prg = tqdm(valid_loader)
  prg.set_description("valid")
  results = [[], []]
  losses = []
  for batch in prg:
    imgs, labels = [b.to(device) for b in batch]
    with torch.no_grad():
      preds = model(imgs)
      loss = criteria(preds, labels)
      losses.append(loss.item())

    results[0].append(labels.cpu().numpy())    
    results[1].append(preds.argmax(dim=1).cpu().numpy())

  valid_loss = np.mean(losses)
  result_df = pd.DataFrame(np.array([np.concatenate(col) for col in results]).T, columns=["true", "pred"])
  return valid_loss, result_df


lr0 = 1e-3
num_epochs = 30
c0 = 32 # num of first layer channels 
device = "cpu"
MODEL_PATH = "model.pth"

def accum_func(epoch):
  if epoch < num_epochs * 0.85:
    return 1
  else:
    return 10
  

model = Model(num_classes=len(train_dataset.classes), c0=c0, norm_fn=nn.BatchNorm2d).to(device)

optimizer = Adam(model.parameters(), lr=lr0)
criterion = nn.CrossEntropyLoss()


optimizer.zero_grad()

logs = []
for epoch in range(1, num_epochs + 1):
  model.train()

  prg = tqdm(train_loader)
  prg.set_description("train Epoch {}".format(epoch))

  num_accum = accum_func(epoch)
  loss_factor = 1.0 / num_accum
  it = 0
  count = 0
  hit = 0
  losses = []
  for batch in prg:
    imgs, labels = [b.to(device) for b in batch]
    preds = model(imgs)
    loss = criterion(preds, labels) * loss_factor
    loss.backward()
    # calculate train accuracy
    it += 1
    count += len(labels)
    hit += (preds.argmax(dim=1) == labels).sum().item()
    # gradient accumulation
    if it % num_accum == 0:
      optimizer.step()
      optimizer.zero_grad()
    
    losses.append(loss.item() / loss_factor)
    
  train_acc = hit / count
  train_loss = np.mean(losses)
  
  valid_loss, valid_df = evaluation(model, valid_loader, criterion, device)
  valid_acc = (valid_df["true"]==valid_df["pred"]).mean()
  print("Epoch {} , Loss train / {:.3f} valid / {:.3f}, Accuracy  train /{:.3f} valid {:.3f}".format(epoch, train_loss, valid_loss, train_acc, valid_acc))
  logs.append((epoch, train_loss, valid_loss, train_acc, valid_acc))
    
logs_df = pd.DataFrame(logs, columns=["Epoch", "train_loss", "valid_loss", "train_acc", "valid_acc"]).set_index("Epoch")
        
fig, _ = plt.subplots(figsize=(15, 4), ncols=2)
logs_df[["train_loss", "valid_loss"]].plot(ax=fig.axes[0], title="loss")
logs_df[["train_acc", "valid_acc"]].plot(ax=fig.axes[1], title="accuracy")
    
    
    
torch.save(model.state_dict(), MODEL_PATH)
valid_df.groupby(["true", "pred"]).size().unstack(fill_value=0)

学習結果

学習曲線

f:id:nakamrnk:20201012103658j:plain

confusion matrix
true\pred T-shirt/top Trouser Pullover Dress Coat Sandal Shirt Sneaker Bag Ankle boot
T-shirt/top 860 0 20 19 5 2 89 0 5 0
Trouser 0 985 1 10 0 0 2 0 2 0
Pullover 10 1 908 10 36 0 35 0 0 0
Dress 8 3 9 937 22 0 20 0 1 0
Coat 1 0 52 32 878 0 36 0 1 0
Sandal 0 0 0 0 0 971 0 21 0 8
Shirt 92 2 56 25 57 1 760 0 7 0
Sneaker 0 0 0 0 0 4 0 975 0 21
Bag 1 0 1 1 2 1 3 2 987 2
Ankle boot 0 0 0 0 0 6 0 27 0 967

全体で92%精度の精度だった。
誤判定は

  • T-shirt/top - Shirtクラス間
  • Pullover - Shirtクラス間

などが多い。

クラスごとの評価結果
precision recall F-measure
T-shirt/top 0.884774 0.86 0.872211
Trouser 0.993946 0.985 0.989453
Pullover 0.86724 0.908 0.887152
Dress 0.90619 0.937 0.921337
Coat 0.878 0.878 0.878
Sandal 0.985787 0.971 0.978338
Shirt 0.804233 0.76 0.781491
Sneaker 0.95122 0.975 0.962963
Bag 0.984048 0.987 0.985522
Ankle boot 0.968938 0.967 0.967968

クラスの内訳は

  • トップス - T-shirt/top , Pullover, Dress , Coat, Shirt
  • ボトムス- Trouser
  • 靴 - Sandal, Sneaker, Ankle boot
  • その他 - Bag

であるので、画像全体の形状が独立しているTrouser, Bagクラスは正解率が高い。
一方で全体の構造が似ているトップス関連(私が見ても違いがよく分からないものも多い)
はやや性能が低く、Shirtクラスの正解率が特に低いため原因を説明したい。

Anchors

AnchorsはXAI技術の一つであり、データ点周辺の予測結果がほとんど変わらない
"Anchor"と呼ばれる領域を用いてモデルを説明しようとするアルゴリズムである。
AnchorsLIMEに近い手法である(Authorが同じ)。

LIMEはあるデータ点近くのモデルの挙動を線形近似することでモデルを説明する
アルゴリズムであるが、 Anchorsはデータ点周辺で予測結果があまり変わらない領域と
それに寄与している特徴量を求めることで、 あるデータの予測に欠かせない
特徴量セットを求めるというアルゴリズムである。

superpixels

LIMEAnchorsは画像を教師なしで領域分割するsuperpixelsと呼ばれる技術により
画像を分割し、分割した部分を塗りつぶしてモデルに入力することによる結果の変化から
その領域の重要度を推定している。

このsuperpixelsを行うためのアルゴリズムは複数あり、このアルゴリズム選択がLIME
説明結果に影響を与える3

今回は以下の3つの手法を比較した。

  • quick-shift4
  • slic5
  • 固定グリッド6

Anchors検証

alibi7ライブラリの実装を利用して検証した。

superpixels 準備

from alibi.explainers import AnchorImage
from skimage.segmentation import slic, quickshift, watershed

def slic_segmentation(image):
  return slic(image, n_segments=30)


def superpixel(image, size=(4, 7)):
    segments = np.zeros([image.shape[0], image.shape[1]])
    row_idx, col_idx = np.where(segments == 0)
    for i, j in zip(row_idx, col_idx):
        segments[i, j] = int((image.shape[1]/size[1]) * (i//size[0]) + j//size[1])
    return segments
  
  
def predict_fn(x):
  img = torch.from_numpy(x.astype(np.float32)/255.0).reshape([-1, 1, IMG_SIZE, IMG_SIZE]).to(device)
  with torch.no_grad():
    logits = model(img)
    probs = nn.Softmax(dim=1)(logits).cpu().numpy()
  return probs


v_index = 1
tensor, label = valid_dataset[v_index]
arr = (tensor.cpu().numpy()[0] * 255).astype(np.uint8)
image_shape =arr.shape


explainer_qs = AnchorImage(predict_fn, image_shape, segmentation_fn="quickshift")
explainer_slic = AnchorImage(predict_fn, image_shape, segmentation_fn=slic_segmentation)
explainer_grid = AnchorImage(predict_fn, image_shape, segmentation_fn=superpixel)


expls = {
  "qs":explainer_qs,
  "slic":explainer_slic,
  "grid":explainer_grid
}

解析・可視化

def analyze(v_indices):
  num_images = len(v_indices)
  num_expls = len(expls)
  figsize = (min(num_images * 4, 25), 3 * (1 + num_expls))
  fig, axes = plt.subplots(figsize=figsize, nrows= 1 + num_expls, ncols=num_images)


  for i, v_index in enumerate(v_indices):
    tensor, label = valid_dataset[v_index]
    arr = (tensor.cpu().numpy()[0] * 255).astype(np.uint8)
    class_label = valid_dataset.classes[label]

    axes[0, i].imshow(arr, cmap="gray");
    axes[0, i].set_title(class_label)


    probs = predict_fn(arr)[0]
    pred_index = probs.argmax()
    prob  = probs[pred_index]
    prob_label = "{} : {:.2f}%".format(valid_dataset.classes[pred_index], prob*100)  
    for k, (key, expl) in enumerate(expls.items()):
      explanation = expl.explain(arr.reshape([IMG_SIZE, IMG_SIZE, 1]), threshold=.95, p_sample=.8, seed=0)
      print(i, key, explanation.precision, explanation.coverage)
      axes[1 + k, i].imshow(explanation.anchor[:,:,0], cmap="gray");
      axes[1 + k, i].set_title(prob_label)
  return fig

各クラス、ランダムに10枚サンプリングした評価データに対してAnchorsの計算を行う。

Bagクラス

f:id:nakamrnk:20201012153925j:plain

上図がBagクラスに対するAnchors検証結果である。
最上段が元画像、2,3, 4行目がそれぞれ、"quick-shift", "slic", "固定グリッド"による
super-pixelアルゴリズムを用いたAnchorsの出力結果である。

Bagクラスの場合全体の構造が他のクラスと異なっているため、Anchorsのような
ローカルな置き換えに反応している画像は少ない。
ただ、持ち手部分は他のクラスにはない特徴であるため、反応している画像も存在する。

Trouserクラス

f:id:nakamrnk:20201012154508j:plain

ThrouserクラスもBagクラスと同様比較的予測精度の高いクラスである。
ボトムスはこのクラスだけなので足の部分が重要であることは予測される。
quick-shift以外のアルゴリズムは足の先端付近の細い部分が重要であるとしているので、
妥当な結果に見える。quick shiftはどこにも反応していない画像がほとんどである。

また、どのアルゴリズムも左から5番目の誤判定画像に関しては画像全体に反応している。 このような誤判定画像に対する結果はTruouserのクラス特徴とは乖離すると思われる。

靴クラス比較
Sandal

f:id:nakamrnk:20201012155550j:plain

Sneaker

f:id:nakamrnk:20201012155608j:plain

Ankle boot

f:id:nakamrnk:20201012155625j:plain

靴関連の3クラスに対するAnchorsの結果を比較すると

  • Sandalクラスは紐や細い部分に反応している
  • Sneakerクラスは靴の正面部分や履き口に反応
  • Ankle bootは正面部分に反応
    • Sneakerとの違いは角度?

のようにある程度合理的に見える挙動をしている。 (quick-shiftアルゴリズム以外)

トップス比較

Dressクラス

トップスの中では比較的精度の高いdressクラスに対するAnchorsの結果は下図のようになっている。

f:id:nakamrnk:20201012160958j:plain

固定グリッドアルゴリズムは比較的結果が安定しており、胸や胴まわりに反応しているものが多い。

Pulloverクラス

f:id:nakamrnk:20201012161326j:plain

pulloverクラスは袖の先端や胴、襟元付近に反応しているものが見られる。

T-shirt/top クラス

f:id:nakamrnk:20201012161727j:plain

T-shirt/topクラスは反応している画像自体が少ないが、
反応しているものは首周りや肩、脇あたりに反応している。

Shirtクラス

f:id:nakamrnk:20201012162058j:plain

Shirtクラスは他クラスと比べて性能が低いクラスである。

全体の構造的には半袖はT-shirt, 長袖はPulloverとの判別が難しそうに見えるが
首まわりの構造が異なるのでそれらがはっきりしている画像は正しく判別できている
ように感じる(右から1, 2, 4番目の画像等)。

そもそも素人目にはShirtに見えない画像(左から4, 5番目等)も多いため、
データ自体の難易度が高いため、性能が低いものと思われる。

まとめ

Fashion Mnistデータに対してAnchorsアルゴリズムによるXAIの検証を行った。
比較的妥当に見えるAnchorが出力されているものも多かったが、
superpixelsアルゴリズムへの依存やグローバルな構造を見づらいという
欠点も感じられたので、他のXAI手法と併用したほうがよいと思った。

参考文献

OpenCVjsとtensorflow.jsによるモデル検証アプリ

はじめに

OpenCVjsは画像処理ライブラリであるOpenCVjavascript版。
tensorflow.jsはtensorflowのjavascript版である。
これらを組み合わせて、webブラウザ上で簡単な画像修正と
モデル推論を行うプログラムをgithubに公開した。

https://github.com/NeverendingNotification/opecvjs_tensorflowjs_viewer

概要

近年Deep Learningは様々な分野で利用されるようになっている。
しかし、 pythonによるDeep Learning関連の環境構築は素人には難しい部分もある。
そこで、webブラウザさえあれば動作するDeep Learning環境として
tensorflow.jsは有望である。これにもともとはCの画像処理ライブラリであった
opencvjavascript版であるOpenCVjsを組み合わせることで、
簡単な画像処理を行いながらDeep Learningに気軽に触れられる
プログラムを作ってみた。

アプリの使い方

このライブラリのローカル環境での使い方を以下に述べる。

機能

このアプリの機能は主に2つあり、

  1. OpenCVjsによる簡単な画像処理
    • 現時点では色変換、回転・スケール変換のみ
  2. 1で修正した画像と元画像に対して学習済みモデルによる予測結果比較
画像編集

上記のgithubリポジトリをクローンして、index.htmlを開くと以下のような画面となる。

f:id:nakamrnk:20200918225705j:plain

上側の領域が元画像領域、中央が編集パラメータ領域、 下側の領域が編集後画像領域である。 

レナ画像(http://www.ess.ic.kanagawa-it.ac.jp/app_images_j.htmlよりダウンロード)を例に
アプリの動きを説明する。

  1. 中央のファイルを選択ボタンからローカルのファイル選択
  2. 中央の編集パラメータを適当に変更する
  3. 編集ボタンを押す

この結果アプリの画面は以下のようになる。
f:id:nakamrnk:20200918230324j:plain

このようにOpenCVjsを利用するとウェブブラウザのみで簡単な画像編集ができる。

モデル推論

モデル推論に関してはローカルで行うのにひと手間必要である。
これはセキュリティ上webブラウザがローカルファイルへのアクセスを許可していない (ことが多い)ため、学習済みのモデルにアクセスできないためである。

これを回避するためにいくつか手法が考えられると思うが、chromeを利用しているならば
chromeアプリであるWeb Server for Chromeを利用するのが楽だと思う。

アプリインストール後左上のアプリ項目からWeb Serverを選択すると以下のような画面が
表示されるため、 CHOOSE FOLDERから先ほどのアプリのルートフォルダを指定し、 画面中央のリンク(ここではhttp://127.0.0.1:8887/)からアプリに移動できる。

f:id:nakamrnk:20200918231338j:plain

ここで先ほどの画像編集と同様に画像を読み込み中央にある予測ボタンを押すと
学習済みモデルの予測結果が表示される。
f:id:nakamrnk:20200918232000j:plain

今回の学習済みモデルは手元のデータで適当に学習した、
人間、動物、食べ物の3クラス分類モデルである。

このモデルでLenna画像を予測した結果、
人間8.5%, 食べ物91%と明らかに間違った結果となっている。
これはこの学習済みモデルの学習データのバイアスが原因があると思われる。

今回学習に利用したデータは正面の人間画像が多く、振り向き顔であるLennaのような
姿勢の画像はなかった。また、Lenna画像はやや色合いが強く、これも学習データと
異なるように感じた。

そこで画像編集機能で彩度を-20した結果で予測を行う(下側のパネル)と
予測結果が人間 44 %まで上昇した。

さらに、明度、回転、拡大変換を追加して、学習データの条件に近づけると
人間 59 %、食べ物 40 %となり、かろうじて間違っていない結果を得ることができた。

f:id:nakamrnk:20200918232957j:plain

このように手軽に画像処理を行いながら、Deep Learningモデルの結果を観測することで、
モデルの持つバイアスを理解しやすくなる。

このアプリで軽くテストしてみた限り、今回のモデルは

  • 色相をピンクよりにすると予測確率が上がり、それ以外は下がる
  • 彩度は下げると、明度は上げると、予測確率が上がる
  • 回転は単体で行うと予測確率が下がるが、拡大といい感じに組み合わせると予測確率が上がる

などのモデルの癖を観測することができた。

まとめ

javascriptを使って学習済みモデルの性質をチェックするアプリを開発した。
ブラウザから手軽に扱えるので、モデルの性質把握がしやすい。
現状画像処理機能の種類やアプリデザインがイマイチなので
暇があったら修正したい。

参考文献

PFRLを試してみる - self play

はじめに

前回PFRLを用いてslime volleyballを学習した。
今回は同じ slime volleyballl環境に対して,
複数のagent を用いたself playを試してみる。

self play

対戦型ゲームにおける強化学習は対戦相手となるエージェントに依存する。
前回の学習では、slime volleyballが予め用意してくれているdefaultエージェントに
勝てるように学習を行ったが、 問題によっては初期に対戦相手となる
エージェントが存在しない場合がある。

そのような場合はself playが有効である。 self playは過去の自分自身に
打ち勝てるように学習を行う手法である。

  1. 初期にランダムにエージェントを初期化
  2. 片方のエージェント(A)のみ学習し、もう片方のエージェント(B)のパラメータを固定する
  3. AがBに安定して勝てるようになるまでAを学習する
  4. AのパラメータをBにコピーして2に戻る。

過去の自分を超えるプロセスを何回か繰り返すことで、
外部の情報(初期対戦相手)なしに、 エージェントを学習することができる。

slime volleyballについてのself playは以下のページで検証されている。
https://github.com/hardmaru/slimevolleygym/blob/master/TRAINING.md

検証

検証は前回同様Google Colaboratory上で行った。

ライブラリ

!pip install pfrl
!pip install slimevolleygym
import slimevolleygym
import argparse
import os

import torch
import torch.nn as nn
import numpy as np
from PIL import Image


import gym


import pfrl
from pfrl.q_functions import DiscreteActionValueHead
from pfrl import agents
from pfrl import experiments
from pfrl import explorers
from pfrl import nn as pnn
from pfrl import utils
from pfrl.q_functions import DuelingDQN
from pfrl import replay_buffers

from pfrl.wrappers import atari_wrappers
from pfrl.initializers import init_chainer_default
from pfrl.q_functions import DistributionalDuelingDQN

環境構築

コード

SEED = 0
train_seed = SEED
test_seed = 2 ** 31 - 1 - SEED

class SelfPlayMultiBinaryAsDiscreteAction(gym.ActionWrapper):
    """Transforms MultiBinary action space to Discrete.
    If the action space of a given env is `gym.spaces.MultiBinary(n)`, then
    the action space of the wrapped env will be `gym.spaces.Discrete(2**n)`,
    which covers all the combinations of the original action space.
    Args:
        env (gym.Env): Gym env whose action space is `gym.spaces.MultiBinary`.
    """

    def __init__(self, env):
        super().__init__(env)
        assert isinstance(env.action_space, gym.spaces.MultiBinary)
        self.orig_action_space = env.action_space
        self.action_space = gym.spaces.Discrete(2 ** env.action_space.n)

    def action(self, action):
        return [(action >> i) % 2 for i in range(self.orig_action_space.n)]


    def step(self, action, otherAction=None):
      if otherAction is not None:
        otherAction = self.action(otherAction)
      return self.env.step(self.action(action), otherAction=otherAction)

def make_env(test):
  # Use different random seeds for train and test envs
  env_seed = test_seed if test else train_seed
  env = gym.make("SlimeVolley-v0")  
  env.seed(int(env_seed))
  if isinstance(env.action_space, gym.spaces.MultiBinary):
      env = SelfPlayMultiBinaryAsDiscreteAction(env)
  # if args.render:
  #     env = pfrl.wrappers.Render(env)
  return env 

# 初期SEED設定
utils.set_random_seed(SEED)

# 環境設定
env = make_env(test=False)
eval_env = make_env(test=True)
obs = env.observation_space
obs_size = env.observation_space.low.size
n_actions = env.action_space.n
print(obs_size, n_actions)

前回との違いはgymのwrapperをmulti agent用に修正した部分である。
slime volleyball環境のstep関数は複数入力を与えることでmulti agent
に対応するため、それに合わせてコードを修正した。

エージェント

コード

class DistributionalDuelingHead(nn.Module):
    """Head module for defining a distributional dueling network.
    This module expects a (batch_size, in_size)-shaped `torch.Tensor` as input
    and returns `pfrl.action_value.DistributionalDiscreteActionValue`.
    Args:
        in_size (int): Input size.
        n_actions (int): Number of actions.
        n_atoms (int): Number of atoms.
        v_min (float): Minimum value represented by atoms.
        v_max (float): Maximum value represented by atoms.
    """

    def __init__(self, in_size, n_actions, n_atoms, v_min, v_max):
        super().__init__()
        assert in_size % 2 == 0
        self.n_actions = n_actions
        self.n_atoms = n_atoms
        self.register_buffer(
            "z_values", torch.linspace(v_min, v_max, n_atoms, dtype=torch.float)
        )
        self.a_stream = nn.Linear(in_size // 2, n_actions * n_atoms)
        self.v_stream = nn.Linear(in_size // 2, n_atoms)

    def forward(self, h):
        h_a, h_v = torch.chunk(h, 2, dim=1)
        a_logits = self.a_stream(h_a).reshape((-1, self.n_actions, self.n_atoms))
        a_logits = a_logits - a_logits.mean(dim=1, keepdim=True)
        v_logits = self.v_stream(h_v).reshape((-1, 1, self.n_atoms))
        probs = nn.functional.softmax(a_logits + v_logits, dim=2)
        return pfrl.action_value.DistributionalDiscreteActionValue(probs, self.z_values)

def phi(x):
  return np.asarray(x, dtype=np.float32)



def get_rainbow_agent(gamma, gpu, update_interval=1,replay_start_size=2000, target_update_interval=2000,
                      n_atoms=51, v_max=1, v_min=-1, hidden_size=512,
                      noisy_net_sigma=0.1,
                      lr0=1e-3, eps=1.5e-4, betasteps=2e6, num_step_return=3,
                      minibatch_size=32):
  
  # categorical Q-function
  q_func = nn.Sequential(
          nn.Linear(obs_size, hidden_size),
          nn.ReLU(),
          nn.Linear(hidden_size, hidden_size),
          nn.ReLU(),
          DistributionalDuelingHead(hidden_size, n_actions, n_atoms, v_min, v_max),
      )


  pnn.to_factorized_noisy(q_func, sigma_scale=noisy_net_sigma)
  # 探索アルゴリズム
  explorer = explorers.Greedy()

  # 最適化
  opt = torch.optim.Adam(q_func.parameters(), lr=lr0, eps=eps)

  # replay
  rbuf = replay_buffers.PrioritizedReplayBuffer(
            10 ** 6,
            alpha=0.5,
            beta0=0.4,
            betasteps=betasteps,
            num_steps=num_step_return,
            normalize_by_max="memory"
        )

  agent = agents.CategoricalDoubleDQN(
          q_func,
          opt,
          rbuf,
          gpu=gpu,
          gamma=gamma,
          explorer=explorer,
          minibatch_size=minibatch_size,
          replay_start_size=replay_start_size,
          target_update_interval=target_update_interval,
          update_interval=update_interval,
          batch_accumulator="mean",
          phi=phi,
          max_grad_norm=10,
      )
  return agent

エージェントに関しては前回と同様にRainbow Agentを用いている。

評価コード

from contextlib import ExitStack

def evaluate(agent1, agent2=None, n_episodes=30,  num_obs=1, multi_agent=False):
  contexts = [agent1]

  if multi_agent:
    assert agent2 is not None
    contexts.append(agent2)
  env = eval_env

  scores = []
  terminate = False
  timestep = 0
  obses = []
  actions = []
  rewards = []
  dones = []

  with ExitStack() as stack:
    for agent in contexts:
      stack.enter_context(agent.eval_mode())
    reset = True
    while not terminate:
        if reset:
            obs = env.reset()
            obs2 = obs
            done = False
            test_r = 0
            episode_len = 0
            info = {}
  
        a1 = agent1.act(obs)
        if multi_agent:
          a2 = agent2.act(obs2)

        if len(scores) < num_obs:
          obses.append(obs)
        actions.append(a1)

        if multi_agent:
          obs, r, done, info = env.step(a1, a2)
          obs2 = info['otherObs']
        else:
          obs, r, done, info = env.step(a1)

        rewards.append(r)
        dones.append(done)

        test_r += r
        episode_len += 1
        timestep += 1
        reset = done or info.get("needs_reset", False)
        agent1.observe(obs, r, done, reset)
        if multi_agent:
          agent2.observe(obs2, -r, done, reset)     
        if reset:
            # As mixing float and numpy float causes errors in statistics
            # functions, here every score is cast to float.
            scores.append(float(test_r))
        terminate = len(scores) >= n_episodes

  return obses, actions, rewards, dones, scores
    

可視化コード

import cv2
from PIL import Image

def get_converter(img_size, x0=-2, x1=2, y0=0, y1=2):
  def converter(x, y):
    px = int((np.clip(x, x0, x1) - x0)/(x1 - x0) * (img_size - 1))
    py = img_size - int((np.clip(y, y0, y1) - y0)/(y1 - y0) * (img_size - 1))
    return px, py

  return converter


def visualize(obses, rewards, skip=1, img_size=64):
  arrs = []
  converter = get_converter(img_size)
  total_reward = 0

  pb1 = converter(0, 0)
  pb2 = converter(0, 0.2)

  p_a = 0
  p_o = 0
  for s, (o, reward) in enumerate(zip(obses[:-1:skip], rewards)):
    arr = np.zeros([img_size, img_size, 3], dtype=np.uint8) + 255
    pa = converter(-o[0], o[1])
    pb = converter(-o[4], o[5])
    po = converter(o[8], o[9])

    cv2.circle(arr, pa, 3, (255, 0, 0), -1)
    cv2.circle(arr, pb, 3, (0, 255, 0), -1)
    cv2.circle(arr, po, 3, (0, 0, 255), -1)
    cv2.line(arr, pb1, pb2, (128, 64, 64), 2)
    total_reward += reward
    if reward > 0:
      p_a += 1
    elif reward < 0:
      p_o += 1

    txt = "{}:{}:{}".format(p_a, p_o, total_reward)
    cv2.putText(arr, txt, (3, 15), cv2.FONT_HERSHEY_PLAIN, 1.0, (0, 0, 0))

    arrs.append(arr)

  return arrs

学習

パラメータ

steps = 10 ** 6
gamma = 0.98

update_interval = 1 
betasteps = steps / update_interval  * 2
gpu = 0 if torch.cuda.is_available() else -1
print(gpu)

学習

import time
import logging
import pandas as pd
from pfrl.experiments.evaluator import save_agent

agent  = get_rainbow_agent(gamma, gpu, update_interval=update_interval,
                                  betasteps=betasteps) 
agent2 = get_rainbow_agent(gamma, gpu, update_interval=update_interval,
                                  betasteps=betasteps) 


outdir = "selfplay_results"
os.makedirs(outdir, exist_ok=True)


champion_period = 20000
next_champion = champion_period
num_match = 20
update_threshold = 0.5

eval_period = 50000
num_eval = 20
next_eval = eval_period

save_period = 200000


logger = logging.getLogger(__name__)

match_scores = []
eval_scores = []
logs = []
episode_r = 0
episode_idx = 0
episode_len = 0
t = 0
num_champion = 0
obs = env.reset()
obs2 = obs

t0 = time.time()
with agent2.eval_mode():
  while t < steps:
    action = agent.act(obs)
    action2 = agent2.act(obs2)
    obs, r, done, info = env.step(action, action2)
    obs2 = info['otherObs']
    t += 1
    episode_r += r
    episode_len += 1
    reset = info.get("needs_reset", False)
    agent.observe(obs, r, done, reset)
    agent2.observe(obs2, -r, done, reset)
    if done or reset or t == steps:
      if t == steps:
          break

      if t  >= next_champion:
        _, actions, _, _, scores = evaluate(agent, agent2=agent2, n_episodes=num_match,  num_obs=0, multi_agent=True)
        mean_score = np.mean(scores)
        if mean_score > update_threshold:
          agent2.model.load_state_dict(agent.model.state_dict())
          save_agent(agent, t, outdir, logger, suffix="_oldagent")
          num_champion += 1
        print("Match {} :{},  {:.4f} {:.1f} s".format(t, num_champion, mean_score, time.time() - t0))
        next_champion += champion_period
        match_scores.extend([(t, num_champion, s) for s in scores])

      if t  >= next_eval:
        _, actions, _, _, scores = evaluate(agent, agent2=agent2, n_episodes=num_eval,  num_obs=0, multi_agent=False)
        eval_scores.extend([(t, s) for s in scores])
        stats = {}
        stats["steps"] = t
        stats["episodes"] = episode_idx
        stats["time"] = time.time() - t0
        stats["mean"] = np.mean(scores)
        stats["median"] = np.median(scores)
        stats["stdev"] = np.std(scores)
        for k, v in agent.get_statistics():
          stats[k] = v
        print("Eval {} : {} {:.1f} s".format(t, np.mean(scores), time.time() - t0))
        logs.append(stats)
        pd.DataFrame(logs).to_csv(os.path.join(outdir, "scores.csv"))
        next_eval += eval_period

      # Start a new episode
      episode_r = 0
      episode_idx += 1
      episode_len = 0
      obs = env.reset()
      obs2 = obs


    if t % save_period == 0:
      save_agent(agent, t, outdir, logger, suffix="_checkpoint")

以下の条件でself play学習を行った。

  • 20,000 framesごとに新エージェントと旧エージェントを20回対戦させる
  • 新エージェントの旧エージェントに対するスコアが0.5以上の場合は旧エージェントのパラメータを更新する

学習時間の節約のため新エージェントと旧エージェントとの評価用対戦は20回としているが、
性能差を正確に見たい場合は、もう少し多くしたほうが良いかもしれない。
また、対戦周期の20,000 framesも検討の余地があると思う。

結果

学習経過

f:id:nakamrnk:20200810140855p:plain

学習エージェントAの対戦エージェントB(1つ前の世代)に対する平均スコアを
プロットすると上図のようになる。
800,000 framesで計14世代のエージェントが生まれた。
序盤の性能の低い段階では、世代の切り替わりが激しい。
(第三世代はやや長いが...)
10世代以降になると1世代にかかるframe数も伸びており、
ある程度以上成長すると過去の自分に打ち勝つことが難しくなるのが分かる。

対戦相手がslime volleyball defaultエージェントの場合のスコア

f:id:nakamrnk:20200810141340p:plain

学習序盤ではdefaultエージェントに対するスコアはほとんど伸びていない。
350,000 framesほどで-2.5くらいとなり、その後しばらく停滞するが、
最終的には-0.4(700,000 frames)となり、やや負け越すぐらいの性能となる。
self playのみでもdefaultエージェントに近い性能までは学習できることが分かった。

self play対戦成績

世代交代はエージェントがひとつ前の世代に対して勝つと
行われるが、 それ以前の世代にも勝てるかは判定していない。
そのため、前の世代のみに強くてそれ以前の世代には弱いような
汎用性のないエージェントが生じる可能性もある。

ここでは、いくつかの世代間で対戦成績を比較し、そのようなことが
起こっていないかを確認する。

対戦結果

今回は1, 3, 8, 10, 12, 14世代に対して総当りの対戦を行った。
1つの組み合わせあたり、30戦行い平均スコアを求めた。

A\B 1 3 8 10 12 14
1 nan -0.266667 -3.86667 -4.36667 -4.63333 -4.93333
3 0.266667 nan -3.96667 -4.53333 -4.73333 -4.83333
8 3.86667 3.96667 nan -3.93333 -4.2 -4.3
10 4.36667 4.53333 3.93333 nan -2.46667 -2.2
12 4.63333 4.73333 4.2 2.46667 nan -0.533333
14 4.93333 4.83333 4.3 2.2 0.533333 nan

エージェントA(縦軸)とエージェントB(横軸)の対戦結果は上表のようになる。
1, 3世代は8世代以降に大きく負け越しており、 defaultエージェントに対する学習曲線で
見られた通り、 1, 3 世代と8世代以降には大きな性能差があることが分かる。

それ以降の世代でも今回比較した範囲では後の世代に行くほど性能が向上している。 
(12世代と14世代の差は僅かではあるが...)

3世代 vs 8世代 (スコア : -3.97)

f:id:nakamrnk:20200810161026g:plain

左側の赤が3世代、右側の青が8世代のエージェントである。
実際は30ゲーム試行して平均スコアを計算しているが、最初の1ゲームのみ表示。

3世代と8世代を比較すると,3世代はほとんどボールに触れていないが、
8世代になると自分の近くに来たボールには反応できている。

8世代 vs 10世代 (スコア : -3.93)

f:id:nakamrnk:20200810161317g:plain

赤 : 8世代、 青 : 10世代。
8世代と10世代の試合ではだいぶラリーが続くようになっている。
8世代はボールに反応して動けてはいるが、 ボールコントロール
あまく、相手コートにうまく返せていない。 一方で10世代は2, 3タッチで
ボールを相手コートに返せているので、成長が見て取れる。

10世代 vs 14世代 (スコア : -2.2)

f:id:nakamrnk:20200810162024g:plain

赤 : 10世代、 青 : 14世代。

10世代と14世代の対戦ではどちらも中々ボールを落とさない。
14世代のほうが動きが洗練されているように見えるが、 10世代も粘るので
結果的に5点先取までにタイムアップ(3,000 framesでタイムアップ)となり、
スコアが2くらいになるものと思われる。

まとめ

今回は slime volleyballをself playによって学習した。
self playのハイパーパラメータには調整の余地があると思ったが、
self playのみでもslime volleyballのdefault エージェントくらいのモデルは
学習できることが分かった。 学習を早めるために、学習初期はself playを
行い、 終盤はdefault エージェントを使うようにすればいいのではないかと
思った。

参考文献