Torch TensorRTのfxによるTensorRT変換可能なノードの可視化

Torch TensorRTはTensorRTに変換できる部分だけ変換できるみたいですが、どのノードが変換できているか知りたいです。

可視化機能を使用するとどのノードがTensorRTに変換できるか把握できるぞ

準備

Torch TensorRTの機能については下記をご覧ください

Torch TensorRTの環境はdockerを使用すると簡単に用意できます。

https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch

docker run --gpus all -it --rm --name tensorrt nvcr.io/nvidia/pytorch:22.07-py3 

下記コードを参考に実行します。

https://github.com/pytorch/TensorRT/blob/master/examples/fx/fx2trt_example.py

レイヤーの可視化

必要なライブラリをインポートします。

import torch
import torch.fx
import torch.nn as nn
from torch_tensorrt.fx.utils import LowerPrecision
import torch_tensorrt.fx.tracer.acc_tracer.acc_tracer as acc_tracer
from torch_tensorrt.fx import InputTensorSpec, TRTInterpreter, TRTModule
from torch_tensorrt.fx.tools.trt_splitter import TRTSplitter

今回、解析するモデルを定義します。

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(10, 10)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.linear(x)
        x = self.relu(x)
        x = torch.linalg.norm(x, ord=2, dim=1)
        x = self.relu(x)
        return x

サンプルの入力データとモデルのインスタンスを作成します。

inputs = [torch.randn((1, 10), device=torch.device("cuda"))]
model = Model().cuda().eval()

acc_tracerを用いてPyTorch operatorsをトレースします。

traced = acc_tracer.trace(model, inputs)

acc_tracerの結果を確認します。

モデルは線形層のみのシンプルなものです。

GraphModule(
  (linear): Module()
)



def forward(self, x):
    linear_weight = self.linear.weight
    linear_bias = self.linear.bias
    linear_1 = torch_tensorrt_fx_tracer_acc_tracer_acc_ops_linear(input = x, weight = linear_weight, bias = linear_bias);  x = linear_weight = linear_bias = None
    relu_2 = torch_tensorrt_fx_tracer_acc_tracer_acc_ops_relu(input = linear_1, inplace = False);  linear_1 = None
    linalg_norm_1 = torch_tensorrt_fx_tracer_acc_tracer_acc_ops_linalg_norm(input = relu_2, ord = 2, dim = 1, keepdim = False);  relu_2 = None
    relu_3 = torch_tensorrt_fx_tracer_acc_tracer_acc_ops_relu(input = linalg_norm_1, inplace = False);  linalg_norm_1 = None
    return relu_3

TRTSpilitterを用いるとTensorRTを使用可能なノードとそうでないノードを把握できます。

run_on_acc_{}: TensorRTで変換可能なノード

run_on_gpu_{}: TensorRTで変換できないノード

splitter.node_support_previewで可視化しています。

dump_graph=Trueを設定することでgraphvizで可視化可能なファイルを出力しています。

splitter = TRTSplitter(traced, inputs)

splitter.node_support_preview(dump_graph=True)

TensorRTで変換可能なノードとそうでないノードが確認できます。

Supported node types in the model:
acc_ops.linear: ((), {'input': torch.float32, 'weight': torch.float32, 'bias': torch.float32})
acc_ops.relu: ((), {'input': torch.float32})

Unsupported node types in the model:
acc_ops.linalg_norm: ((), {'input': torch.float32})

下記コードでTensorRTを使用可能なノードとそうでないノードで分割されたグラフを確認します。

split_mod = splitter()

print(split_mod.graph)

グラフの情報は下記のように確認できます。

graph():
    %x : [#users=1] = placeholder[target=x]
    %_run_on_acc_0 : [#users=1] = call_module[target=_run_on_acc_0](args = (%x,), kwargs = {})
    %_run_on_gpu_1 : [#users=1] = call_module[target=_run_on_gpu_1](args = (%_run_on_acc_0,), kwargs = {})
    %_run_on_acc_2 : [#users=1] = call_module[target=_run_on_acc_2](args = (%_run_on_gpu_1,), kwargs = {})
    return _run_on_acc_2

各ノードの詳細情報を確認します。

print(split_mod._run_on_acc_0.graph)
print(split_mod._run_on_gpu_1.graph)
print(split_mod._run_on_acc_2.graph)

下記のように各ノードの情報を確認できます。

graph():
    %x : [#users=1] = placeholder[target=x]
    %linear_weight : [#users=1] = get_attr[target=linear.weight]
    %linear_bias : [#users=1] = get_attr[target=linear.bias]
    %linear_1 : [#users=1] = call_function[target=torch_tensorrt.fx.tracer.acc_tracer.acc_ops.linear](args = (), kwargs = {input: %x, weight: %linear_weight, bias: %linear_bias})
    %relu_2 : [#users=1] = call_function[target=torch_tensorrt.fx.tracer.acc_tracer.acc_ops.relu](args = (), kwargs = {input: %linear_1, inplace: False})
    return relu_2
graph():
    %relu_2 : [#users=1] = placeholder[target=relu_2]
    %linalg_norm_1 : [#users=1] = call_function[target=torch_tensorrt.fx.tracer.acc_tracer.acc_ops.linalg_norm](args = (), kwargs = {input: %relu_2, ord: 2, dim: 1, keepdim: False})
    return linalg_norm_1
graph():
    %linalg_norm_1 : [#users=1] = placeholder[target=linalg_norm_1]
    %relu_3 : [#users=1] = call_function[target=torch_tensorrt.fx.tracer.acc_tracer.acc_ops.relu](args = (), kwargs = {input: %linalg_norm_1, inplace: False})
    return relu_3

dotファイルが作成されているので、その結果を確認します。

ホスト側で確認するため、ホストからデータを取得します。

docker cp tensorrt:/workspace/node_support.dot .

graphvizをインストールして可視化できるファイルに変換します。

sudo apt install -y graphviz

dot -Tps node_support.dot -o node_support.ps

緑がTensorRTに変換可能なノード、赤が変換できないノードになります。

可視化できました。

これでボトルネック把握などデバッグがスムーズになるな

Close Bitnami banner
Bitnami