XAIについての検証 - SHAP

はじめに

前回に引き続きfashion-mnistデータに対するXAIの検証を行う。
今回はSHAPアルゴリズムについて検証する。

SHAP

SHAP1はXAIアルゴリズムの一つである。

  • 各特徴量が加減算的に予測に寄与するとする
  • ある特徴を使う場合と使わない場合の差から寄与度(SHAP値)を求める
  • Gradient を利用してCNNにも適用できる

実装は以下で公開されている。
https://github.com/slundberg/shap

検証

前回学習したモデルをそのまま利用。
pytorchによるSHAP Deep Explainerは以下にチュートリアルがあったのでそれを参考にした。
pytorch mnistチュートリアル

コード (jupyter notebook)

前回と共通 (DataLoader, モデル定義)

import os

import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
%matplotlib inline
import seaborn as sns

import torch
from torch.utils.data import DataLoader
import torchvision


data_root = "data"
os.makedirs(data_root, exist_ok=True)

IMG_SIZE= 28
batch_size = 104



train_transform =  torchvision.transforms.Compose([
        torchvision.transforms.RandomHorizontalFlip(),
        torchvision.transforms.RandomRotation(5),
        torchvision.transforms.RandomResizedCrop(IMG_SIZE, scale=(0.9, 1.0)),
        torchvision.transforms.ToTensor(),
    ])
test_transform = torchvision.transforms.ToTensor()


train_dataset = torchvision.datasets.FashionMNIST(data_root, train=True, transform=train_transform, target_transform=None, download=True)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

valid_dataset = torchvision.datasets.FashionMNIST(data_root, train=False, transform=test_transform, target_transform=None, download=True)
valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False)

print("train : ", len(train_dataset))
print("valid : ", len(valid_dataset))

import torch.nn as nn
import torch.nn.functional as F

class ConvMod(nn.Sequential):
  def __init__(self, in_channel, out_channel, kernel_size, padding=1, stride=2, norm_fn=None, act_func=None):
    seqs = []
    seqs.append(nn.Conv2d(in_channel, out_channel, kernel_size, padding=padding, stride=stride))
    if norm_fn is not None:
      seqs.append(norm_fn(out_channel))
    if act_func is not None:
      seqs.append(act_func())
    super().__init__(*seqs)


class Model(nn.Module):
  MOD = ConvMod
  def __init__(self, num_classes=10, c0=16, num_conv=3, stride=2, norm_fn=None, act_func=nn.ReLU, last_act=None):
    super().__init__()
    seqs = []
    cin = 1
    cout = c0
    for n in range(num_conv):
      seqs.append(self.MOD(cin, cout, 3, padding=1, stride=stride, norm_fn=norm_fn, act_func=act_func))
      cin, cout = cout, cout * 2
    self.feature = nn.Sequential(*seqs)
    
    self.pool = nn.AdaptiveAvgPool2d(1)
    
    self.pred = nn.Linear(c0*4, num_classes)
    
    if last_act is not None:
      self.last_act = last_act()
    else:
      self.last_act = None

  def forward(self, x):
    
    x = self.feature(x)
    
    gap = self.pool(x).squeeze(3).squeeze(2)
    
    o = self.pred(gap)
    
    if self.last_act is not None:
      o = self.last_act(o)

    return o

モデル読み込み

from functools import partial
MODEL_PATH = "model.pth"
c0 = 32
device = "cpu"
valid_df = pd.read_csv("valid.csv", index_col=0)


model = Model(num_classes=len(train_dataset.classes), c0=c0, norm_fn=nn.BatchNorm2d, last_act=partial(nn.Softmax, dim=1)).to(device)
model.load_state_dict(torch.load(MODEL_PATH))
model.eval()
print((valid_df["true"]==valid_df["pred"]).mean())
valid_df.head()

解析・可視化

import shap

def plot_figures(shap_numpy, image_numpy, probs, num_vis_rank=3):
  num_classes = len(shap_numpy)
  nrows = 1 + num_vis_rank
  ncols = len(image_numpy)
  figsize = (ncols * 3, nrows*2.5)
  
  fig, axes  = plt.subplots(figsize=figsize, ncols=ncols, nrows=nrows)
  max_val = np.percentile(np.abs(np.concatenate(shap_numpy)), 99.9)
    
  for j in range(ncols):
    axes[0, j].imshow(image_numpy[j][:, :, 0], cmap="gray")
    axes[0, j].set_axis_off()

    p = probs[j]
    inds = np.argsort(p)[::-1]
    pro = p[inds]
    
    for n in range(num_vis_rank):
      class_id = inds[n]
      class_name = valid_dataset.classes[class_id]
      prob_str = "{:.1f}%".format(pro[n]*100)
      axes[1 + n, j].imshow(shap_numpy[class_id][j][:, :, 0], cmap="bwr", vmin=-max_val, vmax=max_val)
      axes[1 + n, j].set_axis_off()
      axes[1 + n, j].set_title(class_name + " : " +  prob_str, fontsize=18)
  fig.tight_layout()
  return fig
  
  
def analyze(v_indices, bg_indices):
  background = torch.cat([valid_dataset[i][0][None] for i in bg_indices])
  test_images = torch.cat([valid_dataset[i][0][None] for i in v_indices])
  with torch.no_grad():
    outputs = model(background)
    expected_value = outputs.mean(0).cpu().numpy()
    probs = model(test_images).cpu().numpy()
  e = shap.DeepExplainer(model, background)
  shap_values = e.shap_values(test_images)
  sh_arr = np.array([shap_value.sum(axis=(1, 2, 3)) for shap_value in shap_values])

  shap_numpy = [np.swapaxes(np.swapaxes(s, 1, -1), 1, 2) for s in shap_values]
  test_numpy = np.swapaxes(np.swapaxes(test_images.numpy(), 1, -1), 1, 2)
  return plot_figures(shap_numpy, test_numpy, probs)  
  
fig = analyze(valid_df[valid_df["true"] == 8].sample(10).index, valid_df.sample(200).index)
fig.tight_layout()
fig.savefig("tmp.jpg")

前回モデルのクラスごとの予測性能
クラス precision recall F-measure
T-shirt/top 0.884774 0.86 0.872211
Trouser 0.993946 0.985 0.989453
Pullover 0.86724 0.908 0.887152
Dress 0.90619 0.937 0.921337
Coat 0.878 0.878 0.878
Sandal 0.985787 0.971 0.978338
Shirt 0.804233 0.76 0.781491
Sneaker 0.95122 0.975 0.962963
Bag 0.984048 0.987 0.985522
Ankle boot 0.968938 0.967 0.967968

Bagクラス

f:id:nakamrnk:20201013113510j:plain

最上段は元の画像、 以降は予測結果上位クラス(1, 2, 3位)に対する
SHAP値の分布である。赤い部分は予測に対して正の寄与をしている領域
青い部分は予測に対して負の寄与をしている領域である。

Bagクラスは多くの画像が正しく判定できているため、上から2行目の第1予測
クラスへの反応が大きい。 バッグの左右端の領域や持ち手部分がBag予測に
寄与していることが分かり、妥当な結果と言える。
第2予測以降はほとんど反応していない。

Trouserクラス

f:id:nakamrnk:20201013114412j:plain

Trouserクラスも比較的高精度で予測出来ているクラスである。
Trouserクラスの場合は股下の背景部分に赤い領域が多く、
そこに注目して判定を行っている。この構造は他のクラスにないため
判定基準としては妥当と思われる。

一方で股下構造の見えない 右から3番目の画像は誤判定している
(Dress クラス 85 %, Trouser 13%) 。このように典型的な特徴から
外れた画像に対しては誤判定が起こりやすい。

靴クラス比較

Sandalクラス

f:id:nakamrnk:20201013115142j:plain

Sneakerクラス

f:id:nakamrnk:20201013115228j:plain

Ankle bootクラス

f:id:nakamrnk:20201013115307j:plain

靴クラスを比較すると

  • Sandalクラスは隙間部分や紐部分に反応している
  • Sneakerクラスは特定箇所への反応が弱い
    • 全体の構造を見て判定しているためSHAPでは特徴がでない?
  • Ankle bootクラスはつま先に強く反応しているものが多い

Ankle bootはくるぶしを覆う靴なのでくるぶし当たりに反応するほうが
人間の感覚からは自然であろう。 しかし、実際はつま先付近の構造に
強く反応しているため、つま先付近の構造に対して何か他クラスとの
違いを発見したものと思われる。

その特徴が適切なものならば良いのだが、Sneakerクラス画像の右から二番目は
つま先付近に反応してAnkle bootクラスと誤判定しており、 Ankle bootクラスの
特徴としては十分ではないと思われる (ラベルミスの可能性もあるが)。
今回のモデルで偶然このような学習が進んだのか、 現状の学習手法に問題が
あるかは今後の検証課題である。

トップスクラス比較

T-shirt/top クラス

f:id:nakamrnk:20201013122027j:plain

Pullover クラス

f:id:nakamrnk:20201013122039j:plain

Dress クラス

f:id:nakamrnk:20201013122051j:plain

Coat クラス

f:id:nakamrnk:20201013122116j:plain

Shirt クラス

f:id:nakamrnk:20201013122131j:plain

各クラスに対するSHAPについて

  • T-shirt/top は肩から脇付近への反応がやや強い
  • Pullover は長袖部分への反応が強いように見える
    • そこまではっきりはしていない
  • Dressクラスは肩から胸付近と腰付近に反応
  • Coatクラスは首元への反応がやや強い
  • Shirtクラスは首元と体中心付近に反応

現状トップスについてはSHAPによって説得力のある説明は
できないと思う。Dressクラスの特徴的なボディラインや
T-shirt/topクラスの肩まわりなどある程度人間の感覚に近い傾向
も見て取れるが、noisyであまり綺麗に特徴を捉えているとは言えない。
正の寄与と負の寄与が混在しているような領域も多く、
一見してどちらが優勢なのか分かりづらいのもマイナス点である。

まとめ

前回に引き続きfashion-mnistデータに対してXAIの検証を行った。
SHAPはTrouserの股下などわりと細かい特徴も捉えられているが、
正負の寄与が混じった領域の解釈などに難があると感じた。
今後はGrad-CAMなどの滑らかなsaliency map系のXAIと比較していきたい。

参考文献