33import torch
44import torch .fx
55import torch .nn as nn
6+ from torch_tensorrt .fx .utils import LowerPrecision
67import torch_tensorrt .fx .tracer .acc_tracer .acc_tracer as acc_tracer
78from torch_tensorrt .fx import InputTensorSpec , TRTInterpreter , TRTModule
89from torch_tensorrt .fx .tools .trt_splitter import TRTSplitter
910
10-
1111# The purpose of this example is to demonstrate the overall flow of lowering a PyTorch
1212# model to TensorRT via FX with existing FX based tooling. The general lowering flow
1313# would be like:
@@ -30,11 +30,12 @@ def forward(self, x):
3030 x = self .linear (x )
3131 x = self .relu (x )
3232 x = torch .linalg .norm (x , ord = 2 , dim = 1 )
33+ x = self .relu (x )
3334 return x
3435
3536
36- inputs = [torch .randn (1 , 10 )]
37- model = Model ().eval ()
37+ inputs = [torch .randn (( 1 , 10 ), device = torch . device ( 'cuda' ) )]
38+ model = Model ().cuda (). eval ()
3839
3940# acc_tracer is a custom fx tracer that maps nodes whose targets are PyTorch operators
4041# to acc ops.
@@ -64,20 +65,23 @@ def forward(self, x):
6465# Split.
6566split_mod = splitter ()
6667
67- # After split we have two submodules, _run_on_acc_0 and _run_on_gpu_1.
68+ # After split we have three submodules, _run_on_acc_0 and _run_on_gpu_1.
6869print (split_mod .graph )
6970"""
7071graph():
7172 %x : [#users=1] = placeholder[target=x]
7273 %_run_on_acc_0 : [#users=1] = call_module[target=_run_on_acc_0](args = (%x,), kwargs = {})
7374 %_run_on_gpu_1 : [#users=1] = call_module[target=_run_on_gpu_1](args = (%_run_on_acc_0,), kwargs = {})
74- return _run_on_gpu_1
75+ %_run_on_acc_2 : [#users=1] = call_module[target=_run_on_acc_2](args = (%_run_on_gpu_1,), kwargs = {})
76+ return _run_on_acc_2
7577"""
7678
7779# Take a look at what inside each submodule. _run_on_acc_0 contains linear and relu while
78- # _run_on_gpu_1 contains linalg_norm which currently is not supported by fx2trt.
80+ # _run_on_gpu_1 contains linalg_norm which currently is not supported by fx2trt. _run_on_acc_3
81+ # is the another submodule supported.
7982print (split_mod ._run_on_acc_0 .graph )
8083print (split_mod ._run_on_gpu_1 .graph )
84+ print (split_mod ._run_on_acc_2 .graph )
8185"""
8286graph():
8387 %x : [#users=1] = placeholder[target=x]
@@ -90,32 +94,51 @@ def forward(self, x):
9094 %relu_1 : [#users=1] = placeholder[target=relu_1]
9195 %linalg_norm_1 : [#users=1] = call_function[target=torch_tensorrt.fx.tracer.acc_tracer.acc_ops.linalg_norm](args = (), ...
9296 return linalg_norm_1
97+ graph():
98+ %linalg_norm_1 : [#users=1] = placeholder[target=linalg_norm_1]
99+ %relu_3 : [#users=1] = call_function[target=torch_tensorrt.fx.tracer.acc_tracer.acc_ops.relu](args = (), kwargs = {input: %linalg_norm_1, inplace: False})
100+ return relu_3
93101"""
94102
95- # Now let's lower split_mod._run_on_acc_0. If we know the model can be fully lowered,
96- # we can skip the splitter part.
97- interp = TRTInterpreter (split_mod ._run_on_acc_0 , InputTensorSpec .from_tensors (inputs ))
98- r = interp .run ()
99- trt_mod = TRTModule (r .engine , r .input_names , r .output_names )
100- split_mod ._run_on_acc_0 = trt_mod
101-
102- cuda_inputs = [input .cuda () for input in inputs ]
103- split_mod .cuda ()
104- lowered_model_output = split_mod (* cuda_inputs )
103+ def get_submod_inputs (mod , submod , inputs ):
104+ acc_inputs = None
105+
106+ def get_input (self , inputs ):
107+ nonlocal acc_inputs
108+ acc_inputs = inputs
109+
110+ handle = submod .register_forward_pre_hook (get_input )
111+ mod (* inputs )
112+ handle .remove ()
113+ return acc_inputs
114+
115+ # Since the model is splitted into three segments. We need to lower each TRT eligible segment.
116+ # If we know the model can be fully lowered, we can skip the splitter part.
117+ for name , _ in split_mod .named_children ():
118+ if "_run_on_acc" in name :
119+ submod = getattr (split_mod , name )
120+ # Get submodule inputs for fx2trt
121+ acc_inputs = get_submod_inputs (split_mod , submod , inputs )
122+
123+ # fx2trt replacement
124+ interp = TRTInterpreter (
125+ submod ,
126+ InputTensorSpec .from_tensors (acc_inputs ),
127+ explicit_batch_dimension = True ,
128+ )
129+ r = interp .run (lower_precision = LowerPrecision .FP32 )
130+ trt_mod = TRTModule (* r )
131+ setattr (split_mod , name , trt_mod )
132+
133+ lowered_model_output = split_mod (* inputs )
134+
135+ # Save and load model
136+ torch .save (split_mod , "trt.pt" )
137+ reload_trt_mod = torch .load ("trt.pt" )
138+ reload_model_output = reload_trt_mod (* inputs )
105139
106140# Make sure the results match
107- model .cuda ()
108- regular_model_output = model (* cuda_inputs )
141+ regular_model_output = model (* inputs )
109142torch .testing .assert_close (
110- lowered_model_output , regular_model_output . to ( torch . float16 ) , atol = 3e-3 , rtol = 1e-2
143+ reload_model_output , regular_model_output , atol = 3e-3 , rtol = 1e-2
111144)
112-
113- # We can utilize the trt profiler to print out the time spend on each layer.
114- trt_mod .enable_profiling ()
115- trt_mod (* cuda_inputs )
116- """
117- Reformatting CopyNode for Input Tensor 0 to LayerType.FULLY_CONNECTED_acc_ops.linear_linear_1: 0.027392ms
118- LayerType.FULLY_CONNECTED_acc_ops.linear_linear_1: 0.023072ms
119- PWN(ActivationType.RELU_acc_ops.relu_relu_1): 0.008928ms
120- """
121- trt_mod .disable_profiling ()
0 commit comments