Torch TensorRTでInt8量子化がしたいです。
量子化は大きく2つ方法があるのでこの方法を紹介するぞ
準備
Torch TensorRTの機能については下記をご覧ください
Torch TensorRTの環境はdockerを使用すると簡単に用意できます。
https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch
docker run --gpus all -it --rm -p 8888:8888 --name tensorrt nvcr.io/nvidia/pytorch:22.07-py3
下記コードを参考に実行します。
https://github.com/pytorch/TensorRT/blob/master/notebooks/qat-ptq-workflow.ipynb
ライブラリのインポート
pipで必要なライブラリをインストールします。
!pip install ipywidgets --trusted-host pypi.org --trusted-host pypi.python.org --trusted-host=files.pythonhosted.org
!pip install wget
必要なライブラリをインポートします。
import torch
import torch.backends.cudnn as cudnn
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as data
import torchvision.transforms as transforms
from torchvision import models, datasets
import torch_tensorrt
import pytorch_quantization
from pytorch_quantization import nn as quant_nn
from pytorch_quantization import quant_modules
from pytorch_quantization.tensor_quant import QuantDescriptor
from pytorch_quantization import calib
from tqdm import tqdm
print(pytorch_quantization.__version__)
import os
import sys
import warnings
import time
import numpy as np
import wget
import tarfile
import shutil
warnings.simplefilter('ignore')
データの準備
データをダウンロードするための関数を用意します。
def download_data(DATA_DIR):
if os.path.exists(DATA_DIR):
if not os.path.exists(os.path.join(DATA_DIR, 'imagenette2-320')):
url = 'https://s3.amazonaws.com/fast-ai-imageclas/imagenette2-320.tgz'
wget.download(url)
# open file
file = tarfile.open('imagenette2-320.tgz')
# extracting file
file.extractall(DATA_DIR)
file.close()
else:
print("This directory doesn't exist. Create the directory and run again")
ダウンロード用のディレクトリを作成します。
if not os.path.exists("./data"):
os.mkdir("./data")
download_data("./data")
学習データと検証データのパスを設定します。
# Define main data directory
DATA_DIR = './data/imagenette2-320'
# Define training and validation data paths
TRAIN_DIR = os.path.join(DATA_DIR, 'train')
VAL_DIR = os.path.join(DATA_DIR, 'val')
データの前処理をしつつデータを取得するデータローダーを作成します.
#Performing Transformations on the dataset and defining training and validation dataloaders
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
])
train_dataset = datasets.ImageFolder(TRAIN_DIR, transform=transform)
val_dataset = datasets.ImageFolder(VAL_DIR, transform=transform)
calib_dataset = torch.utils.data.random_split(val_dataset, [2901, 1024])[1]
train_dataloader = data.DataLoader(train_dataset, batch_size=32, shuffle=True, drop_last=True)
val_dataloader = data.DataLoader(val_dataset, batch_size=64, shuffle=False, drop_last=True)
calib_dataloader = data.DataLoader(calib_dataset, batch_size=64, shuffle=False, drop_last=True)
使用するデータセットを確認します。
# Visualising an image from the validation set
import matplotlib.pyplot as plt
for images, labels in val_dataloader:
print(labels[0])
image = images[0]
img = image.swapaxes(0, 1)
img = img.swapaxes(1, 2)
plt.imshow(img)
break

モデルの準備
モデルの準備をします。特徴抽出層はフリーズし、分類用のレイヤーを追加します。
#This function allows you to set the all the parameters to not have gradients,
#allowing you to freeze the model and not undergo training during the train step.
def set_parameter_requires_grad(model, feature_extracting):
if feature_extracting:
for param in model.parameters():
param.requires_grad = False
feature_extract = True #This varaible can be set False if you want to finetune the model by updating all the parameters.
model = models.mobilenet_v2(pretrained=True)
set_parameter_requires_grad(model, feature_extract)
#Define a classification head for 10 classes.
model.classifier[1] = nn.Linear(1280, 10)
model = model.cuda()
学習率とロス、optimizerを設定します。
# Declare Learning rate
lr = 0.0001
# Use cross entropy loss for classification and SGD optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=lr)
学習用の関数を設定します。
#Define functions for training, evalution, saving checkpoint and train parameter setting function
def train(model, dataloader, crit, opt, epoch):
model.train()
running_loss = 0.0
for batch, (data, labels) in enumerate(dataloader):
data, labels = data.cuda(), labels.cuda(non_blocking=True)
opt.zero_grad()
out = model(data)
loss = crit(out, labels)
loss.backward()
opt.step()
running_loss += loss.item()
if batch % 100 == 99:
print("Batch: [%5d | %5d] loss: %.3f" % (batch + 1, len(dataloader), running_loss / 100))
running_loss = 0.0
評価用の関数を設定します。
def evaluate(model, dataloader, crit, epoch):
total = 0
correct = 0
loss = 0.0
class_probs = []
class_preds = []
model.eval()
with torch.no_grad():
for data, labels in dataloader:
data, labels = data.cuda(), labels.cuda(non_blocking=True)
out = model(data)
loss += crit(out, labels)
preds = torch.max(out, 1)[1]
class_probs.append([F.softmax(i, dim=0) for i in out])
class_preds.append(preds)
total += labels.size(0)
correct += (preds == labels).sum().item()
evaluate_probs = torch.cat([torch.stack(batch) for batch in class_probs])
evaluate_preds = torch.cat(class_preds)
return loss / total, correct / total
モデルを保存するようの関数を設定します。
def save_checkpoint(state, ckpt_path="checkpoint.pth"):
torch.save(state, ckpt_path)
print("Checkpoint saved")
ベンチマーク用の関数を設定します。
cudnn.benchmark = True
# Helper function to benchmark the model
def benchmark(model, input_shape=(1024, 1, 32, 32), dtype='fp32', nwarmup=50, nruns=1000):
input_data = torch.randn(input_shape)
input_data = input_data.to("cuda")
if dtype=='fp16':
input_data = input_data.half()
with torch.no_grad():
for _ in range(nwarmup):
features = model(input_data)
torch.cuda.synchronize()
timings = []
with torch.no_grad():
for i in range(1, nruns+1):
start_time = time.time()
output = model(input_data)
torch.cuda.synchronize()
end_time = time.time()
timings.append(end_time - start_time)
print('Average batch time: %.2f ms'%(np.mean(timings)*1000))
モデルを学習します。
# Train the model for 3 epochs to attain an acceptable accuracy.
num_epochs=3
for epoch in range(num_epochs):
print('Epoch: [%5d / %5d] LR: %f' % (epoch + 1, num_epochs, lr))
train(model, train_dataloader, criterion, optimizer, epoch)
test_loss, test_acc = evaluate(model, val_dataloader, criterion, epoch)
print("Test Loss: {:.5f} Test Acc: {:.2f}%".format(test_loss, 100 * test_acc))
save_checkpoint({'epoch': epoch + 1,
'model_state_dict': model.state_dict(),
'acc': test_acc,
'opt_state_dict': optimizer.state_dict()
},
ckpt_path="mobilenetv2_base_ckpt")
学習ログは下記のようになります。
Epoch: [ 1 / 3] LR: 0.000100
Batch: [ 100 | 295] loss: 2.299
Batch: [ 200 | 295] loss: 2.219
Test Loss: 0.03214 Test Acc: 34.76%
Epoch: [ 2 / 3] LR: 0.000100
Batch: [ 100 | 295] loss: 2.053
Batch: [ 200 | 295] loss: 1.965
Test Loss: 0.02858 Test Acc: 60.81%
Epoch: [ 3 / 3] LR: 0.000100
Batch: [ 100 | 295] loss: 1.838
Batch: [ 200 | 295] loss: 1.771
Test Loss: 0.02544 Test Acc: 78.46%
Checkpoint saved
検証データの精度とベンチマークの結果を確認します。
#Evaluate and benchmark the performance of the baseline model
test_loss, test_acc = evaluate(model, val_dataloader, criterion, 0)
print("Mobilenetv2 Baseline accuracy: {:.2f}%".format(100 * test_acc))
benchmark(model, input_shape=(64, 3, 224, 224))
下記のように確認できます。
Mobilenetv2 Baseline accuracy: 78.46%
Average batch time: 24.36 ms
Torch TensorRT形式に変換します。
# Exporting to TorchScript
with torch.no_grad():
data = iter(val_dataloader)
images, _ = data.next()
jit_model = torch.jit.trace(model, images.to("cuda"))
torch.jit.save(jit_model, "mobilenetv2_base.jit.pt")
#Loading the Torchscript model and compiling it into a TensorRT model
baseline_model = torch.jit.load("mobilenetv2_base.jit.pt").eval()
compile_spec = {"inputs": [torch_tensorrt.Input([64, 3, 224, 224])]
, "enabled_precisions": torch.float
}
trt_base = torch_tensorrt.compile(baseline_model, **compile_spec)
Torch TensorRTで精度とベンチマークタイムを確認します。
# Evaluate and benchmark the performance of the baseline TRT model (TRT FP32 Model)
test_loss, test_acc = evaluate(trt_base, val_dataloader, criterion, 0)
print("Mobilenetv2 TRT Baseline accuracy: {:.2f}%".format(100 * test_acc))
benchmark(trt_base, input_shape=(64, 3, 224, 224))
精度は約3%下がっていますが、速度は2倍程度早くなっています。
Mobilenetv2 TRT Baseline accuracy: 75.38%
Average batch time: 10.32 ms
量子化
量子化は2つ方法があります。まず学習後に量子化する方法です。
FP32の数値精度をInt8の数値精度にマップするキャリブレーションがあります。キャリブレーションは3つ方法があります。
- Min-Max: FP32の最大値と最小値をIn8のマッピングする値としてキャリブレーション時に設定します。丸め誤差が大きくなる可能性があります。
- Entropy: クロスエントロピーを用いて情報誤差が最小になるようなレンジを設定します。
- Percentile: データの絶対値の分布から1%をクリップして99%をカバーするレンジを設定します。

キャリブレーションはMin-Max方法を選んでいます。
calibrator = torch_tensorrt.ptq.DataLoaderCalibrator(calib_dataloader,
use_cache=False,
algo_type=torch_tensorrt.ptq.CalibrationAlgo.MINMAX_CALIBRATION,
device=torch.device('cuda:0'))
compile_spec = {
"inputs": [torch_tensorrt.Input([64, 3, 224, 224])],
"enabled_precisions": torch.int8,
"calibrator": calibrator,
"truncate_long_and_double": True
}
trt_ptq = torch_tensorrt.compile(baseline_model, **compile_spec)
量子化したモデルの精度とベンチマークの結果を確認します。
# Evaluate the PTQ model
test_loss, test_acc = evaluate(trt_ptq, val_dataloader, criterion, 0)
print("Mobilenetv2 PTQ accuracy: {:.2f}%".format(100 * test_acc))
benchmark(trt_ptq, input_shape=(64, 3, 224, 224))
精度は6%下がり、速度は10倍程度速くなっています。
Mobilenetv2 PTQ accuracy: 72.67%
Average batch time: 2.36 ms
Quantization Aware Training
学習中に量子化誤差を考慮する方法で、疑似量子化を行うレイヤーと量子化を戻すレイヤーを入れています。

PyTorchではQuantization toolkitを提供しています。量子化をシミュレーションできます。
quant_modules.initialize()を使用するとモデルの内部でオリジナルのレイヤーからQAT用のレイヤーに置き換えられます。例えばQuantConv2d、QuantLinear
、 QuantPooling
に置き換えられます。
置き換え可能なレイヤーはこちらで確認できます。
quant_modules.initialize()
先程と同様にモデルを定義して保存済みモデルをロードします。自動的にQuantizeとDeQuantizeレイヤーが挿入されます。
# We define Mobilenetv2 again just like we did above
# All the regular conv, FC layers will be converted to their quantized counterparts due to quant_modules.initialize()
feature_extract = False
q_model = models.mobilenet_v2(pretrained=True)
set_parameter_requires_grad(q_model, feature_extract)
q_model.classifier[1] = nn.Linear(1280, 10)
q_model = q_model.cuda()
# mobilenetv2_base_ckpt is the checkpoint generated from Step 2 : Training a baseline Mobilenetv2 model.
ckpt = torch.load("./mobilenetv2_base_ckpt")
modified_state_dict={}
for key, val in ckpt["model_state_dict"].items():
# Remove 'module.' from the key names
if key.startswith('module'):
modified_state_dict[key[7:]] = val
else:
modified_state_dict[key] = val
# Load the pre-trained checkpoint
q_model.load_state_dict(modified_state_dict)
optimizer.load_state_dict(ckpt["opt_state_dict"])
キャリブレーション結果をロードする関数を作成します。
def compute_amax(model, **kwargs):
# Load calib result
for name, module in model.named_modules():
if isinstance(module, quant_nn.TensorQuantizer):
if module._calibrator is not None:
if isinstance(module._calibrator, calib.MaxCalibrator):
module.load_calib_amax()
else:
module.load_calib_amax(**kwargs)
model.cuda()
統計的な情報を取得します。Percentileの技術を使用してキャリブレーションをします。
def collect_stats(model, data_loader, num_batches):
"""Feed data to the network and collect statistics"""
# Enable calibrators
for name, module in model.named_modules():
if isinstance(module, quant_nn.TensorQuantizer):
if module._calibrator is not None:
module.disable_quant()
module.enable_calib()
else:
module.disable()
# Feed data to the network for collecting stats
for i, (image, _) in tqdm(enumerate(data_loader), total=num_batches):
model(image.cuda())
if i >= num_batches:
break
# Disable calibrators
for name, module in model.named_modules():
if isinstance(module, quant_nn.TensorQuantizer):
if module._calibrator is not None:
module.enable_quant()
module.disable_calib()
else:
module.enable()
先程作成した関数を用いて統計的な情報を取得し、キャリブレーションをします。
#Calibrate the model using percentile calibration technique.
with torch.no_grad():
collect_stats(q_model, train_dataloader, num_batches=32)
compute_amax(q_model, method="max")
QATモデルをファインチューニングします。
# Finetune the QAT model for 2 epochs
num_epochs=2
lr = 0.001
for epoch in range(num_epochs):
print('Epoch: [%5d / %5d] LR: %f' % (epoch + 1, num_epochs, lr))
train(q_model, train_dataloader, criterion, optimizer, epoch)
test_loss, test_acc = evaluate(q_model, val_dataloader, criterion, epoch)
print("Test Loss: {:.5f} Test Acc: {:.2f}%".format(test_loss, 100 * test_acc))
save_checkpoint({'epoch': epoch + 1,
'model_state_dict': q_model.state_dict(),
'acc': test_acc,
'opt_state_dict': optimizer.state_dict()
},
ckpt_path="mobilenetv2_qat_ckpt")
ファインチューニングによって1%程度精度が回復しています。下記にQATにかんする詳しい記事があります。
Achieving FP32 Accuracy for INT8 Inference Using Quantization Aware Training with NVIDIA TensorRT | NVIDIA Technical Blog https://developer.nvidia.com/blog/achieving-fp32-accuracy-for-int8-inference-using-quantization-aware-training-with-tensorrt/
Epoch: [ 1 / 2] LR: 0.001000
Batch: [ 100 | 295] loss: 1.766
Batch: [ 200 | 295] loss: 1.792
Test Loss: 0.02721 Test Acc: 74.00%
Epoch: [ 2 / 2] LR: 0.001000
Batch: [ 100 | 295] loss: 1.775
Batch: [ 200 | 295] loss: 1.777
Test Loss: 0.02712 Test Acc: 73.64%
Checkpoint saved
quant_nn.TensorQuantizer.use_fb_fake_quant = Trueを設定して推論すると内部的に
torch.fake_quantize_per_tensor_affine
と torch.fake_quantize_per_channel_affine
を使用して量子化しTensorRT形式に対応するOperatorsに変換します。
quant_nn.TensorQuantizer.use_fb_fake_quant = True
with torch.no_grad():
data = iter(val_dataloader)
images, _ = data.next()
jit_model = torch.jit.trace(q_model, images.to("cuda"))
torch.jit.save(jit_model, "mobilenetv2_qat.jit.pt")
変換したTorch TensorRT形式のモデルをロードします。
#Loading the Torchscript model and compiling it into a TensorRT model
qat_model = torch.jit.load("mobilenetv2_qat.jit.pt").eval()
compile_spec = {"inputs": [torch_tensorrt.Input([64, 3, 224, 224])],
"enabled_precisions": torch.int8
}
trt_mod = torch_tensorrt.compile(qat_model, **compile_spec)
QATモデルの精度とベンチマークを計測します。
#Evaluate and benchmark the performance of the QAT-TRT model (TRT INT8)
test_loss, test_acc = evaluate(trt_mod, val_dataloader, criterion, 0)
print("Mobilenetv2 QAT accuracy using TensorRT: {:.2f}%".format(100 * test_acc))
benchmark(trt_mod, input_shape=(64, 3, 224, 224))
下記のように結果になります。各モデルの精度とパフォーマンスを比較します。
Mobilenetv2 QAT accuracy using TensorRT: 73.54%
Average batch time: 3.13 ms
Model | Accuracy | Performance |
Baseline MobileNetv2 | 78.46% | 24.36 ms |
Base + TRT (TRT FP32) | 75.38% | 10.32 ms |
PTQ + TRT (TRT int8) | 72.67% | 2.36 ms |
QAT+TRT (TRT INT8) | 73.54% | 3.13 ms |
2つの方法で量子化を試せました。
Torch TensorRTのPTQとQATの量子化を簡単に試せるぞ