制御可能な文章生成RAG - RAG学習スクリプト

はじめに

前回 まで単純なデータセットに対してRAG1の学習を行ってきたが、
RAGのモデルを学習するまでの一連の処理をまとめてgithubで公開した。

https://github.com/NeverendingNotification/rag-japanese.git

RAG学習一連の処理

RAGの学習を行うためには以下の処理を行う必要がある。

  1. BERT学習済みモデルの準備
  2. データの前処理
  3. DPR2の学習
  4. DPRモデルのtransformers3モデルへの変換
  5. 知識文章のindex化

BERT学習済みモデルの準備

今回は以下のモデルを利用させてもらった。

https://github.com/cl-tohoku/bert-japanese

前回まではtokenizerの設定が正しくできておらず、日本語のtokenizerが
機能していなかったので出力結果から濁点が消えたりしていた。

今回はDPRもBART4も全て、
transformers.tokenization_bert_japanese.BertJapaneseTokenizer
を利用することで、適切な日本語のtokenizeを行う。

# rag-japanese
python make_small_bert.py --pretrained-model cl-tohoku/bert-base-japanese-whole-word-masking --out-dir models/small_bert --num-layers 3

データの前処理

知識文章と質問・回答文章を含んだcsvファイルから、DPR学習を行うための
jsonファイルを作成する。 知識文章のcsvは通し番号と文章の列があればよいが、
質問・回答csvは質問文・回答文の列に加えて、その質問と関連する知識文章の
情報が必要である。 具体的には、各質問・回答ペアに対して、 Positive,
Negative, Hard-Negativeに対応する知識文章の番号情報が必要である。

質問・回答csv
質問 回答 Positive Negative Hard-Negative
0 北海道の人口は538万人くらいですか? 538万人くらいです。 [979] [836, 937, 11] [96, 629, 288]
1 北海道の人口は708万人くらいですか? 708万人よりも少ないです。 [979] [170, 996, 272] [979, 779, 288]
2 北海道の人口は642万人くらいですか? 642万人くらいです。 [979] [496, 318, 870] [288, 679, 96]
3 北海道の人口は530万人くらいですか? 530万人くらいです。 [979] [793, 247, 575] [929, 431, 779]
4 青森県の人口は132万人くらいですか? 132万人くらいです。 [980] [693, 69, 302] [980, 145, 49]

​ また、前回までは都道府県ごとに学習・評価を分けていたが、今回はランダムに分割している。

# rag-japanese
python preprocess_data.py --knowledge-file data/knowledge.csv --qa-file data/qa.csv --out-file data/dpr_qa.json --valid-split --out-csv

DPRの学習

DPRのモデルを学習する。
Facebookの実装を元に、日本語tokenizerなど一部修正している。

https://github.com/facebookresearch/DPR

# rag-japanese/dpr
python train_dense_encoder.py --train_file ../data/dpr_qa_train.json --dev_file ../data/dpr_qa_valid.json --encoder_model_type hf_bert --pretrained_model_cfg ../models/small_bert --batch_size 8 --output_dir ../models/dpr --num_train_epochs 6

DPRモデルの変換

学習したDPRモデルをtransformersライブラリのDPR形式に変換する。

# rag-japanese/dpr
python convert_model.py -p ../models/dpr/dpr_biencoder.5.386 -o ../models/dpr_transformers

文章情報のindex化

知識文章を学習したDPRのcontext encoderによってembeddingして、
faiss5によりindex化する。

# rag-japanese
python make_index.py  --context-model models/dpr_transformers/c_encoder --knowledge-file data/knowledge.csv --out-dir data/dpr_knowlege_index

RAG学習

ここまで作成した、DPRとindexを利用して、RAGの学習を行う。

# rag-japanese
python train_model.py  --model-type rag --question-model models/dpr_transformers/q_encoder --train-csv data/dpr_qa_train.csv --valid-csv data/dpr_qa_valid.csv --indexdata-path data/dpr_knowlege_index/knowlege --index-path data/dpr_knowlege_index/knowlege_index.faiss --out-dir  results/rag

RAG推論

学習済みRAGモデルから、テストデータに対して、RAGによる文章推論を行う。

# rag-japanese
python test_model.py  --model-type rag --pretrained-model results/rag --test-csv data/dpr_qa_valid.csv --indexdata-path data/dpr_knowlege_index/knowlege --index-path data/dpr_knowlege_index/knowlege_index.faiss --out-dir results/rag --out-file test.csv

出力csvファイル例

df = pd.read_csv(result_csv_file, index_col=0)
print(df.set_index("質問").head(3).T.to_markdown())
2010年から2015年で千葉県の人口は変わっていますか? 1940年から2005年で広島県の人口は変わっていますか? 1950年から2010年で鹿児島県の人口は変わっていますか?
回答 同じくらいです。 増えています。 同じくらいです。
返答 同じ くらい です 。 増え て い ます 。 同じ くらい です 。
関連1 2010年の千葉県の人口は621万人です。 1940年の広島県の人口は186万人です。 1950年の鹿児島県の人口は180万人です。
関連2 2015年の千葉県の人口は622万人です。 2005年の広島県の人口は287万人です。 2010年の鹿児島県の人口は170万人です。
関連3 2010年の埼玉県の人口は719万人です。 1940年の徳島県の人口は71万人です。 2015年の鹿児島県の人口は164万人です。
関連4 2010年の栃木県の人口は200万人です。 1930年の広島県の人口は169万人です。 1955年の鹿児島県の人口は204万人です。
関連5 2010年の富山県の人口は109万人です。 1935年の広島県の人口は180万人です。 2000年の鹿児島県の人口は178万人です。

参考文献