NLP を学ぶ - 3
はじめに
今日はNLP関連の一般公開されているデータセットとその解析手法について調査する。
データセットの特性を理解することは機械学習モデルを構築するうえで重要である。
手頃に操作可能なサイズのデータセットに実際に触れることでそれらの手法を理解したい。
NLPのデータセットについて
torchtextのdataset APIのページに利用可能なデータセットの一覧がある。
https://pytorch.org/text/datasets.html
2020/4/3 (torchtext 0.5.1)の段階では以下のようになっている。
WikiText-2
WikiText-2データセットはWikipediaから抽出したデータセットである。
言語モデリングの検証用のデータセットである。
以下のtorchtextチュートリアルが分かりやすかった。
https://github.com/mjc92/TorchTextTutorial/blob/master/01.%20Getting%20started.ipynb
データの読み込み
import torch import spacy # Fieldの作成用 from torchtext import data, datasets # 1. Fieldというtokenizerとvocabularyなどをまとめたクラスを作成 TEXT = data.Field(tokenize=data.get_tokenizer('spacy'), init_token='<SOS>', eos_token='<EOS>',lower=True) # 2. torchtextがデータのダウンロード、前処理、学習、評価、テストへの分割などをやってくれる。 train,val,test = datasets.WikiText2.splits(text_field=TEXT)
通常のpytorchのデータセットはtorch.utils.data.Datasetを継承して、iter関数とlen関数を定義するだけなのだが、
torchtextの場合は以下の流れが必要なようだ。
1. Fieldの定義
2. Exampleの設定
WikiText2ではsplitsでExampleの設定は行ってくれている。
Fieldの設定は自分で行う必要がある。
Fieldの定義
Fieldはデータの前処理を行うobjectである。 vocabulary情報も保持可能。
ここでは
- spacyによるtokenize
- 文頭に開始文字
- 文末に終了文字
- 文字を小文字に変換
の処理を行っている。 pad処理(文章の長さを揃えるためにダミー文字で埋める処理)もFieldに設定するようだが今回は利用していない。
Example
Exampleとは1つのデータ単位である。
教師あり学習の場合は入力データとラベルのペアがExampleに対応する。
torchtextのdatasetの場合、このExampleのリストをメンバー変数(self.examples)として保持し dataset[i] でi番目のexampleにアクセルできる。 一つのExampleは辞書ライクなオブジェクトであり、今回のデータの場合は.textでテキスト情報にアクセス可能。
example = train[0] # train Datasetの 0番目のexampleにアクセス text = example.text # 0番目のexampleのテキストにアクセス
今回のデータセットは言語モデリング用のデータセットであるためラベルは存在しない。
また、WikiText2のDatasetでは1つのexampleにtext全体が入っている形式となっている。
データセットの解析
統計量
train, valid, testデータに対して簡単な統計量を求める。
import pandas as pd datasets = [train, val, test] rows = [] columns = ["統計量", "学習", "評価", "テスト"] # 各データセットの統計量 rows.append(["単語数"] + list(map(lambda x:len(x[0].text), datasets))) rows.append(["語彙数"] + list(map(lambda x:len(set(x[0].text)), datasets))) rows.append(["文章数"] + list(map(lambda x:x[0].text.count("<eos>"), datasets))) rows.append(["1語彙あたりの平均出現数"] + list(map(lambda x,y:int(x/y), rows[0][1:], rows[1][1:]))) rows.append(["1文あたりの平均単語数"] + list(map(lambda x,y:int(x/y), rows[0][1:], rows[2][1:]))) stat_df = pd.DataFrame(rows, columns=columns).set_index("統計量") print(stat_df.to_markdown())
統計量 | 学習 | 評価 | テスト |
---|---|---|---|
単語数 | 2.23665e+06 | 245042 | 280576 |
語彙数 | 28868 | 12044 | 12501 |
文章数 | 36718 | 3760 | 4358 |
1語彙あたりの平均出現数 | 77 | 20 | 22 |
1文あたりの平均単語数 | 60 | 65 | 64 |
学習、テスト、評価データで10:1:1程度のデータサイズであるが語彙数は2:1:1となっている。
語彙の豊富さについての指標も計算したほうが良い? 1
評価指標について
言語モデリングの評価指標にperplexityがある。
言語モデリングは文脈から単語を予測するモデルを構築することである。
perplexityは大まかにはある単語の予測確率の逆数であり小さいほど性能が良いこととなる。
wikitext-2データに対するperplexityは以下から見ることができる。
https://paperswithcode.com/sota/language-modelling-on-wikitext-2
2020/4/3現在ではGPT-2が20を下回るPerplexityを達成し、他を圧倒している。
(他のモデルは外部データを使っていないことも差が大きい原因?)
モデルの学習
# パラメータ設定 batch_size = 32 bptt_len = 30 device = "cuda" if torch.cuda.is_available() else "cpu" n_tokens = len(TEXT.vocab) n_inps = 128 n_hiddens = 64 n_layers = 1 dropout = 0 # BPTT学習用のデータloader作成 trainset, valset, testset = data.BPTTIterator.splits( (train, val, test), batch_size=batch_size, bptt_len=bptt_len, device=device)
# モデル構築 import torch.nn as nn import torch.nn.functional as F class RNNModel(nn.Module): def __init__(self, ntoken, ninp, nhid, nlayers, bsz, device, dropout=0.5, tie_weights=True): super(RNNModel, self).__init__() self.nhid, self.nlayers, self.bsz = nhid, nlayers, bsz self.drop = nn.Dropout(dropout) self.encoder = nn.Embedding(ntoken, ninp) self.rnn = nn.LSTM(ninp, nhid, nlayers, dropout=dropout) self.decoder = nn.Linear(nhid, ntoken) self.device = device self.hidden = self.init_hidden(bsz) def forward(self, input): emb = self.drop(self.encoder(input)) output, self.hidden = self.rnn(emb, self.hidden) output = self.drop(output) decoded = self.decoder(output.view(output.size(0)*output.size(1), output.size(2))) return decoded.view(output.size(0), output.size(1), decoded.size(1)) def init_hidden(self, bsz): return [torch.zeros([1, bsz, self.nhid]).to(self.device), torch.zeros([1, bsz, self.nhid]).to(self.device)] def reset_history(self, bsz): self.hidden = self.init_hidden(bsz) model = RNNModel(n_tokens, n_inps, n_hiddens, n_layers, batch_size, device, dropout=dropout) print(model)
# 学習 import torch.optim as optim from tqdm import tqdm_notebook.tqdm as tqdm criterion = nn.CrossEntropyLoss() optimizer = optim.Adam(model.parameters(), lr=1e-3, betas=(0.7, 0.99)) losses = [] # 1エポック文だけ学習 for batch in tqdm(trainset): model.reset_history(batch.batch_size) optimizer.zero_grad() text, targets = batch.text, batch.target prediction = model(text) loss = criterion(prediction.view(-1, n_tokens), targets.view(-1)) loss.backward() optimizer.step() losses.append(loss.item())
# 可視化 import pandas as pd loss_ser = pd.Series(losses, index=range(len(losses))) loss_ser.index.name = "step" # windows size 10 の ewmでスムージング loss_ser.ewm(span=10).mean().plot(title="batch loss in 1 epoch", logy=True)
# 予測 import numpy as np model.eval() for batch in valset: text, targets = batch.text, batch.target with torch.no_grad(): prediction = model(text) break pred_words = np.argmax(prediction.cpu().numpy(), axis=2) rows = [] for i in [5, 10, 20]: rows.append(["入力"] + [vocab.itos[t] for t in text[:, i]] + [""] ) rows.append(["正解"] + [""] + [vocab.itos[t] for t in targets[:, i]]) rows.append(["予測"] + [""] + [vocab.itos[t] for t in pred_words[:, i]]) print(pd.DataFrame(rows).set_index(0).to_markdown())
0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
入力 | we | ' | ve | never | had | a | race | driver | like | tim | in | stock | car | racing | . | he | was | almost | a | james | dean | @-@ | like | character | . | " | when | richmond | was | cast | |
正解 | ' | ve | never | had | a | race | driver | like | tim | in | stock | car | racing | . | he | was | almost | a | james | dean | @-@ | like | character | . | " | when | richmond | was | cast | for | |
予測 | to | t | " | a | been | < | of | . | the | , | the | , | . | . | |
was | the | the | < | @-@ | of | < | < | , | |
|
the | the | also | to | |
入力 | also | serves | as | evacuation | route | from | cape | may | county | to | inland | areas | in | the | event | of | a | hurricane | . | |
|
= | = | history | = | = | |
||||
正解 | serves | as | evacuation | route | from | cape | may | county | to | inland | areas | in | the | event | of | a | hurricane | . | |
|
= | = | history | = | = | |
|
||||
予測 | in | the | , | , | the | , | be | , | the | , | . | the | < | . | the | < | . | |
|
= | = | = | = | = | = | |
|||||
入力 | of | < | unk | > | 2000 | . | |
< | unk | > | gained | fame | as | attorney | of | < | unk | > | < | unk | > | , | a | kurdish | national | who | was | first | charged | ||
正解 | < | unk | > | 2000 | . | |
< | unk | > | gained | fame | as | attorney | of | < | unk | > | < | unk | > | , | a | kurdish | national | who | was | first | charged | in | ||
予測 | the | unk | > | , | . | |
|
unk | > | , | the | , | the | , | the | unk | > | , | unk | > | , | and | < | of | < | was | the | in | . |
1エポック程度の学習では適切な言語モデルは学習できていないがunknown文字や一部記号などの頻出文字は正解している部分もある。
AG-NEWS
AG-NEWSデータセットはWikipediaから抽出したデータセットである。
ニュース記事を世界、経済、スポーツ、科学技術の4クラスに分類してあるため分類の検証に利用できる。
torchtext公式ページにチュートリアルがあった。
https://pytorch.org/tutorials/beginner/text_sentiment_ngrams_tutorial.html
データ解析
データ読み込み
import os import torch import torchtext from torchtext.datasets import text_classification data_dir = "./.data" NGRAMS = 1 os.makedirs(data_dir, exist_ok=True) train_dataset, test_dataset = text_classification.DATASETS['AG_NEWS']( root=data_dir, ngrams=NGRAMS, vocab=None)
text_classification データセットはFieldを利用していない。(なぜ?)
データサンプル (クラス 0:世界、 1:経済、 2:スポーツ、3: 科学技術)
label | text |
---|---|
3 | china launches mapping and surveying satellite china sunday launched a satellite that will carry out land surveys and mapping for several days before returning to earth , the xinhua news agency said . |
0 | indonesia #39 s bashir retrial gets underway indonesia #39 s prominent muslim preacher abu bakr bashir has been put on trial again , charged once more over the marriot hotel attack in 2003 and the bali bombing in 2002 . |
2 | executive shake-up unveiled at cbs , paramount tv in a major shake-up of its west coast programing operations , media giant viacom inc . on tuesday promoted several cbs executives to new roles at the broadcaster and a newly merged television studio . |
1 | jets edwards says pennington is still hurting jets coach herman edwards said yesterday that chad pennington still has some discomfort in his right arm , lingering effects of a strained rotator cuff . |
2 | unions rally resistance at jaguar unions hold emergency meetings with workers at jaguar ' s doomed browns lane plant in coventry to fight closure plans . |
2 | dot orders fedex to repay \$29 million memphis , tenn . memphis-based fedex is challenging a transportation department order to repay 29 ( m ) million dollars . it #39 s part of the federal money the package carrier received after the 2001 terrorist attacks shut down flights . |
統計量 | 学習 | テスト |
---|---|---|
文章数 | 120000 | 7600 |
単語数 | 5193609 | 327070 |
語彙数 | 95810 | 22447 |
一文あたりが平均50単語程度の長さである。
ラベル | 学習 | テスト |
---|---|---|
0 | 30000 | 1900 |
1 | 30000 | 1900 |
2 | 30000 | 1900 |
3 | 30000 | 1900 |
ラベルは均等に分布している。
学習モデル
import torch.nn as nn import torch.nn.functional as F class TextSentiment(nn.Module): def __init__(self, vocab_size, embed_dim, num_class): super().__init__() self.embedding = nn.EmbeddingBag(vocab_size, embed_dim, sparse=True) self.fc = nn.Linear(embed_dim, num_class) self.init_weights() def init_weights(self): initrange = 0.5 self.embedding.weight.data.uniform_(-initrange, initrange) self.fc.weight.data.uniform_(-initrange, initrange) self.fc.bias.data.zero_() def forward(self, text, offsets): embedded = self.embedding(text, offsets) return self.fc(embedded)
基本的には文章をembeddingして全結合層に通して分類しているだけだが、
EmbeddingBagという処理が使われている。
これは文章をその文章を構成する単語のEmbeddingの和として効率よくEmbeddingするものであり、
文章の系列情報が必要ないテキスト分類に利用されるようだ。
以下の記事の解説が分かりやすかった。
https://www.dskomei.com/entry/2019/12/26/225959
学習パラメータは基本的にチュートリアルのものと同じだが、
エポック数を5→10
NGRAMを 2 → 1
というように変更している。
結果
accuracyの推移を見ると、学習用データに対するaccuracyは上昇傾向で、
評価用データに対するaccuracyは1エポック学習した後はほぼ横ばいである。
チュートリアルページのNGRAM=2のtest accuracyが0.905で今回の
NGRAM=1のtest accuracyが0.901であるのでやや過学習している可能性はある。
テストデータに対する混同行列
正解\予測 | 世界 | 経済 | スポーツ | 科学技術 |
---|---|---|---|---|
世界 | 1694 | 66 | 95 | 45 |
経済 | 17 | 1855 | 23 | 5 |
スポーツ | 70 | 18 | 1715 | 97 |
科学技術 | 58 | 27 | 228 | 1587 |
混同行列を見る限り、科学技術の記事をスポーツ記事と予測している誤判定が多い。
誤判定データ
正解 : 科学技術→予測 : スポーツ | |
---|---|
5312 | microsoft makes another antitrust deal software giant settles with novell and the ccia , ending years of legal wrangling . |
5546 | four in court over sql theft four former microsoft employees have been charged with stealing |
173 | small computers can have multiple personalities , too boston the jury is still out on whether a computer can ever truly be intelligent , but there is no question that it can have multiple personalities . it #39 s just a matter of software . we usually think of the processor chip as the brains of a computer . the . . . |
1884 | nortel lowers expectations nortel said it expects revenue for the third quarter to fall short of expectations . |
3060 | you have mail , always , with a blackberry washington lawyer william wilhelm knows from experience that not everybody loves his blackberry as much as he does . the girlfriend was fed up with a relationship |
3800 | google unveils desktop search , takes on microsoft google inc . ( goog . o quote , profile , research ) on thursday rolled out a preliminary version of its new desktop search tool , making the first move against |
9 | card fraud unit nets 36 , 000 cards in its first two years , the uk ' s dedicated card fraud unit , has recovered 36 , 000 stolen cards and 171 arrests - and estimates it saved 65m . |
4622 | korean and japanese phone makers win -survey amsterdam ( reuters ) - south korean mobile phone makers continued a rapid move up the global market rankings during the third quarter , while growth in the wider mobile phone market slowed , a survey found on wednesday . |
4058 | regulators approve artificial heart the food and drug administration approved the use of an artificial heart made by syncardia systems as a temporary device for people awaiting transplants . |
6278 | update peoplesoft board won ' t negotiate oracle takeover peoplesoft executives said over the weekend that they won ' t discuss a sale to oracle at a price of \$24 per share but would consider an offer at a higher price . |
感覚的には普通に科学の記事に見えるため、なぜ誤判定しているか理解できない。
何が原因でこのような結果となったかを解析する技術についても今後調査していきたい。
まとめと今後
今回は WikiText-2データセット(言語モデル)とAG-NEWSデータセット(文章分類)に
実際に触れてみた。 次回は別のタスクのデータセットに触れてみる。
また、今後は結果の解析や解釈の手法についても調査をしていきたい。