Torch TensorRT で Int8 量子化を試す

Torch TensorRTでInt8量子化がしたいです。

量子化は大きく2つ方法があるのでこの方法を紹介するぞ

準備

Torch TensorRTの機能については下記をご覧ください

Torch TensorRTの環境はdockerを使用すると簡単に用意できます。

https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch

docker run --gpus all -it --rm -p 8888:8888 --name tensorrt nvcr.io/nvidia/pytorch:22.07-py3 

下記コードを参考に実行します。

https://github.com/pytorch/TensorRT/blob/master/notebooks/qat-ptq-workflow.ipynb

ライブラリのインポート

pipで必要なライブラリをインストールします。

!pip install ipywidgets --trusted-host pypi.org --trusted-host pypi.python.org --trusted-host=files.pythonhosted.org
!pip install wget

必要なライブラリをインポートします。

import torch
import torch.backends.cudnn as cudnn
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as data
import torchvision.transforms as transforms
from torchvision import models, datasets
import torch_tensorrt


import pytorch_quantization
from pytorch_quantization import nn as quant_nn
from pytorch_quantization import quant_modules
from pytorch_quantization.tensor_quant import QuantDescriptor
from pytorch_quantization import calib
from tqdm import tqdm

print(pytorch_quantization.__version__)

import os
import sys
import warnings
import time
import numpy as np
import wget
import tarfile
import shutil
warnings.simplefilter('ignore')

データの準備

データをダウンロードするための関数を用意します。

def download_data(DATA_DIR):
    if os.path.exists(DATA_DIR):
        if not os.path.exists(os.path.join(DATA_DIR, 'imagenette2-320')):
            url = 'https://s3.amazonaws.com/fast-ai-imageclas/imagenette2-320.tgz'
            wget.download(url)
            # open file
            file = tarfile.open('imagenette2-320.tgz')
            # extracting file
            file.extractall(DATA_DIR)
            file.close()
    else:
        print("This directory doesn't exist. Create the directory and run again")

ダウンロード用のディレクトリを作成します。

if not os.path.exists("./data"):
    os.mkdir("./data")
download_data("./data")

学習データと検証データのパスを設定します。

# Define main data directory
DATA_DIR = './data/imagenette2-320' 
# Define training and validation data paths
TRAIN_DIR = os.path.join(DATA_DIR, 'train') 
VAL_DIR = os.path.join(DATA_DIR, 'val')

データの前処理をしつつデータを取得するデータローダーを作成します.

#Performing Transformations on the dataset and defining training and validation dataloaders
transform = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            ])
train_dataset = datasets.ImageFolder(TRAIN_DIR, transform=transform)
val_dataset = datasets.ImageFolder(VAL_DIR, transform=transform)
calib_dataset = torch.utils.data.random_split(val_dataset, [2901, 1024])[1]

train_dataloader = data.DataLoader(train_dataset, batch_size=32, shuffle=True, drop_last=True)
val_dataloader = data.DataLoader(val_dataset, batch_size=64, shuffle=False, drop_last=True)
calib_dataloader = data.DataLoader(calib_dataset, batch_size=64, shuffle=False, drop_last=True)

使用するデータセットを確認します。

# Visualising an image from the validation set
import matplotlib.pyplot as plt
for images, labels in val_dataloader:
    print(labels[0])
    image = images[0]
    img = image.swapaxes(0, 1)
    img = img.swapaxes(1, 2)
    plt.imshow(img)
    break

モデルの準備

モデルの準備をします。特徴抽出層はフリーズし、分類用のレイヤーを追加します。

#This function allows you to set the all the parameters to not have gradients, 
#allowing you to freeze the model and not undergo training during the train step. 
def set_parameter_requires_grad(model, feature_extracting):
    if feature_extracting:
        for param in model.parameters():
            param.requires_grad = False

feature_extract = True #This varaible can be set False if you want to finetune the model by updating all the parameters. 
model = models.mobilenet_v2(pretrained=True)
set_parameter_requires_grad(model, feature_extract)
#Define a classification head for 10 classes.
model.classifier[1] = nn.Linear(1280, 10)
model = model.cuda()

学習率とロス、optimizerを設定します。

# Declare Learning rate
lr = 0.0001

# Use cross entropy loss for classification and SGD optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=lr)

学習用の関数を設定します。

#Define functions for training, evalution, saving checkpoint and train parameter setting function
def train(model, dataloader, crit, opt, epoch):
    model.train()
    running_loss = 0.0
    for batch, (data, labels) in enumerate(dataloader):
        data, labels = data.cuda(), labels.cuda(non_blocking=True)
        opt.zero_grad()
        out = model(data)
        loss = crit(out, labels)
        loss.backward()
        opt.step()
        running_loss += loss.item()
        if batch % 100 == 99:
            print("Batch: [%5d | %5d] loss: %.3f" % (batch + 1, len(dataloader), running_loss / 100))
            running_loss = 0.0

評価用の関数を設定します。

def evaluate(model, dataloader, crit, epoch):
    total = 0
    correct = 0
    loss = 0.0
    class_probs = []
    class_preds = []
    model.eval()
    with torch.no_grad():
        for data, labels in dataloader:
            data, labels = data.cuda(), labels.cuda(non_blocking=True)
            out = model(data)
            loss += crit(out, labels)
            preds = torch.max(out, 1)[1]
            class_probs.append([F.softmax(i, dim=0) for i in out])
            class_preds.append(preds)
            total += labels.size(0)
            correct += (preds == labels).sum().item()

    evaluate_probs = torch.cat([torch.stack(batch) for batch in class_probs])
    evaluate_preds = torch.cat(class_preds)

    return loss / total, correct / total

モデルを保存するようの関数を設定します。

def save_checkpoint(state, ckpt_path="checkpoint.pth"):
    torch.save(state, ckpt_path)
    print("Checkpoint saved")

ベンチマーク用の関数を設定します。

cudnn.benchmark = True
# Helper function to benchmark the model
def benchmark(model, input_shape=(1024, 1, 32, 32), dtype='fp32', nwarmup=50, nruns=1000):
    input_data = torch.randn(input_shape)
    input_data = input_data.to("cuda")
    if dtype=='fp16':
        input_data = input_data.half()
        
    with torch.no_grad():
        for _ in range(nwarmup):
            features = model(input_data)
    torch.cuda.synchronize()

    timings = []
    with torch.no_grad():
        for i in range(1, nruns+1):
            start_time = time.time()
            output = model(input_data)
            torch.cuda.synchronize()
            end_time = time.time()
            timings.append(end_time - start_time)
    print('Average batch time: %.2f ms'%(np.mean(timings)*1000))

モデルを学習します。

# Train the model for 3 epochs to attain an acceptable accuracy.
num_epochs=3
for epoch in range(num_epochs):
    print('Epoch: [%5d / %5d] LR: %f' % (epoch + 1, num_epochs, lr))

    train(model, train_dataloader, criterion, optimizer, epoch)
    test_loss, test_acc = evaluate(model, val_dataloader, criterion, epoch)

    print("Test Loss: {:.5f} Test Acc: {:.2f}%".format(test_loss, 100 * test_acc))
    
save_checkpoint({'epoch': epoch + 1,
                 'model_state_dict': model.state_dict(),
                 'acc': test_acc,
                 'opt_state_dict': optimizer.state_dict()
                },
                ckpt_path="mobilenetv2_base_ckpt")

学習ログは下記のようになります。

Epoch: [    1 /     3] LR: 0.000100
Batch: [  100 |   295] loss: 2.299
Batch: [  200 |   295] loss: 2.219
Test Loss: 0.03214 Test Acc: 34.76%
Epoch: [    2 /     3] LR: 0.000100
Batch: [  100 |   295] loss: 2.053
Batch: [  200 |   295] loss: 1.965
Test Loss: 0.02858 Test Acc: 60.81%
Epoch: [    3 /     3] LR: 0.000100
Batch: [  100 |   295] loss: 1.838
Batch: [  200 |   295] loss: 1.771
Test Loss: 0.02544 Test Acc: 78.46%
Checkpoint saved

検証データの精度とベンチマークの結果を確認します。

#Evaluate and benchmark the performance of the baseline model
test_loss, test_acc = evaluate(model, val_dataloader, criterion, 0)
print("Mobilenetv2 Baseline accuracy: {:.2f}%".format(100 * test_acc))

benchmark(model, input_shape=(64, 3, 224, 224))

下記のように確認できます。

Mobilenetv2 Baseline accuracy: 78.46%
Average batch time: 24.36 ms

Torch TensorRT形式に変換します。

# Exporting to TorchScript
with torch.no_grad():
    data = iter(val_dataloader)
    images, _ = data.next()
    jit_model = torch.jit.trace(model, images.to("cuda"))
    torch.jit.save(jit_model, "mobilenetv2_base.jit.pt")

#Loading the Torchscript model and compiling it into a TensorRT model
baseline_model = torch.jit.load("mobilenetv2_base.jit.pt").eval()
compile_spec = {"inputs": [torch_tensorrt.Input([64, 3, 224, 224])]
               , "enabled_precisions": torch.float
               }
trt_base = torch_tensorrt.compile(baseline_model, **compile_spec)

Torch TensorRTで精度とベンチマークタイムを確認します。

# Evaluate and benchmark the performance of the baseline TRT model (TRT FP32 Model)
test_loss, test_acc = evaluate(trt_base, val_dataloader, criterion, 0)
print("Mobilenetv2 TRT Baseline accuracy: {:.2f}%".format(100 * test_acc))

benchmark(trt_base, input_shape=(64, 3, 224, 224))

精度は約3%下がっていますが、速度は2倍程度早くなっています。

Mobilenetv2 TRT Baseline accuracy: 75.38%
Average batch time: 10.32 ms

量子化

量子化は2つ方法があります。まず学習後に量子化する方法です。

FP32の数値精度をInt8の数値精度にマップするキャリブレーションがあります。キャリブレーションは3つ方法があります。

  • Min-Max: FP32の最大値と最小値をIn8のマッピングする値としてキャリブレーション時に設定します。丸め誤差が大きくなる可能性があります。
  • Entropy: クロスエントロピーを用いて情報誤差が最小になるようなレンジを設定します。
  • Percentile: データの絶対値の分布から1%をクリップして99%をカバーするレンジを設定します。

キャリブレーションはMin-Max方法を選んでいます。

calibrator = torch_tensorrt.ptq.DataLoaderCalibrator(calib_dataloader,
                                              use_cache=False,
                                              algo_type=torch_tensorrt.ptq.CalibrationAlgo.MINMAX_CALIBRATION,
                                              device=torch.device('cuda:0'))

compile_spec = {
         "inputs": [torch_tensorrt.Input([64, 3, 224, 224])],
         "enabled_precisions": torch.int8,
         "calibrator": calibrator,
        "truncate_long_and_double": True
         
     }
trt_ptq = torch_tensorrt.compile(baseline_model, **compile_spec)

量子化したモデルの精度とベンチマークの結果を確認します。

# Evaluate the PTQ model
test_loss, test_acc = evaluate(trt_ptq, val_dataloader, criterion, 0)
print("Mobilenetv2 PTQ accuracy: {:.2f}%".format(100 * test_acc))

benchmark(trt_ptq, input_shape=(64, 3, 224, 224))

精度は6%下がり、速度は10倍程度速くなっています。

Mobilenetv2 PTQ accuracy: 72.67%
Average batch time: 2.36 ms

Quantization Aware Training

学習中に量子化誤差を考慮する方法で、疑似量子化を行うレイヤーと量子化を戻すレイヤーを入れています。

PyTorchではQuantization toolkitを提供しています。量子化をシミュレーションできます。

quant_modules.initialize()を使用するとモデルの内部でオリジナルのレイヤーからQAT用のレイヤーに置き換えられます。例えばQuantConv2d、QuantLinear、 QuantPooling に置き換えられます。

置き換え可能なレイヤーはこちらで確認できます。

quant_modules.initialize()

先程と同様にモデルを定義して保存済みモデルをロードします。自動的にQuantizeとDeQuantizeレイヤーが挿入されます。

# We define Mobilenetv2 again just like we did above
# All the regular conv, FC layers will be converted to their quantized counterparts due to quant_modules.initialize()
feature_extract = False
q_model = models.mobilenet_v2(pretrained=True)
set_parameter_requires_grad(q_model, feature_extract)
q_model.classifier[1] = nn.Linear(1280, 10)
q_model = q_model.cuda()

# mobilenetv2_base_ckpt is the checkpoint generated from Step 2 : Training a baseline Mobilenetv2 model.
ckpt = torch.load("./mobilenetv2_base_ckpt")
modified_state_dict={}
for key, val in ckpt["model_state_dict"].items():
    # Remove 'module.' from the key names
    if key.startswith('module'):
        modified_state_dict[key[7:]] = val
    else:
        modified_state_dict[key] = val

# Load the pre-trained checkpoint
q_model.load_state_dict(modified_state_dict)
optimizer.load_state_dict(ckpt["opt_state_dict"])

キャリブレーション結果をロードする関数を作成します。

def compute_amax(model, **kwargs):
    # Load calib result
    for name, module in model.named_modules():
        if isinstance(module, quant_nn.TensorQuantizer):
            if module._calibrator is not None:
                if isinstance(module._calibrator, calib.MaxCalibrator):
                    module.load_calib_amax()
                else:
                    module.load_calib_amax(**kwargs)
    model.cuda()

統計的な情報を取得します。Percentileの技術を使用してキャリブレーションをします。

def collect_stats(model, data_loader, num_batches):
    """Feed data to the network and collect statistics"""
    # Enable calibrators
    for name, module in model.named_modules():
        if isinstance(module, quant_nn.TensorQuantizer):
            if module._calibrator is not None:
                module.disable_quant()
                module.enable_calib()
            else:
                module.disable()

    # Feed data to the network for collecting stats
    for i, (image, _) in tqdm(enumerate(data_loader), total=num_batches):
        model(image.cuda())
        if i >= num_batches:
            break

    # Disable calibrators
    for name, module in model.named_modules():
        if isinstance(module, quant_nn.TensorQuantizer):
            if module._calibrator is not None:
                module.enable_quant()
                module.disable_calib()
            else:
                module.enable()

先程作成した関数を用いて統計的な情報を取得し、キャリブレーションをします。

#Calibrate the model using percentile calibration technique.
with torch.no_grad():
    collect_stats(q_model, train_dataloader, num_batches=32)
    compute_amax(q_model, method="max")

QATモデルをファインチューニングします。

# Finetune the QAT model for 2 epochs
num_epochs=2
lr = 0.001
for epoch in range(num_epochs):
    print('Epoch: [%5d / %5d] LR: %f' % (epoch + 1, num_epochs, lr))

    train(q_model, train_dataloader, criterion, optimizer, epoch)
    test_loss, test_acc = evaluate(q_model, val_dataloader, criterion, epoch)

    print("Test Loss: {:.5f} Test Acc: {:.2f}%".format(test_loss, 100 * test_acc))
    
save_checkpoint({'epoch': epoch + 1,
                 'model_state_dict': q_model.state_dict(),
                 'acc': test_acc,
                 'opt_state_dict': optimizer.state_dict()
                },
                ckpt_path="mobilenetv2_qat_ckpt")

ファインチューニングによって1%程度精度が回復しています。下記にQATにかんする詳しい記事があります。

Achieving FP32 Accuracy for INT8 Inference Using Quantization Aware Training with NVIDIA TensorRT | NVIDIA Technical Blog https://developer.nvidia.com/blog/achieving-fp32-accuracy-for-int8-inference-using-quantization-aware-training-with-tensorrt/

Epoch: [    1 /     2] LR: 0.001000
Batch: [  100 |   295] loss: 1.766
Batch: [  200 |   295] loss: 1.792
Test Loss: 0.02721 Test Acc: 74.00%
Epoch: [    2 /     2] LR: 0.001000
Batch: [  100 |   295] loss: 1.775
Batch: [  200 |   295] loss: 1.777
Test Loss: 0.02712 Test Acc: 73.64%
Checkpoint saved

quant_nn.TensorQuantizer.use_fb_fake_quant = Trueを設定して推論すると内部的に

torch.fake_quantize_per_tensor_affine と torch.fake_quantize_per_channel_affine

を使用して量子化しTensorRT形式に対応するOperatorsに変換します。

quant_nn.TensorQuantizer.use_fb_fake_quant = True
with torch.no_grad():
    data = iter(val_dataloader)
    images, _ = data.next()
    jit_model = torch.jit.trace(q_model, images.to("cuda"))
    torch.jit.save(jit_model, "mobilenetv2_qat.jit.pt")

変換したTorch TensorRT形式のモデルをロードします。

#Loading the Torchscript model and compiling it into a TensorRT model
qat_model = torch.jit.load("mobilenetv2_qat.jit.pt").eval()
compile_spec = {"inputs": [torch_tensorrt.Input([64, 3, 224, 224])],
                "enabled_precisions": torch.int8
               }
trt_mod = torch_tensorrt.compile(qat_model, **compile_spec)

QATモデルの精度とベンチマークを計測します。

#Evaluate and benchmark the performance of the QAT-TRT model (TRT INT8)
test_loss, test_acc = evaluate(trt_mod, val_dataloader, criterion, 0)
print("Mobilenetv2 QAT accuracy using TensorRT: {:.2f}%".format(100 * test_acc))
benchmark(trt_mod, input_shape=(64, 3, 224, 224))

下記のように結果になります。各モデルの精度とパフォーマンスを比較します。

Mobilenetv2 QAT accuracy using TensorRT: 73.54%
Average batch time: 3.13 ms
ModelAccuracyPerformance
Baseline MobileNetv278.46%24.36 ms
Base + TRT
(TRT FP32)
75.38%10.32 ms
PTQ + TRT
(TRT int8)
72.67%2.36 ms
QAT+TRT
(TRT INT8)
73.54%3.13 ms

2つの方法で量子化を試せました。

Torch TensorRTのPTQとQATの量子化を簡単に試せるぞ

Close Bitnami banner
Bitnami