Hugging FaceのBERTモデルの推論をTorch TensorRTで高速化

Hugging Faceのモデルで高速に推論したいです。

Torch TensorRTが対応しているぞ

事前準備

Torch TensorRTに関しては下記記事をご覧ください

Torch TensorRTの環境はdockerを使用すると簡単に用意できます。

https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch

docker run --gpus all -it --rm -p 8887:8887 --name tensorrt nvcr.io/nvidia/pytorch:22.07-py3

下記コードを参考に実行します。

https://github.com/pytorch/TensorRT/blob/master/notebooks/Hugging-Face-BERT.ipynb

Docker内にすでにコードがあるので、そのコードを使用します。

Jupyter-labを起動してコードにアクセスします。

jupyter-lab --port=8887

/workspace/examples/torch_tensorrt/notebooksディレクトリにHugging-Face-BERT.ipynbのNotebookがあるのでこちらで検証します。

ライブラリの導入

Hugging Faceのライブラリをインストールします。

!pip install transformers

ライブラリをインポートします。

from transformers import BertTokenizer, BertForMaskedLM
import torch
import timeit
import numpy as np
import torch_tensorrt
import torch.backends.cudnn as cudnn

BERTモデル

BERTのモデルはTransformerのモデルをベースに作成されています。

Hugging Faceを使用することで簡単にBERTのモデルを利用できます。

下記コードで事前学習済みモデルを取得します。

enc = BertTokenizer.from_pretrained('bert-base-uncased')

BERTのモデルに与えるダミーのインプットを作成します。

batch_size = 4

batched_indexed_tokens = [[101, 64]*64]*batch_size
batched_segment_ids = [[0, 1]*64]*batch_size
batched_attention_masks = [[1, 1]*64]*batch_size

tokens_tensor = torch.tensor(batched_indexed_tokens)
segments_tensor = torch.tensor(batched_segment_ids)
attention_masks_tensor = torch.tensor(batched_attention_masks)

高速に推論できるようにTorch Script化を行います。

mlm_model_ts = BertForMaskedLM.from_pretrained('bert-base-uncased', torchscript=True)
traced_mlm_model = torch.jit.trace(mlm_model_ts, [tokens_tensor, segments_tensor, attention_masks_tensor])

推論用のデータを準備します。[MASK]部分を推論します。

masked_sentences = ['Paris is the [MASK] of France.', 
                    'The primary [MASK] of the United States is English.', 
                    'A baseball game consists of at least nine [MASK].', 
                    'Topology is a branch of [MASK] concerned with the properties of geometric objects that remain unchanged under continuous transformations.']
pos_masks = [4, 3, 9, 6]

下記コードで推論用のデータを推論します。

自然言語処理ではデータのサイズが異なるケースが多いので最大長を指定して、それに合わせてPadding処理で穴埋めしています。

最大長は128にしています。

encoded_inputs = enc(masked_sentences, return_tensors='pt', padding='max_length', max_length=128)
outputs = mlm_model_ts(**encoded_inputs)
most_likely_token_ids = [torch.argmax(outputs[0][i, pos, :]) for i, pos in enumerate(pos_masks)]
unmasked_tokens = enc.decode(most_likely_token_ids).split(' ')
unmasked_sentences = [masked_sentences[i].replace('[MASK]', token) for i, token in enumerate(unmasked_tokens)]
for sentence in unmasked_sentences:
    print(sentence)

下記のような結果が取得できます。

[MASK]部分が正しく推論できているように見えます。

Paris is the capital of France.
The primary language of the United States is English.
A baseball game consists of at least nine innings.
Topology is a branch of mathematics concerned with the properties of geometric objects that remain unchanged under continuous transformations.

Dummyの入力でTraceしたモデルで推論します。データの与え方は'input_ids''token_type_ids', 'attention_mask'の3つのキーで別々で与える形にしています。

encoded_inputs = enc(masked_sentences, return_tensors='pt', padding='max_length', max_length=128)
outputs = traced_mlm_model(encoded_inputs['input_ids'], encoded_inputs['token_type_ids'], encoded_inputs['attention_mask'])
most_likely_token_ids = [torch.argmax(outputs[0][i, pos, :]) for i, pos in enumerate(pos_masks)]
unmasked_tokens = enc.decode(most_likely_token_ids).split(' ')
unmasked_sentences = [masked_sentences[i].replace('[MASK]', token) for i, token in enumerate(unmasked_tokens)]
for sentence in unmasked_sentences:
    print(sentence)

通常のモデルとは異なる結果を出していますが、こちらも正しい推論結果のように見えます。

Paris is the all of France.
The primary all of the United States is English.
A baseball game consists of at least nine each.
Topology is a branch of each concerned with the properties of geometric objects that remain unchanged under continuous transformations.

Torch TensorRT形式に変換

ログレベルを変更します。

new_level = torch_tensorrt.logging.Level.Error
torch_tensorrt.logging.set_reportable_log_level(new_level)

Torch TensorRT形式に変換します。

trt_model = torch_tensorrt.compile(traced_mlm_model, 
    inputs= [torch_tensorrt.Input(shape=[batch_size, 128], dtype=torch.int32),  # input_ids
             torch_tensorrt.Input(shape=[batch_size, 128], dtype=torch.int32),  # token_type_ids
             torch_tensorrt.Input(shape=[batch_size, 128], dtype=torch.int32)], # attention_mask
    enabled_precisions= {torch.float32}, # Run with 32-bit precision
    workspace_size=2000000000,
    truncate_long_and_double=True
)

同様に推論処理をします。

enc_inputs = enc(masked_sentences, return_tensors='pt', padding='max_length', max_length=128)
enc_inputs = {k: v.type(torch.int32).cuda() for k, v in enc_inputs.items()}
output_trt = trt_model(enc_inputs['input_ids'], enc_inputs['token_type_ids'], enc_inputs['attention_mask'])
most_likely_token_ids_trt = [torch.argmax(output_trt[i, pos, :]) for i, pos in enumerate(pos_masks)] 
unmasked_tokens_trt = enc.decode(most_likely_token_ids_trt).split(' ')
unmasked_sentences_trt = [masked_sentences[i].replace('[MASK]', token) for i, token in enumerate(unmasked_tokens_trt)]
for sentence in unmasked_sentences_trt:
    print(sentence)

推論結果は下記のようになります。

Paris is the all of France.
The primary all of the United States is English.
A baseball game consists of at least nine each.
Topology is a branch of each concerned with the properties of geometric objects that remain unchanged under continuous transformations.

数値精度をFP16に変更して実行します。

trt_model_fp16 = torch_tensorrt.compile(traced_mlm_model, 
    inputs= [torch_tensorrt.Input(shape=[batch_size, 128], dtype=torch.int32),  # input_ids
             torch_tensorrt.Input(shape=[batch_size, 128], dtype=torch.int32),  # token_type_ids
             torch_tensorrt.Input(shape=[batch_size, 128], dtype=torch.int32)], # attention_mask
    enabled_precisions= {torch.half}, # Run with 16-bit precision
    workspace_size=2000000000,
    truncate_long_and_double=True
)

モデルの推論性能の比較

ベンチマークを行うための関数を作成します。Warm Up処理後に50回ほど推論処理をします。

def timeGraph(model, input_tensor1, input_tensor2, input_tensor3, num_loops=50):
    print("Warm up ...")
    with torch.no_grad():
        for _ in range(20):
            features = model(input_tensor1, input_tensor2, input_tensor3)

    torch.cuda.synchronize()

    print("Start timing ...")
    timings = []
    with torch.no_grad():
        for i in range(num_loops):
            start_time = timeit.default_timer()
            features = model(input_tensor1, input_tensor2, input_tensor3)
            torch.cuda.synchronize()
            end_time = timeit.default_timer()
            timings.append(end_time - start_time)
            # print("Iteration {}: {:.6f} s".format(i, end_time - start_time))

    return timings

取得したベンチマークの値のレイテンシー、スループットの中央値、平均値、標準偏差を取得します。

比較して見やすいようにコードを若干変更しています。

def printStats(graphName, timings, batch_size):
    times = np.array(timings)
    steps = len(times)
    speeds = batch_size / times
    time_mean = np.mean(times)
    time_med = np.median(times)
    time_99th = np.percentile(times, 99)
    time_std = np.std(times, ddof=0)
    speed_mean = np.mean(speeds)
    speed_med = np.median(speeds)

    msg = ("\n%s =================================\n"
            "batch size=%d, num iterations=%d\n"
            "  Median text batches/second: %.1f, mean: %.1f\n"
            "  Median latency: %.6f, mean: %.6f, 99th_p: %.6f, std_dev: %.6f\n"
            ) % (graphName,
                batch_size, steps,
                speed_med, speed_mean,
                time_med, time_mean, time_99th, time_std)
    print(msg)
    return speed_med, speed_mean, time_med, time_mean
cudnn.benchmark = True

Torch Script化したモデルの性能を計測します。

timings = timeGraph(mlm_model_ts.cuda(), enc_inputs['input_ids'], enc_inputs['token_type_ids'], enc_inputs['attention_mask'])

speed_med, speed_mean, time_med, time_mean = printStats("BERT", timings, batch_size)

dta_dict = {"speed_med": [speed_med], 
            "speed_mean": [speed_mean], 
            "time_med": [time_med], 
            "time_mean": [time_mean]}

著者の環境では下記のような性能になりました。

Warm up ...
Start timing ...

BERT =================================
batch size=4, num iterations=50
  Median text batches/second: 599.1, mean: 597.6
  Median latency: 0.006677, mean: 0.006693, 99th_p: 0.006943, std_dev: 0.000059

Traceしたモデルで性能を計測します。

timings = timeGraph(traced_mlm_model.cuda(), enc_inputs['input_ids'], enc_inputs['token_type_ids'], enc_inputs['attention_mask'])

speed_med, speed_mean, time_med, time_mean = printStats("BERT", timings, batch_size)

dta_dict["speed_med"].append(speed_med)
dta_dict["speed_mean"].append(speed_mean)
dta_dict["time_med"].append(time_med)
dta_dict["time_mean"].append(time_mean)

著者の環境では下記のような性能になりました。

Warm up ...
Start timing ...

BERT =================================
batch size=4, num iterations=50
  Median text batches/second: 951.2, mean: 951.0
  Median latency: 0.004205, mean: 0.004206, 99th_p: 0.004256, std_dev: 0.000015

Torch TensorRT形式に変換したモデルで計測します。

timings = timeGraph(trt_model, enc_inputs['input_ids'], enc_inputs['token_type_ids'], enc_inputs['attention_mask'])

speed_med, speed_mean, time_med, time_mean = printStats("BERT", timings, batch_size)

dta_dict["speed_med"].append(speed_med)
dta_dict["speed_mean"].append(speed_mean)
dta_dict["time_med"].append(time_med)
dta_dict["time_mean"].append(time_mean)

著者の環境では下記のような性能になりました。

Warm up ...
Start timing ...

BERT =================================
batch size=4, num iterations=50
  Median text batches/second: 1216.9, mean: 1216.4
  Median latency: 0.003287, mean: 0.003289, 99th_p: 0.003317, std_dev: 0.000007

Torch TensorRT形式に変換+数値精度FP16のモデルで計測します。

timings = timeGraph(trt_model_fp16, enc_inputs['input_ids'], enc_inputs['token_type_ids'], enc_inputs['attention_mask'])

speed_med, speed_mean, time_med, time_mean = printStats("BERT", timings, batch_size)

dta_dict["speed_med"].append(speed_med)
dta_dict["speed_mean"].append(speed_mean)
dta_dict["time_med"].append(time_med)
dta_dict["time_mean"].append(time_mean)

著者の環境では下記のような性能になりました。

Warm up ...
Start timing ...

BERT =================================
batch size=4, num iterations=50
  Median text batches/second: 1776.7, mean: 1771.1
  Median latency: 0.002251, mean: 0.002259, 99th_p: 0.002305, std_dev: 0.000015

全結果をまとめて確認します。

import pandas as pd

pd.DataFrame.from_dict(dta_dict, orient='index', columns=["BERT_ts", "BERT_trace", "BERT_Torch_TensorRT", "BERT_Torch_TensorRT_FP16"]).T

Torch Scriptのモデルに比べて中央値のレイテンシーは約3倍ほど高速化できています。

Torch TensorRTがHugging Faceのモデルにも適用できました。

これでいろんな自然言語処理モデルの推論性能を高速化できるな!

Close Bitnami banner
Bitnami