nownab.log

nownabe's daily posts

Windows 11 の WSL で GPU を使って rinna InstructGPT

Posted on Jun 25, 2023

はじめに

最近、念願のつよつよ GPU がついた PC を新調して WSL で環境構築を頑張っている。今回は GPU を使った LLM の推論を試した。

ここでの GPU は NVIDIA のもので、GPU の環境構築は WSL で CUDA を使えるようにすることを意味する。また、WSL の Distribution は Ubuntu-22.04。

LLM としては rinna 社の日本語特化 InstructGPT を使った。

GPU on WSL

基本的に この手順 に従って進めれば WSL で GPU が使えるようになる。具体的には、Windows 11 へ WSL 対応 NVIDIA ドライバのインストール、WSL 内で CUDA Toolkit インストールの 2 点。

NVIDIA ドライバのインストールは NVIDIA のドライバダウンロードサイトで Windows 11 のものを選んでダウンロードしてインストールする。最新のものであれば WSL 2 の CUDA サポートが入っている。自分の環境だとこんな感じ。

nvidia-driver-download

Windows 側でドライバをインストールすると、WSL 側でマウントしている /usr/lib/wsl/lib に CUDA の共有ライブラリなどが見えるようになる。

CUDA Toolkit は ダウンロードページから WSL 用のものをインストールする。これは、Windows 側でインストールした NVIDIA ドライバが提供するファイルを上書きしないように配慮されている。自分の環境だとこんな感じ。

install-cuda-toolkit

ただし、表示されたコマンドのまま進めると PyTorch が対応していない最新の CUDA (今だと CUDA 12.2) がインストールされてしまうので、こんな感じで対応しているバージョンをインストールする。

wget https://developer.download.nvidia.com/compute/cuda/repos/wsl-ubuntu/x86_64/cuda-keyring_1.0-1_all.deb
sudo dpkg -i cuda-keyring_1.0-1_all.deb
sudo apt-get update
sudo apt-get -y install cuda-11-8

nvidia-smi コマンドも使えるようになっているはずなので確認ができる。

❯ nvidia-smi
Sun Jun 25 12:15:35 2023
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.54.04              Driver Version: 536.23       CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  NVIDIA GeForce RTX 4090        On  | 00000000:01:00.0  On |                  Off |
|  0%   45C    P8              39W / 450W |   1184MiB / 24564MiB |      5%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+

+---------------------------------------------------------------------------------------+
| Processes:                                                                            |
|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
|        ID   ID                                                             Usage      |
|=======================================================================================|
|  No running processes found                                                           |
+---------------------------------------------------------------------------------------+

InstructGPT

今回コンテナは使わず Poetry の仮想環境で JupyterLab を起動して、Windows のブラウザからアクセスして rinna InstructGPT を使用した。

Poetry のインストールは asdf で。

asdf install poetry 1.5.1
asdf global poetry 1.5.1

適当にディレクトリを作って poetry init

mkdir ~/jupyter
cd ~/jupyter
poetry init

pyproject.toml の dependencies はこんな感じ。urlcp310 は Python のバージョン。

[tool.poetry.dependencies]
python = "^3.10"
jupyterlab = "^4.0.2"
torch = {url = "https://download.pytorch.org/whl/cu118/torch-2.0.1%2Bcu118-cp310-cp310-linux_x86_64.whl"}
transformers = "^4.30.2"
sentencepiece = "^0.1.99"

依存ライブラリをインストールして起動する。

poetry install
poetry run jupyter lab

Windows 側のブラウザから http://localhost:8888/ にアクセスする。トークンは JupyterLab を起動したコンソールにログとして表示されている。

まずはサンプルをそのまま動かしてみた。

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained("rinna/japanese-gpt-neox-3.6b-instruction-ppo", use_fast=False)
model = AutoModelForCausalLM.from_pretrained("rinna/japanese-gpt-neox-3.6b-instruction-ppo")
model = model.to("cuda")

prompt = [
    {
        "speaker": "ユーザー",
        "text": "コンタクトレンズを慣れるにはどうすればよいですか?"
    },
    {
        "speaker": "システム",
        "text": "これについて具体的に説明していただけますか?何が難しいのでしょうか?"
    },
    {
        "speaker": "ユーザー",
        "text": "目が痛いのです。"
    },
    {
        "speaker": "システム",
        "text": "分かりました、コンタクトレンズをつけると目がかゆくなるということですね。思った以上にレンズを外す必要があるでしょうか?"
    },
    {
        "speaker": "ユーザー",
        "text": "いえ、レンズは外しませんが、目が赤くなるんです。"
    }
]
prompt = [
    f"{uttr['speaker']}: {uttr['text']}"
    for uttr in prompt
]
prompt = "<NL>".join(prompt)
prompt = (
    prompt
    + "<NL>"
    + "システム: "
)

token_ids = tokenizer.encode(prompt, add_special_tokens=False, return_tensors="pt")

with torch.no_grad():
    output_ids = model.generate(
        token_ids.to(model.device),
        do_sample=True,
        max_new_tokens=128,
        temperature=0.7,
        repetition_penalty=1.1,
        pad_token_id=tokenizer.pad_token_id,
        bos_token_id=tokenizer.bos_token_id,
        eos_token_id=tokenizer.eos_token_id
    )

output = tokenizer.decode(output_ids.tolist()[0][token_ids.size(1):])
output = output.replace("<NL>", "\n")
print(output)

ちゃんした日本語が出力される。

わかりました。コンタクトレンズを長時間つけっぱなしにしている場合は、目の周りに水ぶくれができやすくなります。また、コンタクトレンズを長時間つけたままにしておくことで、目にゴミやほこりが入りやすくなり、それらが刺激となってかゆみを引き起こすことがあります。そのため、コンタクトレンズを付ける前には必ず手を洗って、清潔な状態にすることが大切です。</s>

こんな関数を作っていろいろ遊んでみた。

def chat(prompt):
    prompt = f"ユーザー: {prompt}<NL>システム: "
    token_ids = tokenizer.encode(prompt, add_special_tokens=False, return_tensors="pt")

    with torch.no_grad():
        output_ids = model.generate(
            token_ids.to(model.device),
            do_sample=True,
            max_new_tokens=128,
            temperature=0.7,
            repetition_penalty=1.1,
            pad_token_id=tokenizer.pad_token_id,
            bos_token_id=tokenizer.bos_token_id,
            eos_token_id=tokenizer.eos_token_id
        )

    output = tokenizer.decode(output_ids.tolist()[0][token_ids.size(1):])
    return output.replace("<NL>", "\n")

推論してるときのリソースはこんな感じ。

gpu-performance

おわりに

ローカルで LLM 動くの楽しい!