Skip to content

Commit 4ecbaac

Browse files
author
Wei Wei
committed
rename to --fx-only
1 parent 5c92884 commit 4ecbaac

File tree

1 file changed

+101
-0
lines changed

1 file changed

+101
-0
lines changed

py/torch_tensorrt/fx/compile.py

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
2+
3+
# @manual=//deeplearning/trt/python:py_tensorrt
4+
import tensorrt as trt
5+
import torch
6+
7+
from torch_tensorrt.fx.tracer.acc_tracer import acc_tracer
8+
from torch_tensorrt.fx import InputTensorSpec
9+
from torch_tensorrt.fx import TRTInterpreter
10+
from torch_tensorrt.fx.passes.lower_basic_pass import transform_setitem
11+
from torch_tensorrt.fx.tools.trt_splitter import TRTSplitter
12+
from torch_tensorrt.fx.tools.trt_splitter import TRTSplitterSetting
13+
from torch_tensorrt.fx.trt_module import TRTModule
14+
from torch_tensorrt.fx.utils import LowerPrecision
15+
16+
17+
def compile(module: Any, ir="default", inputs=[], enabled_precisions=set([_enums.dtype.float]), **kwargs):
18+
"""Compile a PyTorch module through fx
19+
20+
Takes a existing PyTorch module and a set of settings to configure the compiler
21+
and using the path specified in ``ir`` lower and compile the module to TensorRT
22+
returning a PyTorch Module back
23+
24+
Converts specifically the forward method of a Module
25+
26+
Arguments:
27+
module (torch.nn.Module): Source module
28+
29+
Keyword Arguments:
30+
inputs (List[torch.Tensor]): for fixed shape scenario, inputs shapes can not change
31+
enabled_precision (torch.dtype): The datatype that TensorRT can use when selecting kernels. If torch.float is chosen, the kernel is running with fp32; If torch.float16 is chosen, the kernel is running with fp16 or fp32 which selected by TensorRT
32+
ir (str): The requested strategy to compile. (default is ts - TorchScript with scripting path, fx is FX based path)
33+
**kwargs: Additional settings for the specific requested strategy (See submodules for more info)
34+
35+
Returns:
36+
torch.nn.Module: Compiled Module, when run it will execute via TensorRT
37+
"""
38+
acc_model = acc_tracer.trace(module, inputs)
39+
40+
splitter_setting = TRTSplitterSetting()
41+
splitter_setting.use_implicit_batch_dim = False
42+
splitter = TRTSplitter(acc_model, inputs, settings=splitter_setting)
43+
splitter.node_support_preview()
44+
split_mod = splitter()
45+
num_piece = 0
46+
for name, _ in split_mod.named_children():
47+
print(f"graph is split into {name}")
48+
num_piece += 1
49+
50+
# if the graph module is split into pieces larger than 8, we consider its perf
51+
# is not good and fall back to non-TRT
52+
if num_piece > 8:
53+
print(
54+
f"The graph module is split into {num_piece} which is large than the \
55+
threshold=8. Fall back to non-TRT module."
56+
)
57+
return None
58+
59+
if torch.float16 in enabled_precisions or torch.half in enabled_precisions:
60+
precision = LowerPrecision.FP16
61+
else:
62+
precision = LowerPrecision.FP32
63+
64+
def get_submod_inputs(mod, submod, inputs):
65+
acc_inputs = None
66+
67+
def get_input(self, inputs):
68+
nonlocal acc_inputs
69+
acc_inputs = inputs
70+
71+
handle = submod.register_forward_pre_hook(get_input)
72+
mod(*inputs)
73+
handle.remove()
74+
return acc_inputs
75+
76+
for name, _ in split_mod.named_children():
77+
if "_run_on_acc" in name:
78+
submod = getattr(split_mod, name)
79+
# Get submodule inputs for fx2trt
80+
acc_inputs = get_submod_inputs(split_mod, submod, inputs)
81+
82+
# fx2trt replacement
83+
interp = TRTInterpreter(
84+
submod,
85+
InputTensorSpec.from_tensors(acc_inputs),
86+
explicit_batch_dimension=True,
87+
)
88+
r = interp.run(
89+
max_workspace_size=20 << 30,
90+
lower_precision=precision,
91+
# profiling_verbosity=trt.ProfilingVerbosity.DETAILED, #For profile
92+
)
93+
# For profile
94+
# from torch_tensorrt.fx.tools.trt_profiler_sorted import profile_trt_module
95+
# profile_trt_module("", trt_mod, acc_inputs)
96+
trt_mod = TRTModule(*r)
97+
98+
setattr(split_mod, name, trt_mod)
99+
else:
100+
submod = getattr(split_mod, name)
101+
return split_mod

0 commit comments

Comments
 (0)