BERTで推論を高速化できるTorchScriptを試してみる

BERTの推論を高速化したいです。

C++でも動作可能かつ、グラフを静的にできるTorchScriptを使用すればできるぞ

TorchScriptはPyTorchのコードをモデルを最適化して、C++のコードでも動作可能にする機能です。

https://pytorch.org/docs/stable/jit.html

動作環境

Google Colabで動作確認しました。

日本語のBERTモデルをTorchScriptで変換

MASK部分を予測する日本語のBERTモデルを使用して動作確認をします。

モデルを取得する際にtorchscript=Trueを設定します。

Top 5の候補を予測してみます。

from transformers import AutoModelWithLMHead, AutoTokenizer
import torch

tokenizer = AutoTokenizer.from_pretrained("cl-tohoku/bert-base-japanese")
model = AutoModelWithLMHead.from_pretrained("cl-tohoku/bert-base-japanese", torchscript=True)

sequence = f"ジムは{tokenizer.mask_token}ですか? "

input = tokenizer.encode(sequence, return_tensors="pt")
mask_token_index = torch.where(input == tokenizer.mask_token_id)[1]

token_logits = model(input)[0]
mask_token_logits = token_logits[0, mask_token_index, :]

top_5_tokens = torch.topk(mask_token_logits, 5, dim=1).indices[0].tolist()

for i in range(len(top_5_tokens)):
    print(tokenizer.decode(top_5_tokens[i]))

下記のような候補が確認できました。5番目の候補が正解に近いと思います。

ど こ
何
な ん
ど う
誰

速度を計測してみます。

%%time
for i in range(100):
    model(input)

下記のような結果になりました。

CPU times: user 12.1 s, sys: 29.8 ms, total: 12.1 s
Wall time: 12.1 s

続いてTorchScriptを使用してモデルを変換して保存します。

traced_model = torch.jit.trace(model, input)
torch.jit.save(traced_model, "traced_bert.pt")

同一に動作するか確認してみます。

token_logits = traced_model(input)[0]
mask_token_logits = token_logits[0, mask_token_index, :]

top_5_tokens = torch.topk(mask_token_logits, 5, dim=1).indices[0].tolist()
for i in range(len(top_5_tokens)):
    print(tokenizer.decode(top_5_tokens[i]))

下記のような結果になっているので、同一結果を取得できています。

ど こ
何
な ん
ど う
誰

同様に速度計測をしてみます。

%%time
for i in range(100):
    traced_model(input)

下記のような結果になり、8%ほど速度向上しています。

CPU times: user 11.2 s, sys: 16.1 ms, total: 11.2 s
Wall time: 11.2 s

Close Bitnami banner
Bitnami