XAIについての検証 - Anchors

はじめに

画像系Deep LearingにおけるXAI (Explainable AI)のひとつAnchorsを用いて
Fahion-mnistデータに対して学習を行ったモデルの解析を行った。

XAI

Deep Learning モデルはその性能の高さから様々な分野で利用されているが、
処理の多くがNeural Networkにより抽出された特徴量に依存しているため、
人間には理解することが難しい(ブラックボックス化している)1

このAIのブラックボックス化を緩和するためXAIと呼ばれる技術が注目されている。
XAI関連の技術は様々(NNの説明可能モデルへの置き換え、ブラックボックスの中身検査等2)
であるが今回はモデルの出力結果を説明する技術の一つAnchorsを検証する。

学習データ・モデル

学習データはFashion-Mnistを用いた。
これは28x28のファッション画像データに対する10クラス分類である。

学習データは60,000(各クラス6,000)、評価データ10,000(各クラス1000)である。

ソースコード (jupyter notebook)

モジュール

import os

import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
%matplotlib inline
import seaborn as sns

データ読み込み

import torch
from torch.utils.data import DataLoader
import torchvision


data_root = "data"
os.makedirs(data_root, exist_ok=True)

IMG_SIZE= 28
batch_size = 32



train_transform =  torchvision.transforms.Compose([
        torchvision.transforms.RandomHorizontalFlip(),
        torchvision.transforms.RandomRotation(5),
        torchvision.transforms.RandomResizedCrop(IMG_SIZE, scale=(0.9, 1.0)),
        torchvision.transforms.ToTensor(),
    ])
test_transform = torchvision.transforms.ToTensor()


train_dataset = torchvision.datasets.FashionMNIST(data_root, train=True, transform=train_transform, target_transform=None, download=True)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

valid_dataset = torchvision.datasets.FashionMNIST(data_root, train=False, transform=test_transform, target_transform=None, download=True)
valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False)

print("train : ", len(train_dataset))
print("valid : ", len(valid_dataset))

モデル定義

import torch.nn as nn
import torch.nn.functional as F

class ConvMod(nn.Sequential):
  def __init__(self, in_channel, out_channel, kernel_size, padding=1, stride=2, norm_fn=None, act_func=None):
    seqs = []
    seqs.append(nn.Conv2d(in_channel, out_channel, kernel_size, padding=padding, stride=stride))
    if norm_fn is not None:
      seqs.append(norm_fn(out_channel))
    if act_func is not None:
      seqs.append(act_func())
    super().__init__(*seqs)


class Model(nn.Module):
  
  MOD = ConvMod
  
  
  def __init__(self, num_classes=10, c0=16, num_conv=3, stride=2, norm_fn=None, act_func=nn.ReLU):
    super().__init__()
    seqs = []
    cin = 1
    cout = c0
    for n in range(num_conv):
      seqs.append(self.MOD(cin, cout, 3, padding=1, stride=stride, norm_fn=norm_fn, act_func=act_func))
      cin, cout = cout, cout * 2
    self.feature = nn.Sequential(*seqs)
    
    self.pool = nn.AdaptiveAvgPool2d(1)
    
    self.pred = nn.Linear(c0*4, num_classes)
    
    
    
  def forward(self, x):
    
    x = self.feature(x)
    
    gap = self.pool(x).squeeze(3).squeeze(2)
    
    o = self.pred(gap)
    return o
    
 

学習

from torch.optim import Adam
from tqdm import tqdm as tqdm


def evaluation(model, valid_loader, criteria, device):
  model.eval()
  prg = tqdm(valid_loader)
  prg.set_description("valid")
  results = [[], []]
  losses = []
  for batch in prg:
    imgs, labels = [b.to(device) for b in batch]
    with torch.no_grad():
      preds = model(imgs)
      loss = criteria(preds, labels)
      losses.append(loss.item())

    results[0].append(labels.cpu().numpy())    
    results[1].append(preds.argmax(dim=1).cpu().numpy())

  valid_loss = np.mean(losses)
  result_df = pd.DataFrame(np.array([np.concatenate(col) for col in results]).T, columns=["true", "pred"])
  return valid_loss, result_df


lr0 = 1e-3
num_epochs = 30
c0 = 32 # num of first layer channels 
device = "cpu"
MODEL_PATH = "model.pth"

def accum_func(epoch):
  if epoch < num_epochs * 0.85:
    return 1
  else:
    return 10
  

model = Model(num_classes=len(train_dataset.classes), c0=c0, norm_fn=nn.BatchNorm2d).to(device)

optimizer = Adam(model.parameters(), lr=lr0)
criterion = nn.CrossEntropyLoss()


optimizer.zero_grad()

logs = []
for epoch in range(1, num_epochs + 1):
  model.train()

  prg = tqdm(train_loader)
  prg.set_description("train Epoch {}".format(epoch))

  num_accum = accum_func(epoch)
  loss_factor = 1.0 / num_accum
  it = 0
  count = 0
  hit = 0
  losses = []
  for batch in prg:
    imgs, labels = [b.to(device) for b in batch]
    preds = model(imgs)
    loss = criterion(preds, labels) * loss_factor
    loss.backward()
    # calculate train accuracy
    it += 1
    count += len(labels)
    hit += (preds.argmax(dim=1) == labels).sum().item()
    # gradient accumulation
    if it % num_accum == 0:
      optimizer.step()
      optimizer.zero_grad()
    
    losses.append(loss.item() / loss_factor)
    
  train_acc = hit / count
  train_loss = np.mean(losses)
  
  valid_loss, valid_df = evaluation(model, valid_loader, criterion, device)
  valid_acc = (valid_df["true"]==valid_df["pred"]).mean()
  print("Epoch {} , Loss train / {:.3f} valid / {:.3f}, Accuracy  train /{:.3f} valid {:.3f}".format(epoch, train_loss, valid_loss, train_acc, valid_acc))
  logs.append((epoch, train_loss, valid_loss, train_acc, valid_acc))
    
logs_df = pd.DataFrame(logs, columns=["Epoch", "train_loss", "valid_loss", "train_acc", "valid_acc"]).set_index("Epoch")
        
fig, _ = plt.subplots(figsize=(15, 4), ncols=2)
logs_df[["train_loss", "valid_loss"]].plot(ax=fig.axes[0], title="loss")
logs_df[["train_acc", "valid_acc"]].plot(ax=fig.axes[1], title="accuracy")
    
    
    
torch.save(model.state_dict(), MODEL_PATH)
valid_df.groupby(["true", "pred"]).size().unstack(fill_value=0)

学習結果

学習曲線

f:id:nakamrnk:20201012103658j:plain

confusion matrix
true\pred T-shirt/top Trouser Pullover Dress Coat Sandal Shirt Sneaker Bag Ankle boot
T-shirt/top 860 0 20 19 5 2 89 0 5 0
Trouser 0 985 1 10 0 0 2 0 2 0
Pullover 10 1 908 10 36 0 35 0 0 0
Dress 8 3 9 937 22 0 20 0 1 0
Coat 1 0 52 32 878 0 36 0 1 0
Sandal 0 0 0 0 0 971 0 21 0 8
Shirt 92 2 56 25 57 1 760 0 7 0
Sneaker 0 0 0 0 0 4 0 975 0 21
Bag 1 0 1 1 2 1 3 2 987 2
Ankle boot 0 0 0 0 0 6 0 27 0 967

全体で92%精度の精度だった。
誤判定は

  • T-shirt/top - Shirtクラス間
  • Pullover - Shirtクラス間

などが多い。

クラスごとの評価結果
precision recall F-measure
T-shirt/top 0.884774 0.86 0.872211
Trouser 0.993946 0.985 0.989453
Pullover 0.86724 0.908 0.887152
Dress 0.90619 0.937 0.921337
Coat 0.878 0.878 0.878
Sandal 0.985787 0.971 0.978338
Shirt 0.804233 0.76 0.781491
Sneaker 0.95122 0.975 0.962963
Bag 0.984048 0.987 0.985522
Ankle boot 0.968938 0.967 0.967968

クラスの内訳は

  • トップス - T-shirt/top , Pullover, Dress , Coat, Shirt
  • ボトムス- Trouser
  • 靴 - Sandal, Sneaker, Ankle boot
  • その他 - Bag

であるので、画像全体の形状が独立しているTrouser, Bagクラスは正解率が高い。
一方で全体の構造が似ているトップス関連(私が見ても違いがよく分からないものも多い)
はやや性能が低く、Shirtクラスの正解率が特に低いため原因を説明したい。

Anchors

AnchorsはXAI技術の一つであり、データ点周辺の予測結果がほとんど変わらない
"Anchor"と呼ばれる領域を用いてモデルを説明しようとするアルゴリズムである。
AnchorsLIMEに近い手法である(Authorが同じ)。

LIMEはあるデータ点近くのモデルの挙動を線形近似することでモデルを説明する
アルゴリズムであるが、 Anchorsはデータ点周辺で予測結果があまり変わらない領域と
それに寄与している特徴量を求めることで、 あるデータの予測に欠かせない
特徴量セットを求めるというアルゴリズムである。

superpixels

LIMEAnchorsは画像を教師なしで領域分割するsuperpixelsと呼ばれる技術により
画像を分割し、分割した部分を塗りつぶしてモデルに入力することによる結果の変化から
その領域の重要度を推定している。

このsuperpixelsを行うためのアルゴリズムは複数あり、このアルゴリズム選択がLIME
説明結果に影響を与える3

今回は以下の3つの手法を比較した。

  • quick-shift4
  • slic5
  • 固定グリッド6

Anchors検証

alibi7ライブラリの実装を利用して検証した。

superpixels 準備

from alibi.explainers import AnchorImage
from skimage.segmentation import slic, quickshift, watershed

def slic_segmentation(image):
  return slic(image, n_segments=30)


def superpixel(image, size=(4, 7)):
    segments = np.zeros([image.shape[0], image.shape[1]])
    row_idx, col_idx = np.where(segments == 0)
    for i, j in zip(row_idx, col_idx):
        segments[i, j] = int((image.shape[1]/size[1]) * (i//size[0]) + j//size[1])
    return segments
  
  
def predict_fn(x):
  img = torch.from_numpy(x.astype(np.float32)/255.0).reshape([-1, 1, IMG_SIZE, IMG_SIZE]).to(device)
  with torch.no_grad():
    logits = model(img)
    probs = nn.Softmax(dim=1)(logits).cpu().numpy()
  return probs


v_index = 1
tensor, label = valid_dataset[v_index]
arr = (tensor.cpu().numpy()[0] * 255).astype(np.uint8)
image_shape =arr.shape


explainer_qs = AnchorImage(predict_fn, image_shape, segmentation_fn="quickshift")
explainer_slic = AnchorImage(predict_fn, image_shape, segmentation_fn=slic_segmentation)
explainer_grid = AnchorImage(predict_fn, image_shape, segmentation_fn=superpixel)


expls = {
  "qs":explainer_qs,
  "slic":explainer_slic,
  "grid":explainer_grid
}

解析・可視化

def analyze(v_indices):
  num_images = len(v_indices)
  num_expls = len(expls)
  figsize = (min(num_images * 4, 25), 3 * (1 + num_expls))
  fig, axes = plt.subplots(figsize=figsize, nrows= 1 + num_expls, ncols=num_images)


  for i, v_index in enumerate(v_indices):
    tensor, label = valid_dataset[v_index]
    arr = (tensor.cpu().numpy()[0] * 255).astype(np.uint8)
    class_label = valid_dataset.classes[label]

    axes[0, i].imshow(arr, cmap="gray");
    axes[0, i].set_title(class_label)


    probs = predict_fn(arr)[0]
    pred_index = probs.argmax()
    prob  = probs[pred_index]
    prob_label = "{} : {:.2f}%".format(valid_dataset.classes[pred_index], prob*100)  
    for k, (key, expl) in enumerate(expls.items()):
      explanation = expl.explain(arr.reshape([IMG_SIZE, IMG_SIZE, 1]), threshold=.95, p_sample=.8, seed=0)
      print(i, key, explanation.precision, explanation.coverage)
      axes[1 + k, i].imshow(explanation.anchor[:,:,0], cmap="gray");
      axes[1 + k, i].set_title(prob_label)
  return fig

各クラス、ランダムに10枚サンプリングした評価データに対してAnchorsの計算を行う。

Bagクラス

f:id:nakamrnk:20201012153925j:plain

上図がBagクラスに対するAnchors検証結果である。
最上段が元画像、2,3, 4行目がそれぞれ、"quick-shift", "slic", "固定グリッド"による
super-pixelアルゴリズムを用いたAnchorsの出力結果である。

Bagクラスの場合全体の構造が他のクラスと異なっているため、Anchorsのような
ローカルな置き換えに反応している画像は少ない。
ただ、持ち手部分は他のクラスにはない特徴であるため、反応している画像も存在する。

Trouserクラス

f:id:nakamrnk:20201012154508j:plain

ThrouserクラスもBagクラスと同様比較的予測精度の高いクラスである。
ボトムスはこのクラスだけなので足の部分が重要であることは予測される。
quick-shift以外のアルゴリズムは足の先端付近の細い部分が重要であるとしているので、
妥当な結果に見える。quick shiftはどこにも反応していない画像がほとんどである。

また、どのアルゴリズムも左から5番目の誤判定画像に関しては画像全体に反応している。 このような誤判定画像に対する結果はTruouserのクラス特徴とは乖離すると思われる。

靴クラス比較
Sandal

f:id:nakamrnk:20201012155550j:plain

Sneaker

f:id:nakamrnk:20201012155608j:plain

Ankle boot

f:id:nakamrnk:20201012155625j:plain

靴関連の3クラスに対するAnchorsの結果を比較すると

  • Sandalクラスは紐や細い部分に反応している
  • Sneakerクラスは靴の正面部分や履き口に反応
  • Ankle bootは正面部分に反応
    • Sneakerとの違いは角度?

のようにある程度合理的に見える挙動をしている。 (quick-shiftアルゴリズム以外)

トップス比較

Dressクラス

トップスの中では比較的精度の高いdressクラスに対するAnchorsの結果は下図のようになっている。

f:id:nakamrnk:20201012160958j:plain

固定グリッドアルゴリズムは比較的結果が安定しており、胸や胴まわりに反応しているものが多い。

Pulloverクラス

f:id:nakamrnk:20201012161326j:plain

pulloverクラスは袖の先端や胴、襟元付近に反応しているものが見られる。

T-shirt/top クラス

f:id:nakamrnk:20201012161727j:plain

T-shirt/topクラスは反応している画像自体が少ないが、
反応しているものは首周りや肩、脇あたりに反応している。

Shirtクラス

f:id:nakamrnk:20201012162058j:plain

Shirtクラスは他クラスと比べて性能が低いクラスである。

全体の構造的には半袖はT-shirt, 長袖はPulloverとの判別が難しそうに見えるが
首まわりの構造が異なるのでそれらがはっきりしている画像は正しく判別できている
ように感じる(右から1, 2, 4番目の画像等)。

そもそも素人目にはShirtに見えない画像(左から4, 5番目等)も多いため、
データ自体の難易度が高いため、性能が低いものと思われる。

まとめ

Fashion Mnistデータに対してAnchorsアルゴリズムによるXAIの検証を行った。
比較的妥当に見えるAnchorが出力されているものも多かったが、
superpixelsアルゴリズムへの依存やグローバルな構造を見づらいという
欠点も感じられたので、他のXAI手法と併用したほうがよいと思った。

参考文献