XAIについての検証 - Anchors
はじめに
画像系Deep LearingにおけるXAI (Explainable AI)のひとつAnchorsを用いて
Fahion-mnistデータに対して学習を行ったモデルの解析を行った。
XAI
Deep Learning モデルはその性能の高さから様々な分野で利用されているが、
処理の多くがNeural Networkにより抽出された特徴量に依存しているため、
人間には理解することが難しい(ブラックボックス化している)1。
このAIのブラックボックス化を緩和するためXAIと呼ばれる技術が注目されている。
XAI関連の技術は様々(NNの説明可能モデルへの置き換え、ブラックボックスの中身検査等2)
であるが今回はモデルの出力結果を説明する技術の一つAnchorsを検証する。
学習データ・モデル
学習データはFashion-Mnistを用いた。
これは28x28のファッション画像データに対する10クラス分類である。
学習データは60,000(各クラス6,000)、評価データ10,000(各クラス1000)である。
ソースコード (jupyter notebook)
モジュール
import os import numpy as np import pandas as pd import matplotlib.pyplot as plt %matplotlib inline import seaborn as sns
データ読み込み
import torch from torch.utils.data import DataLoader import torchvision data_root = "data" os.makedirs(data_root, exist_ok=True) IMG_SIZE= 28 batch_size = 32 train_transform = torchvision.transforms.Compose([ torchvision.transforms.RandomHorizontalFlip(), torchvision.transforms.RandomRotation(5), torchvision.transforms.RandomResizedCrop(IMG_SIZE, scale=(0.9, 1.0)), torchvision.transforms.ToTensor(), ]) test_transform = torchvision.transforms.ToTensor() train_dataset = torchvision.datasets.FashionMNIST(data_root, train=True, transform=train_transform, target_transform=None, download=True) train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) valid_dataset = torchvision.datasets.FashionMNIST(data_root, train=False, transform=test_transform, target_transform=None, download=True) valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False) print("train : ", len(train_dataset)) print("valid : ", len(valid_dataset))
モデル定義
import torch.nn as nn import torch.nn.functional as F class ConvMod(nn.Sequential): def __init__(self, in_channel, out_channel, kernel_size, padding=1, stride=2, norm_fn=None, act_func=None): seqs = [] seqs.append(nn.Conv2d(in_channel, out_channel, kernel_size, padding=padding, stride=stride)) if norm_fn is not None: seqs.append(norm_fn(out_channel)) if act_func is not None: seqs.append(act_func()) super().__init__(*seqs) class Model(nn.Module): MOD = ConvMod def __init__(self, num_classes=10, c0=16, num_conv=3, stride=2, norm_fn=None, act_func=nn.ReLU): super().__init__() seqs = [] cin = 1 cout = c0 for n in range(num_conv): seqs.append(self.MOD(cin, cout, 3, padding=1, stride=stride, norm_fn=norm_fn, act_func=act_func)) cin, cout = cout, cout * 2 self.feature = nn.Sequential(*seqs) self.pool = nn.AdaptiveAvgPool2d(1) self.pred = nn.Linear(c0*4, num_classes) def forward(self, x): x = self.feature(x) gap = self.pool(x).squeeze(3).squeeze(2) o = self.pred(gap) return o
学習
from torch.optim import Adam from tqdm import tqdm as tqdm def evaluation(model, valid_loader, criteria, device): model.eval() prg = tqdm(valid_loader) prg.set_description("valid") results = [[], []] losses = [] for batch in prg: imgs, labels = [b.to(device) for b in batch] with torch.no_grad(): preds = model(imgs) loss = criteria(preds, labels) losses.append(loss.item()) results[0].append(labels.cpu().numpy()) results[1].append(preds.argmax(dim=1).cpu().numpy()) valid_loss = np.mean(losses) result_df = pd.DataFrame(np.array([np.concatenate(col) for col in results]).T, columns=["true", "pred"]) return valid_loss, result_df lr0 = 1e-3 num_epochs = 30 c0 = 32 # num of first layer channels device = "cpu" MODEL_PATH = "model.pth" def accum_func(epoch): if epoch < num_epochs * 0.85: return 1 else: return 10 model = Model(num_classes=len(train_dataset.classes), c0=c0, norm_fn=nn.BatchNorm2d).to(device) optimizer = Adam(model.parameters(), lr=lr0) criterion = nn.CrossEntropyLoss() optimizer.zero_grad() logs = [] for epoch in range(1, num_epochs + 1): model.train() prg = tqdm(train_loader) prg.set_description("train Epoch {}".format(epoch)) num_accum = accum_func(epoch) loss_factor = 1.0 / num_accum it = 0 count = 0 hit = 0 losses = [] for batch in prg: imgs, labels = [b.to(device) for b in batch] preds = model(imgs) loss = criterion(preds, labels) * loss_factor loss.backward() # calculate train accuracy it += 1 count += len(labels) hit += (preds.argmax(dim=1) == labels).sum().item() # gradient accumulation if it % num_accum == 0: optimizer.step() optimizer.zero_grad() losses.append(loss.item() / loss_factor) train_acc = hit / count train_loss = np.mean(losses) valid_loss, valid_df = evaluation(model, valid_loader, criterion, device) valid_acc = (valid_df["true"]==valid_df["pred"]).mean() print("Epoch {} , Loss train / {:.3f} valid / {:.3f}, Accuracy train /{:.3f} valid {:.3f}".format(epoch, train_loss, valid_loss, train_acc, valid_acc)) logs.append((epoch, train_loss, valid_loss, train_acc, valid_acc)) logs_df = pd.DataFrame(logs, columns=["Epoch", "train_loss", "valid_loss", "train_acc", "valid_acc"]).set_index("Epoch") fig, _ = plt.subplots(figsize=(15, 4), ncols=2) logs_df[["train_loss", "valid_loss"]].plot(ax=fig.axes[0], title="loss") logs_df[["train_acc", "valid_acc"]].plot(ax=fig.axes[1], title="accuracy") torch.save(model.state_dict(), MODEL_PATH) valid_df.groupby(["true", "pred"]).size().unstack(fill_value=0)
学習結果
学習曲線
confusion matrix
true\pred | T-shirt/top | Trouser | Pullover | Dress | Coat | Sandal | Shirt | Sneaker | Bag | Ankle boot |
---|---|---|---|---|---|---|---|---|---|---|
T-shirt/top | 860 | 0 | 20 | 19 | 5 | 2 | 89 | 0 | 5 | 0 |
Trouser | 0 | 985 | 1 | 10 | 0 | 0 | 2 | 0 | 2 | 0 |
Pullover | 10 | 1 | 908 | 10 | 36 | 0 | 35 | 0 | 0 | 0 |
Dress | 8 | 3 | 9 | 937 | 22 | 0 | 20 | 0 | 1 | 0 |
Coat | 1 | 0 | 52 | 32 | 878 | 0 | 36 | 0 | 1 | 0 |
Sandal | 0 | 0 | 0 | 0 | 0 | 971 | 0 | 21 | 0 | 8 |
Shirt | 92 | 2 | 56 | 25 | 57 | 1 | 760 | 0 | 7 | 0 |
Sneaker | 0 | 0 | 0 | 0 | 0 | 4 | 0 | 975 | 0 | 21 |
Bag | 1 | 0 | 1 | 1 | 2 | 1 | 3 | 2 | 987 | 2 |
Ankle boot | 0 | 0 | 0 | 0 | 0 | 6 | 0 | 27 | 0 | 967 |
全体で92%精度の精度だった。
誤判定は
- T-shirt/top - Shirtクラス間
- Pullover - Shirtクラス間
などが多い。
クラスごとの評価結果
precision | recall | F-measure | |
---|---|---|---|
T-shirt/top | 0.884774 | 0.86 | 0.872211 |
Trouser | 0.993946 | 0.985 | 0.989453 |
Pullover | 0.86724 | 0.908 | 0.887152 |
Dress | 0.90619 | 0.937 | 0.921337 |
Coat | 0.878 | 0.878 | 0.878 |
Sandal | 0.985787 | 0.971 | 0.978338 |
Shirt | 0.804233 | 0.76 | 0.781491 |
Sneaker | 0.95122 | 0.975 | 0.962963 |
Bag | 0.984048 | 0.987 | 0.985522 |
Ankle boot | 0.968938 | 0.967 | 0.967968 |
クラスの内訳は
- トップス - T-shirt/top , Pullover, Dress , Coat, Shirt
- ボトムス- Trouser
- 靴 - Sandal, Sneaker, Ankle boot
- その他 - Bag
であるので、画像全体の形状が独立しているTrouser, Bagクラスは正解率が高い。
一方で全体の構造が似ているトップス関連(私が見ても違いがよく分からないものも多い)
はやや性能が低く、Shirtクラスの正解率が特に低いため原因を説明したい。
Anchors
AnchorsはXAI技術の一つであり、データ点周辺の予測結果がほとんど変わらない
"Anchor"と呼ばれる領域を用いてモデルを説明しようとするアルゴリズムである。
AnchorsはLIMEに近い手法である(Authorが同じ)。
LIMEはあるデータ点近くのモデルの挙動を線形近似することでモデルを説明する
アルゴリズムであるが、 Anchorsはデータ点周辺で予測結果があまり変わらない領域と
それに寄与している特徴量を求めることで、 あるデータの予測に欠かせない
特徴量セットを求めるというアルゴリズムである。
superpixels
LIMEやAnchorsは画像を教師なしで領域分割するsuperpixelsと呼ばれる技術により
画像を分割し、分割した部分を塗りつぶしてモデルに入力することによる結果の変化から
その領域の重要度を推定している。
このsuperpixelsを行うためのアルゴリズムは複数あり、このアルゴリズム選択がLIMEの
説明結果に影響を与える3。
今回は以下の3つの手法を比較した。
Anchors検証
alibi7ライブラリの実装を利用して検証した。
superpixels 準備
from alibi.explainers import AnchorImage from skimage.segmentation import slic, quickshift, watershed def slic_segmentation(image): return slic(image, n_segments=30) def superpixel(image, size=(4, 7)): segments = np.zeros([image.shape[0], image.shape[1]]) row_idx, col_idx = np.where(segments == 0) for i, j in zip(row_idx, col_idx): segments[i, j] = int((image.shape[1]/size[1]) * (i//size[0]) + j//size[1]) return segments def predict_fn(x): img = torch.from_numpy(x.astype(np.float32)/255.0).reshape([-1, 1, IMG_SIZE, IMG_SIZE]).to(device) with torch.no_grad(): logits = model(img) probs = nn.Softmax(dim=1)(logits).cpu().numpy() return probs v_index = 1 tensor, label = valid_dataset[v_index] arr = (tensor.cpu().numpy()[0] * 255).astype(np.uint8) image_shape =arr.shape explainer_qs = AnchorImage(predict_fn, image_shape, segmentation_fn="quickshift") explainer_slic = AnchorImage(predict_fn, image_shape, segmentation_fn=slic_segmentation) explainer_grid = AnchorImage(predict_fn, image_shape, segmentation_fn=superpixel) expls = { "qs":explainer_qs, "slic":explainer_slic, "grid":explainer_grid }
解析・可視化
def analyze(v_indices): num_images = len(v_indices) num_expls = len(expls) figsize = (min(num_images * 4, 25), 3 * (1 + num_expls)) fig, axes = plt.subplots(figsize=figsize, nrows= 1 + num_expls, ncols=num_images) for i, v_index in enumerate(v_indices): tensor, label = valid_dataset[v_index] arr = (tensor.cpu().numpy()[0] * 255).astype(np.uint8) class_label = valid_dataset.classes[label] axes[0, i].imshow(arr, cmap="gray"); axes[0, i].set_title(class_label) probs = predict_fn(arr)[0] pred_index = probs.argmax() prob = probs[pred_index] prob_label = "{} : {:.2f}%".format(valid_dataset.classes[pred_index], prob*100) for k, (key, expl) in enumerate(expls.items()): explanation = expl.explain(arr.reshape([IMG_SIZE, IMG_SIZE, 1]), threshold=.95, p_sample=.8, seed=0) print(i, key, explanation.precision, explanation.coverage) axes[1 + k, i].imshow(explanation.anchor[:,:,0], cmap="gray"); axes[1 + k, i].set_title(prob_label) return fig
各クラス、ランダムに10枚サンプリングした評価データに対してAnchorsの計算を行う。
Bagクラス
上図がBagクラスに対するAnchors検証結果である。
最上段が元画像、2,3, 4行目がそれぞれ、"quick-shift", "slic", "固定グリッド"による
super-pixelアルゴリズムを用いたAnchorsの出力結果である。
Bagクラスの場合全体の構造が他のクラスと異なっているため、Anchorsのような
ローカルな置き換えに反応している画像は少ない。
ただ、持ち手部分は他のクラスにはない特徴であるため、反応している画像も存在する。
Trouserクラス
ThrouserクラスもBagクラスと同様比較的予測精度の高いクラスである。
ボトムスはこのクラスだけなので足の部分が重要であることは予測される。
quick-shift以外のアルゴリズムは足の先端付近の細い部分が重要であるとしているので、
妥当な結果に見える。quick shiftはどこにも反応していない画像がほとんどである。
また、どのアルゴリズムも左から5番目の誤判定画像に関しては画像全体に反応している。 このような誤判定画像に対する結果はTruouserのクラス特徴とは乖離すると思われる。
靴クラス比較
Sandal
Sneaker
Ankle boot
靴関連の3クラスに対するAnchorsの結果を比較すると
- Sandalクラスは紐や細い部分に反応している
- Sneakerクラスは靴の正面部分や履き口に反応
- Ankle bootは正面部分に反応
- Sneakerとの違いは角度?
のようにある程度合理的に見える挙動をしている。 (quick-shiftアルゴリズム以外)
トップス比較
Dressクラス
トップスの中では比較的精度の高いdressクラスに対するAnchorsの結果は下図のようになっている。
固定グリッドアルゴリズムは比較的結果が安定しており、胸や胴まわりに反応しているものが多い。
Pulloverクラス
pulloverクラスは袖の先端や胴、襟元付近に反応しているものが見られる。
T-shirt/top クラス
T-shirt/topクラスは反応している画像自体が少ないが、
反応しているものは首周りや肩、脇あたりに反応している。
Shirtクラス
Shirtクラスは他クラスと比べて性能が低いクラスである。
全体の構造的には半袖はT-shirt, 長袖はPulloverとの判別が難しそうに見えるが
首まわりの構造が異なるのでそれらがはっきりしている画像は正しく判別できている
ように感じる(右から1, 2, 4番目の画像等)。
そもそも素人目にはShirtに見えない画像(左から4, 5番目等)も多いため、
データ自体の難易度が高いため、性能が低いものと思われる。
まとめ
Fashion Mnistデータに対してAnchorsアルゴリズムによるXAIの検証を行った。
比較的妥当に見えるAnchorが出力されているものも多かったが、
superpixelsアルゴリズムへの依存やグローバルな構造を見づらいという
欠点も感じられたので、他のXAI手法と併用したほうがよいと思った。
参考文献
-
https://www.acceluniverse.com/blog/developers/2020/04/XAI-Explainable-AI.html↩
-
https://scikit-image.org/docs/dev/api/skimage.segmentation.html#skimage.segmentation.quickshift↩
-
https://scikit-image.org/docs/dev/api/skimage.segmentation.html#skimage.segmentation.slic↩
-
https://docs.seldon.io/projects/alibi/en/stable/examples/anchor_image_fashion_mnist.html↩
-
https://docs.seldon.io/projects/alibi/en/stable/overview/getting_started.html↩