制御可能な文章生成RAG - RAG学習スクリプト
はじめに
前回 まで単純なデータセットに対してRAG1の学習を行ってきたが、
RAGのモデルを学習するまでの一連の処理をまとめてgithubで公開した。
https://github.com/NeverendingNotification/rag-japanese.git
RAG学習一連の処理
RAGの学習を行うためには以下の処理を行う必要がある。
BERT学習済みモデルの準備
今回は以下のモデルを利用させてもらった。
https://github.com/cl-tohoku/bert-japanese
前回まではtokenizerの設定が正しくできておらず、日本語のtokenizerが
機能していなかったので出力結果から濁点が消えたりしていた。
今回はDPRもBART4も全て、
transformers.tokenization_bert_japanese.BertJapaneseTokenizer
を利用することで、適切な日本語のtokenizeを行う。
# rag-japanese python make_small_bert.py --pretrained-model cl-tohoku/bert-base-japanese-whole-word-masking --out-dir models/small_bert --num-layers 3
データの前処理
知識文章と質問・回答文章を含んだcsvファイルから、DPR学習を行うための
jsonファイルを作成する。 知識文章のcsvは通し番号と文章の列があればよいが、
質問・回答csvは質問文・回答文の列に加えて、その質問と関連する知識文章の
情報が必要である。 具体的には、各質問・回答ペアに対して、 Positive,
Negative, Hard-Negativeに対応する知識文章の番号情報が必要である。
質問・回答csv例
質問 | 回答 | Positive | Negative | Hard-Negative | |
---|---|---|---|---|---|
0 | 北海道の人口は538万人くらいですか? | 538万人くらいです。 | [979] | [836, 937, 11] | [96, 629, 288] |
1 | 北海道の人口は708万人くらいですか? | 708万人よりも少ないです。 | [979] | [170, 996, 272] | [979, 779, 288] |
2 | 北海道の人口は642万人くらいですか? | 642万人くらいです。 | [979] | [496, 318, 870] | [288, 679, 96] |
3 | 北海道の人口は530万人くらいですか? | 530万人くらいです。 | [979] | [793, 247, 575] | [929, 431, 779] |
4 | 青森県の人口は132万人くらいですか? | 132万人くらいです。 | [980] | [693, 69, 302] | [980, 145, 49] |
また、前回までは都道府県ごとに学習・評価を分けていたが、今回はランダムに分割している。
# rag-japanese python preprocess_data.py --knowledge-file data/knowledge.csv --qa-file data/qa.csv --out-file data/dpr_qa.json --valid-split --out-csv
DPRの学習
DPRのモデルを学習する。
Facebookの実装を元に、日本語tokenizerなど一部修正している。
https://github.com/facebookresearch/DPR
# rag-japanese/dpr python train_dense_encoder.py --train_file ../data/dpr_qa_train.json --dev_file ../data/dpr_qa_valid.json --encoder_model_type hf_bert --pretrained_model_cfg ../models/small_bert --batch_size 8 --output_dir ../models/dpr --num_train_epochs 6
DPRモデルの変換
学習したDPRモデルをtransformersライブラリのDPR形式に変換する。
# rag-japanese/dpr python convert_model.py -p ../models/dpr/dpr_biencoder.5.386 -o ../models/dpr_transformers
文章情報のindex化
知識文章を学習したDPRのcontext encoderによってembeddingして、
faiss5によりindex化する。
# rag-japanese python make_index.py --context-model models/dpr_transformers/c_encoder --knowledge-file data/knowledge.csv --out-dir data/dpr_knowlege_index
RAG学習
ここまで作成した、DPRとindexを利用して、RAGの学習を行う。
# rag-japanese python train_model.py --model-type rag --question-model models/dpr_transformers/q_encoder --train-csv data/dpr_qa_train.csv --valid-csv data/dpr_qa_valid.csv --indexdata-path data/dpr_knowlege_index/knowlege --index-path data/dpr_knowlege_index/knowlege_index.faiss --out-dir results/rag
RAG推論
学習済みRAGモデルから、テストデータに対して、RAGによる文章推論を行う。
# rag-japanese python test_model.py --model-type rag --pretrained-model results/rag --test-csv data/dpr_qa_valid.csv --indexdata-path data/dpr_knowlege_index/knowlege --index-path data/dpr_knowlege_index/knowlege_index.faiss --out-dir results/rag --out-file test.csv
出力csvファイル例
df = pd.read_csv(result_csv_file, index_col=0) print(df.set_index("質問").head(3).T.to_markdown())
2010年から2015年で千葉県の人口は変わっていますか? | 1940年から2005年で広島県の人口は変わっていますか? | 1950年から2010年で鹿児島県の人口は変わっていますか? | |
---|---|---|---|
回答 | 同じくらいです。 | 増えています。 | 同じくらいです。 |
返答 | 同じ くらい です 。 | 増え て い ます 。 | 同じ くらい です 。 |
関連1 | 2010年の千葉県の人口は621万人です。 | 1940年の広島県の人口は186万人です。 | 1950年の鹿児島県の人口は180万人です。 |
関連2 | 2015年の千葉県の人口は622万人です。 | 2005年の広島県の人口は287万人です。 | 2010年の鹿児島県の人口は170万人です。 |
関連3 | 2010年の埼玉県の人口は719万人です。 | 1940年の徳島県の人口は71万人です。 | 2015年の鹿児島県の人口は164万人です。 |
関連4 | 2010年の栃木県の人口は200万人です。 | 1930年の広島県の人口は169万人です。 | 1955年の鹿児島県の人口は204万人です。 |
関連5 | 2010年の富山県の人口は109万人です。 | 1935年の広島県の人口は180万人です。 | 2000年の鹿児島県の人口は178万人です。 |