BERTの推論を高速化したいです。
C++でも動作可能かつ、グラフを静的にできるTorchScriptを使用すればできるぞ
TorchScriptはPyTorchのコードをモデルを最適化して、C++のコードでも動作可能にする機能です。
https://pytorch.org/docs/stable/jit.html
動作環境
Google Colabで動作確認しました。
日本語のBERTモデルをTorchScriptで変換
MASK部分を予測する日本語のBERTモデルを使用して動作確認をします。
モデルを取得する際にtorchscript=True
を設定します。
Top 5の候補を予測してみます。
from transformers import AutoModelWithLMHead, AutoTokenizer
import torch
tokenizer = AutoTokenizer.from_pretrained("cl-tohoku/bert-base-japanese")
model = AutoModelWithLMHead.from_pretrained("cl-tohoku/bert-base-japanese", torchscript=True)
sequence = f"ジムは{tokenizer.mask_token}ですか? "
input = tokenizer.encode(sequence, return_tensors="pt")
mask_token_index = torch.where(input == tokenizer.mask_token_id)[1]
token_logits = model(input)[0]
mask_token_logits = token_logits[0, mask_token_index, :]
top_5_tokens = torch.topk(mask_token_logits, 5, dim=1).indices[0].tolist()
for i in range(len(top_5_tokens)):
print(tokenizer.decode(top_5_tokens[i]))
下記のような候補が確認できました。5番目の候補が正解に近いと思います。
ど こ
何
な ん
ど う
誰
速度を計測してみます。
%%time
for i in range(100):
model(input)
下記のような結果になりました。
CPU times: user 12.1 s, sys: 29.8 ms, total: 12.1 s
Wall time: 12.1 s
続いてTorchScriptを使用してモデルを変換して保存します。
traced_model = torch.jit.trace(model, input)
torch.jit.save(traced_model, "traced_bert.pt")
同一に動作するか確認してみます。
token_logits = traced_model(input)[0]
mask_token_logits = token_logits[0, mask_token_index, :]
top_5_tokens = torch.topk(mask_token_logits, 5, dim=1).indices[0].tolist()
for i in range(len(top_5_tokens)):
print(tokenizer.decode(top_5_tokens[i]))
下記のような結果になっているので、同一結果を取得できています。
ど こ
何
な ん
ど う
誰
同様に速度計測をしてみます。
%%time
for i in range(100):
traced_model(input)
下記のような結果になり、8%ほど速度向上しています。
CPU times: user 11.2 s, sys: 16.1 ms, total: 11.2 s
Wall time: 11.2 s