日本語CTRLを1から学習する - 4

はじめに

前回まで日本語データの収集や前処理について検証してきた。
今回はCTRLの学習について検討した。

CTRL

CTRLは Keskar et al. 2019で提案された制御つきの言語生成モデルである。
英語に対して学習済みのCTRLによる推論の検証は以前の記事で行った。

元論文の実装: https://github.com/salesforce/ctrl
transformersのAPI: https://huggingface.co/transformers/model_doc/ctrl.html

元のCTRLは140GB文の文章データを元に学習した巨大なモデルだが、
今回の検証では小規模なデータしかないため もう少し小さなモデルを学習する。

学習

transformersのサンプルスクリプトを参考にした。
https://github.com/huggingface/transformers/blob/master/examples/run_language_modeling.py

コード - モジュール import

import numpy as np
import pandas as pd
from tqdm.notebook import tqdm


import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import sentencepiece as spm
import transformers
from transformers import AutoConfig
from transformers import AutoModelWithLMHead
from transformers import CTRLLMHeadModel

データ前処理

前回利用した青空文庫データを用いる。
前回学習したSentencePieceを利用したトークナイザーを利用。
青空文庫のデータのメタ情報を捨てているので、
学習する制御コードは以下の通りトークン数のみで決定している。

制御コード 内容
20トークン以内の短い文
80トークン以内の中程度の長さの文
80トークンよりも長い文

コード - 前処理

aozora_df = pd.read_csv("aozora.csv", index_col=0)
print(aozora_df.shape)
aozora_df.head()

model_file = "sp_model/test_model_016000.model"
sp_test = spm.SentencePieceProcessor()
sp_test.load(model_file)
all_ids = []
length_ids = []
token_length_limit = 256
for document in tqdm(aozora_df["text"].values):
  if isinstance(document, float):
    continue
  texts = document.split("\n")
  for text in texts:
    ids = sp_test.EncodeAsIds(text)
    if len(ids) <=token_length_limit - 3:
      all_ids.append(ids)
      length_ids.append(len(ids))


PAD_ID = 2
SEQ_LEN = token_length_limit

def label_map(ids):
  len_ids = len(ids)
  if len_ids <=20:
    return 2028
  elif len_ids <= 80:
    return 123
  else:
    return 150

class AozoraDataset(Dataset):
  def __init__(self, ids):
    self.ids = ids
    self.length = len(self.ids)
    self.pad_lengths = [SEQ_LEN - len(i) - 1  for i in self.ids]
    self.labels = [label_map(i) for i in self.ids]

  def __len__(self):
    return self.length

  def __getitem__(self, idx):
    return torch.Tensor([self.labels[idx]] + self.ids[idx] + [PAD_ID] * self.pad_lengths[idx]  ).long()



aozora_dataset = AozoraDataset(all_ids)
aozora_dataloader = DataLoader(aozora_dataset, batch_size=16, shuffle=True)

学習パラメータ

計算をColabで行ったためパラメータをかなり減らしている。

device = "cuda" if torch.cuda.is_available() else "cpu"

config = AutoConfig.from_pretrained("ctrl")
config.n_embd = 384
config.vocab_size = 15000
config.n_layer = 6

model = CTRLLMHeadModel(config).to(device)
optimizer = torch.optim.Adam(model.parameters())

コード - 学習

prg = tqdm(aozora_dataloader)
losses = []
loss_fct = nn.CrossEntropyLoss(ignore_index=PAD_ID)

for batch in prg:
  batch = batch.to(device)
  input_ids  = batch
  optimizer.zero_grad()
  lm_logits, _ = model(input_ids)
  shift_logits = lm_logits[..., :-1, :].contiguous()
  shift_labels = input_ids[..., 1:].contiguous()
  loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
  loss.backward()
  losses.append(loss.item())
  optimizer.step()
  if len(losses) > 10000:
    break

結果

10000Iter(1時間半ほど。 1epochは 180000Iterなので途中で終了している)学習した結果。

コード - 推論

max_length = 50
codes = ["短", "中", "長"]
prompt_text = "今日は"
results = []
for code in codes:
  code_id = sp_test.piece_to_id(code)
  encoded = torch.Tensor( [[code_id] + sp_test.encode_as_ids(prompt_text)]).long().to(device)
  generated = model.generate(encoded, max_length=max_length)
  gen_list=list(map(int, generated.cpu().numpy()[0]))
  out = sp_test.decode_ids(gen_list)
  results.append((code, out[1:]))
result_df = pd.DataFrame(results, columns=["制御コード", "結果"]).set_index("制御コード")
print(result_df.to_markdown())                         

初期文 "今日は"
制御コード 結果
今日は、もう、お天気で、お天気が、お天気で、お天気で、お坊さんは、お坊さんは、お坊さんの、お坊さんは、お目にかかりました。お坊さんは、お坊さんは
今日は、お坊さんは、お坊さんの妹《お》さん、お坊さん《おく》さんが、お坊さん《お》さんが、お坊さん《かあ》さんが、お坊さん《かあ》さんが、お坊さん《かあ》
今日は、私は、この頃の頃、私は、その頃、私は、私は、そのお坊さんの、お坊さんの妹と、お坊さんの妹と、お坊さんの妹と、お坊さんの母と、お
初期文 "私は正直だ。"
制御コード 結果
私は正直だ。そして、この話の気持を、その気持を、その気持を、その気持を、その気持を、その気持を、その気持を、その気持を、その気持を、自分の気持を、自分の
私は正直だ。私は、そのお互いに、お互いに、お互いに、お互いに、お互いに、お互いに、お互いに、お互いに、お互いに、お互いにお互いに、お互いに、お互いにお互いに、お互いに
私は正直だ。そして、そのことを、その気持に、その気持を、その気持を、その気持を、自分の知っている。そして、自分の知っている。そして、自分の知っていることは、その気持を、その気持を、自分の気持

同じ表現の繰り返しが多いが一応文章のようなものは生成できている。
制御コードによる違いはほとんどでていない。
単純に学習不足なだけか、そもそも文章の長さを制御コードの分割としていることが
不適切な可能性もある。

まとめと今後

今回は青空文庫データに対してCTRLの学習について検証した。
最も単純な状況における学習を検証したが、 Colaboだけでは学習資源が全く足りない。
単純にGCPで学習資源を増やして学習しても良いが、
それまでにもう少し学習の高速化について検討すべきである。

参考文献