Gradient Accumulation と Normalization

はじめに

batch sizeは学習の安定性やモデル性能に大きな影響を与えるパラメータである。
大きなbatch sizeは学習を安定化するが、GPUのメモリを使い果たしてしまう。
GPT31などの近年の大規模モデルは複数のGPUに分散して非常に大きな
batch sizeをとっており、計算リソースの乏しい人では再現することが難しい。
gradient accumulation2はこの問題をある程度解決してくれる可能性があり、
今回はこのgradient accumulationと各種normalizationの相性をMNISTデータに
対して検証した。

gradient accumulation

参考 : https://medium.com/huggingface/training-larger-batches-practical-tips-on-1-gpu-multi-gpu-distributed-setups-ec88c3e51255

gradient accumulationはbatch単位でパラメータの更新を行わずに、
複数個のbatchのパラメータ勾配を積算してから更新することで
実効的なbatch sizeを増やす手法である。
例えば、batch size 4で4回勾配の積算を行ってからパラメータ更新すると、
メモリ使用量はbatch size 4のままだが、batch size 16と同等の性能が期待される。

ただ、この手法ではbatch normalizationのように学習中のバッチ単位の統計量を
利用している手法ではうまく機能しない可能性がある。

検証

MNISTデータに対して各種normalizationに対する
gradient accumulationの効果を検証した。

データ読み込み

import torch
import torch.nn as nn

from torch.utils.data import Dataset, DataLoader

from torchvision.datasets import MNIST
from torchvision import transforms 

train_trans = transforms.Compose([transforms.ToTensor()])
valid_trans = transforms.ToTensor()


data_dir = "data"
train_dataset = MNIST(data_dir, train=True, download=True, transform=train_trans)
valid_dataset = MNIST(data_dir, train=False, download=True, transform=valid_trans)

特にデータ水増しはしていない。

ネットワーク

class SimpleModel(nn.Module):
  def __init__(self, norm_fn=None, channels=[32, 64, 128, 256], c0=1, 
               act_func=nn.ReLU, num_classes=10):
    super().__init__()
    prev_channel = c0
    features = []
    for c, channel in enumerate(channels):
      features.append(nn.Conv2d(prev_channel, channel, 3, padding=1))
      if norm_fn is not None:
        features.append(norm_fn(channel))
      features.append(act_func())
      if c < len(channels) -1:
        features.append(nn.MaxPool2d(2))
      prev_channel = channel
    self.features = nn.Sequential(*features)
    self.dense = nn.Linear(prev_channel, num_classes)

  def forward(self, x):
    x = self.features(x)
    x = nn.AdaptiveAvgPool2d(1)(x).squeeze(3).squeeze(2)
    x = self.dense(x)
    return x

skip connectionのない単純なVGG-likeなネットワークを利用。
Global Average Poolingから全結合1層で分類。

検証パラメータ

学習・評価コード

device = "cuda" if torch.cuda.is_available() else "cpu"
import numpy as np
import pandas as pd

from torch.optim import Adam
from tqdm.notebook import tqdm

def evaluate(model, valid_loader, show_progress=False):
  prg = tqdm(valid_loader) if show_progress else valid_loader
  prds = []
  lbls = []  
  for batch in prg:
    imgs, labels = [b.to(device) for b in batch]
    with torch.no_grad():
      o = model(imgs)
      pred = np.argmax(o.cpu().numpy(), axis=1)
      prds.append(pred)
      lbls.append(labels.cpu().numpy())

  prds = np.concatenate(prds)
  lbls = np.concatenate(lbls)
  acc = np.mean(prds==lbls)
  return acc

def get_iter(train_loader):
  while True:
    for batch in train_loader:
      yield batch


def train(model, num_iterations, batch_size, lr0=1e-3, eval_period=300, valid_batch_size=32, num_accum=1):
  model.train()
  train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
  valid_loader = DataLoader(valid_dataset, batch_size=valid_batch_size, shuffle=False)

  criteria = nn.CrossEntropyLoss()

  optimizer = Adam(model.parameters(), lr=lr0)
  optimizer.zero_grad()

  iters = get_iter(train_loader)

  iter = 0
  losses = []
  accs = []
  prg = tqdm(iters, total=num_iterations*num_accum)
  
  loss_factor = 1.0 / max(num_accum, 1)

  i = 0
  for batch in prg:
    i += 1
    imgs, labels = [b.to(device) for b in batch]
    o = model(imgs)
    loss = criteria(o, labels) * loss_factor
    loss.backward()
    if i % num_accum == 0:
      optimizer.step()
      optimizer.zero_grad()
      iter += 1
    else:
      continue

    losses.append(loss.item() / loss_factor)

    mean_loss = np.mean(losses[-100:])
    if iter % 10 == 0:
      prg.set_description("iter {} , mean loss {:.4f}".format(iter, mean_loss))
    if iter % eval_period == 0:
      model.eval()
      vacc = evaluate(model, valid_loader)
      model.train()
      accs.append((iter, mean_loss, vacc))
    if iter >= num_iterations:
      break

  accs_df = pd.DataFrame(accs, columns=["iter", "train_loss", "valid_acc"]).set_index("iter")
  return accs_df

num_iterations = 3600
batches = [1, 1, 4, 16]
num_accums = [1, 16, 4, 1]

iteration数は3600 (accumulationがある場合はaccumulate数倍される)。
batch size 1, 4, 16に対して、accumulation数を 16, 4, 1とすることで実行的な
batch sizeを合わせている。
比較のためbatch size 1, accumulation数1も計算。
また、初期値によるバラつきを考慮して、同じパラメータで3run実行している。

パラメータ検証 コード

from functools import partial
def get_group_norm_func(num_groups):
  def norm_func(num_channel):
    return nn.GroupNorm(num_groups, num_channel)
  return norm_func


norms = {
    "grp4_norm":get_group_norm_func(4),
    "grp16_norm":get_group_norm_func(16),
    "no_norm":None,
    "bn_norm": nn.BatchNorm2d,
    "ins_norm":nn.InstanceNorm2d,
}

results = []
for key, norm_fn in norms.items():
  accs_dfs = []
  for num_accum, batch_size in zip(num_accums, batches):
    for run in range(3):
      model = SimpleModel(norm_fn=norm_fn).to(device)
      accs_df = train(model, num_iterations, batch_size, num_accum=num_accum)
      accs_df["batch_size"] = batch_size
      accs_df["num_accum"] = num_accum
      accs_df["run"] = run + 1
      accs_dfs.append(accs_df)
  result_df = pd.concat(accs_dfs)
  result_df["norm"] = key
  results.append(result_df)

結果

import seaborn as sns
all_results_df = pd.concat(results).reset_index()
all_results_df["bs_accum"] = "bs" + all_results_df["batch_size"].map(str) + "_acum" + all_results_df["num_accum"].map(str)
g = sns.relplot(x="iter", y="valid_acc", data=all_results_df, kind="line", hue="bs_accum", col="norm", col_wrap=3)

f:id:nakamrnk:20200727134051p:plain

青線がgradient accumulationなし, batch size 1の結果である。
ほとんどの場合において、 batch sizeの増加やgradient accumulationにより
学習速度、最終的な性能ともに向上している。

唯一batch normalizaitonのbatch size 1, gradient accumulation=16の場合のみ
性能の改善が見られていない。 これはbatch normalizationが学習時にバッチ内の統計量
(の移動平均)を保持して推論時に利用するため、このgradient accumulationの実装では、
正しく学習できていないためと思われる。

一方で同じbatch normalizationでもbatch size 4
gradient accumulation=4の場合はbatch size 16と同程度の性能が出ている。
今回のMNISTデータの場合はbatch size=4程度あれば正しい統計量を学習できるため
と解釈できる。

gradient accumulationとバッチサイズ増加についてもう少し詳して見てみる

f:id:nakamrnk:20200727135456p:plain

batch normalization以外はgradient accumulationによりbatch size1, 4それぞれが
batch size 16でaccumulationなしの場合と同程度の性能を示している。
batch normalizationさえ使わなければgradient accumulationにより、
実効的なbatch sizeを増やせることが分かった。

2500 iteration以降の平均 valid accuracy

評価

all_results_df = pd.concat(results).reset_index()
all_results_df["bs_accum"] = "bs" + all_results_df["batch_size"].map(str) + "_acum" + all_results_df["num_accum"].map(str)
targ = ((all_results_df["batch_size"] > 1) | (all_results_df["num_accum"]>1)) & (all_results_df["iter"] > 250)
all_results_df = all_results_df[targ]
print(all_results_df.groupby(["norm", "bs_accum"])[ "valid_acc"].mean().unstack().to_markdown())

norm bs16_acum1 bs1_acum16 bs4_acum4
bn_norm 0.978367 0.908319 0.984044
grp16_norm 0.974481 0.971417 0.971078
grp4_norm 0.972542 0.973103 0.9706
ins_norm 0.984742 0.984817 0.984356
no_norm 0.969178 0.969169 0.967678

今回の問題においては、instance normalizationとbatch normalization
(gradient accumulationなし)がほぼ同程度の性能で、
group normalizationはやや劣る結果となった。
group normalizationについては、normalizationなしと比較すると性能が向上
しているため、全く効果がないわけではないと思われるが、  
instance normalizationには劣る結果となった。
(グループパラメータの値が適切でない?学習不足?)

まとめ

今回はgradient accumulationについて簡単な検証を行った。
結果batch normalization以外の場合はgradient accumulationで
メモリサイズを節約して学習できることが確認できた。
normalizationについては、問題依存の可能性はあるが、
instance normalizationあたりから始めればいいのではないかと思う。

参考文献