55import torch
66from torch import fx
77from enum import Enum
8-
8+ from torch_tensorrt import fx
99
1010class _IRType (Enum ):
1111 """Enum to set the minimum required logging level to print a message to stdout
@@ -43,13 +43,7 @@ def _get_target_ir(module_type: _ModuleType, ir: str) -> _IRType:
4343 if module_is_tsable and ir_targets_torchscript :
4444 return _IRType .ts
4545 elif module_is_fxable and ir_targets_fx :
46- if module_type == _ModuleType .fx :
47- raise ValueError ("Was given a torch.fx.GraphModule, fx is not currently supported by Torch-TensorRT" )
48- elif ir_targets_fx :
49- raise ValueError ("Preferred ir was set to \" fx\" which is currently not supported by Torch-TensorRT" )
50- else :
51- raise ValueError ("Torch-TensorRT currently does not support fx" )
52- # return _IRType.fx
46+ return _IRType .fx
5347 else :
5448 if ir == "default" :
5549 # Options are listed in order of preference
@@ -114,7 +108,78 @@ def compile(module: Any, ir="default", inputs=[], enabled_precisions=set([_enums
114108 ts_mod = torch .jit .script (module )
115109 return torch_tensorrt .ts .compile (ts_mod , inputs = inputs , enabled_precisions = enabled_precisions , ** kwargs )
116110 elif target_ir == _IRType .fx :
117- raise RuntimeError ("fx is currently not supported" )
111+ from torch_tensorrt .fx .tracer .acc_tracer import acc_tracer
112+ from torch_tensorrt .fx import InputTensorSpec
113+ from torch_tensorrt .fx import TRTInterpreter
114+ from torch_tensorrt .fx .passes .lower_basic_pass import transform_setitem
115+ from torch_tensorrt .fx .tools .trt_splitter import TRTSplitter
116+ from torch_tensorrt .fx .tools .trt_splitter import TRTSplitterSetting
117+ from torch_tensorrt .fx .trt_module import TRTModule
118+ from torch_tensorrt .fx .utils import LowerPrecision
119+ acc_model = acc_tracer .trace (module , inputs )
120+
121+ splitter_setting = TRTSplitterSetting ()
122+ splitter_setting .use_implicit_batch_dim = False
123+ splitter = TRTSplitter (acc_model , inputs , settings = splitter_setting )
124+ splitter .node_support_preview ()
125+ split_mod = splitter ()
126+ num_piece = 0
127+ for name , _ in split_mod .named_children ():
128+ print (f"graph is split into { name } " )
129+ num_piece += 1
130+
131+ # if the graph module is split into pieces larger than 8, we consider its perf
132+ # is not good and fall back to non-TRT
133+ if num_piece > 8 :
134+ print (
135+ f"The graph module is split into { num_piece } which is large than the \
136+ threshold=8. Fall back to non-TRT module."
137+ )
138+ return None
139+
140+ if torch .float16 in enabled_precisions or torch .half in enabled_precisions :
141+ precision = LowerPrecision .FP16
142+ else :
143+ precision = LowerPrecision .FP32
144+
145+ def get_submod_inputs (mod , submod , inputs ):
146+ acc_inputs = None
147+
148+ def get_input (self , inputs ):
149+ nonlocal acc_inputs
150+ acc_inputs = inputs
151+
152+ handle = submod .register_forward_pre_hook (get_input )
153+ mod (* inputs )
154+ handle .remove ()
155+ return acc_inputs
156+
157+ for name , _ in split_mod .named_children ():
158+ if "_run_on_acc" in name :
159+ submod = getattr (split_mod , name )
160+ # Get submodule inputs for fx2trt
161+ acc_inputs = get_submod_inputs (split_mod , submod , inputs )
162+
163+ # fx2trt replacement
164+ interp = TRTInterpreter (
165+ submod ,
166+ InputTensorSpec .from_tensors (acc_inputs ),
167+ explicit_batch_dimension = True ,
168+ )
169+ r = interp .run (
170+ max_workspace_size = 20 << 30 ,
171+ lower_precision = precision ,
172+ # profiling_verbosity=trt.ProfilingVerbosity.DETAILED, #For profile
173+ )
174+ # For profile
175+ # from fx2trt_oss.fx.tools.trt_profiler_sorted import profile_trt_module
176+ # profile_trt_module("", trt_mod, acc_inputs)
177+ trt_mod = TRTModule (* r )
178+
179+ setattr (split_mod , name , trt_mod )
180+ else :
181+ submod = getattr (split_mod , name )
182+ return split_mod
118183 else :
119184 raise RuntimeError ("Module is an unknown format or the ir requested is unknown" )
120185
@@ -173,4 +238,4 @@ def convert_method_to_trt_engine(module: Any,
173238 elif target_ir == _IRType .fx :
174239 raise RuntimeError ("fx is currently not supported" )
175240 else :
176- raise RuntimeError ("Module is an unknown format or the ir requested is unknown" )
241+ raise RuntimeError ("Module is an unknown format or the ir requested is unknown" )
0 commit comments