XAIについての検証 - SHAP
はじめに
前回に引き続きfashion-mnistデータに対するXAIの検証を行う。
今回はSHAPアルゴリズムについて検証する。
SHAP
- 各特徴量が加減算的に予測に寄与するとする
- ある特徴を使う場合と使わない場合の差から寄与度(SHAP値)を求める
- Gradient を利用してCNNにも適用できる
実装は以下で公開されている。
https://github.com/slundberg/shap
検証
前回学習したモデルをそのまま利用。
pytorchによるSHAP Deep Explainerは以下にチュートリアルがあったのでそれを参考にした。
pytorch mnistチュートリアル
コード (jupyter notebook)
前回と共通 (DataLoader, モデル定義)
import os import numpy as np import pandas as pd import matplotlib.pyplot as plt %matplotlib inline import seaborn as sns import torch from torch.utils.data import DataLoader import torchvision data_root = "data" os.makedirs(data_root, exist_ok=True) IMG_SIZE= 28 batch_size = 104 train_transform = torchvision.transforms.Compose([ torchvision.transforms.RandomHorizontalFlip(), torchvision.transforms.RandomRotation(5), torchvision.transforms.RandomResizedCrop(IMG_SIZE, scale=(0.9, 1.0)), torchvision.transforms.ToTensor(), ]) test_transform = torchvision.transforms.ToTensor() train_dataset = torchvision.datasets.FashionMNIST(data_root, train=True, transform=train_transform, target_transform=None, download=True) train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) valid_dataset = torchvision.datasets.FashionMNIST(data_root, train=False, transform=test_transform, target_transform=None, download=True) valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False) print("train : ", len(train_dataset)) print("valid : ", len(valid_dataset)) import torch.nn as nn import torch.nn.functional as F class ConvMod(nn.Sequential): def __init__(self, in_channel, out_channel, kernel_size, padding=1, stride=2, norm_fn=None, act_func=None): seqs = [] seqs.append(nn.Conv2d(in_channel, out_channel, kernel_size, padding=padding, stride=stride)) if norm_fn is not None: seqs.append(norm_fn(out_channel)) if act_func is not None: seqs.append(act_func()) super().__init__(*seqs) class Model(nn.Module): MOD = ConvMod def __init__(self, num_classes=10, c0=16, num_conv=3, stride=2, norm_fn=None, act_func=nn.ReLU, last_act=None): super().__init__() seqs = [] cin = 1 cout = c0 for n in range(num_conv): seqs.append(self.MOD(cin, cout, 3, padding=1, stride=stride, norm_fn=norm_fn, act_func=act_func)) cin, cout = cout, cout * 2 self.feature = nn.Sequential(*seqs) self.pool = nn.AdaptiveAvgPool2d(1) self.pred = nn.Linear(c0*4, num_classes) if last_act is not None: self.last_act = last_act() else: self.last_act = None def forward(self, x): x = self.feature(x) gap = self.pool(x).squeeze(3).squeeze(2) o = self.pred(gap) if self.last_act is not None: o = self.last_act(o) return o
モデル読み込み
from functools import partial MODEL_PATH = "model.pth" c0 = 32 device = "cpu" valid_df = pd.read_csv("valid.csv", index_col=0) model = Model(num_classes=len(train_dataset.classes), c0=c0, norm_fn=nn.BatchNorm2d, last_act=partial(nn.Softmax, dim=1)).to(device) model.load_state_dict(torch.load(MODEL_PATH)) model.eval() print((valid_df["true"]==valid_df["pred"]).mean()) valid_df.head()
解析・可視化
import shap def plot_figures(shap_numpy, image_numpy, probs, num_vis_rank=3): num_classes = len(shap_numpy) nrows = 1 + num_vis_rank ncols = len(image_numpy) figsize = (ncols * 3, nrows*2.5) fig, axes = plt.subplots(figsize=figsize, ncols=ncols, nrows=nrows) max_val = np.percentile(np.abs(np.concatenate(shap_numpy)), 99.9) for j in range(ncols): axes[0, j].imshow(image_numpy[j][:, :, 0], cmap="gray") axes[0, j].set_axis_off() p = probs[j] inds = np.argsort(p)[::-1] pro = p[inds] for n in range(num_vis_rank): class_id = inds[n] class_name = valid_dataset.classes[class_id] prob_str = "{:.1f}%".format(pro[n]*100) axes[1 + n, j].imshow(shap_numpy[class_id][j][:, :, 0], cmap="bwr", vmin=-max_val, vmax=max_val) axes[1 + n, j].set_axis_off() axes[1 + n, j].set_title(class_name + " : " + prob_str, fontsize=18) fig.tight_layout() return fig def analyze(v_indices, bg_indices): background = torch.cat([valid_dataset[i][0][None] for i in bg_indices]) test_images = torch.cat([valid_dataset[i][0][None] for i in v_indices]) with torch.no_grad(): outputs = model(background) expected_value = outputs.mean(0).cpu().numpy() probs = model(test_images).cpu().numpy() e = shap.DeepExplainer(model, background) shap_values = e.shap_values(test_images) sh_arr = np.array([shap_value.sum(axis=(1, 2, 3)) for shap_value in shap_values]) shap_numpy = [np.swapaxes(np.swapaxes(s, 1, -1), 1, 2) for s in shap_values] test_numpy = np.swapaxes(np.swapaxes(test_images.numpy(), 1, -1), 1, 2) return plot_figures(shap_numpy, test_numpy, probs) fig = analyze(valid_df[valid_df["true"] == 8].sample(10).index, valid_df.sample(200).index) fig.tight_layout() fig.savefig("tmp.jpg")
前回モデルのクラスごとの予測性能
クラス | precision | recall | F-measure |
---|---|---|---|
T-shirt/top | 0.884774 | 0.86 | 0.872211 |
Trouser | 0.993946 | 0.985 | 0.989453 |
Pullover | 0.86724 | 0.908 | 0.887152 |
Dress | 0.90619 | 0.937 | 0.921337 |
Coat | 0.878 | 0.878 | 0.878 |
Sandal | 0.985787 | 0.971 | 0.978338 |
Shirt | 0.804233 | 0.76 | 0.781491 |
Sneaker | 0.95122 | 0.975 | 0.962963 |
Bag | 0.984048 | 0.987 | 0.985522 |
Ankle boot | 0.968938 | 0.967 | 0.967968 |
Bagクラス
最上段は元の画像、 以降は予測結果上位クラス(1, 2, 3位)に対する
SHAP値の分布である。赤い部分は予測に対して正の寄与をしている領域
青い部分は予測に対して負の寄与をしている領域である。
Bagクラスは多くの画像が正しく判定できているため、上から2行目の第1予測
クラスへの反応が大きい。 バッグの左右端の領域や持ち手部分がBag予測に
寄与していることが分かり、妥当な結果と言える。
第2予測以降はほとんど反応していない。
Trouserクラス
Trouserクラスも比較的高精度で予測出来ているクラスである。
Trouserクラスの場合は股下の背景部分に赤い領域が多く、
そこに注目して判定を行っている。この構造は他のクラスにないため
判定基準としては妥当と思われる。
一方で股下構造の見えない 右から3番目の画像は誤判定している
(Dress クラス 85 %, Trouser 13%) 。このように典型的な特徴から
外れた画像に対しては誤判定が起こりやすい。
靴クラス比較
Sandalクラス
Sneakerクラス
Ankle bootクラス
靴クラスを比較すると
- Sandalクラスは隙間部分や紐部分に反応している
- Sneakerクラスは特定箇所への反応が弱い
- 全体の構造を見て判定しているためSHAPでは特徴がでない?
- Ankle bootクラスはつま先に強く反応しているものが多い
Ankle bootはくるぶしを覆う靴なのでくるぶし当たりに反応するほうが
人間の感覚からは自然であろう。 しかし、実際はつま先付近の構造に
強く反応しているため、つま先付近の構造に対して何か他クラスとの
違いを発見したものと思われる。
その特徴が適切なものならば良いのだが、Sneakerクラス画像の右から二番目は
つま先付近に反応してAnkle bootクラスと誤判定しており、 Ankle bootクラスの
特徴としては十分ではないと思われる (ラベルミスの可能性もあるが)。
今回のモデルで偶然このような学習が進んだのか、 現状の学習手法に問題が
あるかは今後の検証課題である。
トップスクラス比較
T-shirt/top クラス
Pullover クラス
Dress クラス
Coat クラス
Shirt クラス
各クラスに対するSHAPについて
- T-shirt/top は肩から脇付近への反応がやや強い
- Pullover は長袖部分への反応が強いように見える
- そこまではっきりはしていない
- Dressクラスは肩から胸付近と腰付近に反応
- Coatクラスは首元への反応がやや強い
- Shirtクラスは首元と体中心付近に反応
現状トップスについてはSHAPによって説得力のある説明は
できないと思う。Dressクラスの特徴的なボディラインや
T-shirt/topクラスの肩まわりなどある程度人間の感覚に近い傾向
も見て取れるが、noisyであまり綺麗に特徴を捉えているとは言えない。
正の寄与と負の寄与が混在しているような領域も多く、
一見してどちらが優勢なのか分かりづらいのもマイナス点である。
まとめ
前回に引き続きfashion-mnistデータに対してXAIの検証を行った。
SHAPはTrouserの股下などわりと細かい特徴も捉えられているが、
正負の寄与が混じった領域の解釈などに難があると感じた。
今後はGrad-CAMなどの滑らかなsaliency map系のXAIと比較していきたい。