hajimemath blog

LLama2をGoogle Colaboratoryでfine tuningしてみよう!

公開日時:2023.12.01

背景

GIG inc Advent Calendar 2023 1日目です!
https://qiita.com/advent-calendar/2023/gig_inc

最近LLMが話題になっているのは皆さんもご存知だとは思いますが、私も例に漏れず触ってみたいとなったので、LLMをfine tuningしてみたいなと思います!

qloraを使用してfine tuningを行おうと思いましたが、qloraはApple Silicon CPUだとGPUの関係でfine tuningは難しいらしいので、
Google Colaboratoryを使用してfine tuningしていきたいと思います。

https://colab.research.google.com/?hl=ja

導入方法

1.GPUの選択
メニュー「編集→ノートブックの設定」で、「ハードウェアアクセラレータ」で「T4 GPU」を選択。

2.Googleドライブのマウント

from google.colab import drive
drive.mount("/content/drive")


3.作業フォルダへの移動

import os
os.makedirs("/content/drive/My Drive/work", exist_ok=True)
%cd "/content/drive/My Drive/work"


4.qloraパッケージのインストール

!git clone https://github.com/artidoro/qlora
%cd qlora
!pip install -U -r requirements.txt


5.HuggingFaceのログイン

リンクからHuggingFace Hubのトークンを取得し、「Token:」に入力してください。

!huggingface-cli login

6.HuggingFaceへデータセットをアップロード

必要であればHuggingFaceへjsonをアップロードする。
データセットをHuggingFaceにアップロードし、そのデータセットを参照します。

7.「qlora.py」の編集

qlora.pyを編集します。
qlora.pyは google drive/work/qlora/qlora.py にあります。
HuggingFaceにあるdatasetの識別子に書き換えてあげます。


if dataset_name == 'alpaca':
      return load_dataset({データセットの識別子}) # 書き換え



例として、以下のデータセットの場合は bbz662bbz/databricks-dolly-15k-ja-gozarinnemon
https://huggingface.co/datasets/bbz662bbz/databricks-dolly-15k-ja-gozarinnemon


           dataset = dataset.map(
                extract_alpaca_dataset, remove_columns=["instruction"]
            )
            print("dataset:", dataset["train"][0])  # 追加



model.config.use_cache = False
model.config.pretraining_tp = 1  # 追加



データセットのパラメータ数によって指定する値も変わってきます。
・7B: config.pretraining_tp=1
・13B: config.pretraining_tp=2

8.学習の実行

!python qlora.py \
    --model_name meta-llama/Llama-2-7b-hf \ # HuggingFace上のモデルを参照する
    --output_dir "./output/test_llama2" \ # google drive上の出力先を指定
    --dataset "alpaca" \ # "alpaca"を指定することで、上記のデータセットを読み込む
    --max_steps 1000 \
    --use_auth \
    --logging_steps 10 \
    --save_strategy steps \
    --data_seed 42 \
    --save_steps 50 \
    --save_total_limit 40 \
    --max_new_tokens 32 \
    --dataloader_num_workers 1 \
    --group_by_length \
    --logging_strategy steps \
    --remove_unused_columns False \
    --do_train \
    --lora_r 64 \
    --lora_alpha 16 \
    --lora_modules all \
    --double_quant \
    --quant_type nf4 \
    --fp16 \ # T4 GPUだとbfは使用できないので、fp16で対応する
    --bits 4 \
    --warmup_ratio 0.03 \
    --lr_scheduler_type constant \
    --gradient_checkpointing \
    --source_max_len 16 \
    --target_max_len 512 \
    --per_device_train_batch_size 1 \
    --gradient_accumulation_steps 16 \
    --eval_steps 187 \
    --learning_rate 0.0002 \
    --adam_beta2 0.999 \
    --max_grad_norm 0.3 \
    --lora_dropout 0.1 \
    --weight_decay 0.0 \
    --seed 0 \
    --load_in_4bit \
    --use_peft \
    --batch_size 4 \
    --gradient_accumulation_steps 2


9.トークナイザーとモデルの読み込み
推論を行うためにトークナイザーとモデルの読み込みを行います。

import torch
from peft import PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig

# トークナイザーとモデルの読み込み
tokenizer = AutoTokenizer.from_pretrained(
    "meta-llama/Llama-2-7b-hf"
)
model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-2-7b-hf",
    quantization_config=BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_use_double_quant=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.bfloat16
    ),
    device_map={"":0}
)



10.LoRAの読み込み

model = PeftModel.from_pretrained(
    model,
    "./output/test_llama2/checkpoint-1000/adapter_model/", # 学習の実行時に指定したパスを書く
    device_map={"":0}
)
model.eval()



11.推論の実行

# プロンプトの準備
prompt = "### Instruction: 富士山とは?\n\n### Response: "

# 推論の実行
inputs = tokenizer(prompt, return_tensors="pt").to("cuda:0")
with torch.no_grad():
    outputs = model.generate(**inputs, max_new_tokens=100)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))


まとめ

Google Colaboratoryをまともに触ったのは初めてなのですが、LLama2の7b程度であれば、fine tuningできるのがすごすぎるというのが素直な感想です。

大いに参考にさせていただいたサイト様

https://note.com/npaka/n/na7c631175111