@@ -49,6 +49,9 @@ def compile(
4949 cuda_graph_batch_size = - 1 ,
5050 is_aten = False ,
5151 use_experimental_fx_rt = False ,
52+ max_aux_streams = None ,
53+ version_compatible = False ,
54+ optimization_level = None ,
5255 num_avg_timing_iters = 1 ,
5356 torch_executed_ops = [],
5457 torch_executed_modules = [],
@@ -68,14 +71,12 @@ def compile(
6871 save_timing_cache: Update timing cache with current timing cache data if set to True.
6972 cuda_graph_batch_size: Cuda graph batch size, default to be -1.
7073 use_experimental_fx_rt: Uses the next generation TRTModule which supports both Python and TorchScript based execution (including in C++).
74+ max_aux_streams: max number of aux stream to use
75+ version_compatible: enable version compatible feature
76+ optimization_level: builder optimization level
7177 Returns:
7278 A torch.nn.Module lowered by TensorRT.
7379 """
74- if use_experimental_fx_rt and not explicit_batch_dimension :
75- raise ValueError (
76- "The experimental unifed runtime only supports explicit batch. Please make sure to set explicit_batch_dimension=True when use_experimental_fx_rt=True"
77- )
78-
7980 logger .warn (
8081 "For ir=fx_ts_compat backend only the "
8182 + "following arguments are supported: "
@@ -123,6 +124,9 @@ def compile(
123124 cuda_graph_batch_size = cuda_graph_batch_size ,
124125 is_aten = is_aten ,
125126 use_experimental_rt = use_experimental_fx_rt ,
127+ max_aux_streams = max_aux_streams ,
128+ version_compatible = version_compatible ,
129+ optimization_level = optimization_level ,
126130 )
127131 lowerer = Lowerer .create (lower_setting = lower_setting )
128132 return lowerer (module , inputs )
@@ -162,8 +166,6 @@ def __call__(self, mod, input, split_name) -> TRTInterpreterResult:
162166 interpreter = TRTInterpreter (
163167 mod ,
164168 input_specs = self .lower_setting .input_specs ,
165- explicit_batch_dimension = self .lower_setting .explicit_batch_dimension ,
166- explicit_precision = self .lower_setting .explicit_precision ,
167169 logger_level = trt .Logger .VERBOSE
168170 if self .lower_setting .debug
169171 else trt .Logger .WARNING ,
@@ -198,7 +200,7 @@ def default_split_function(
198200 model : fx .GraphModule , inputs : Input , lower_setting : LowerSetting
199201) -> SplitResult :
200202 splitter_setting = TRTSplitterSetting ()
201- splitter_setting .use_implicit_batch_dim = not lower_setting . explicit_batch_dimension
203+ splitter_setting .use_implicit_batch_dim = False
202204 splitter_setting .min_block_size = lower_setting .min_block_size
203205 splitter_setting .use_experimental_rt = lower_setting .use_experimental_rt
204206 splitter = TRTSplitter (model , inputs , settings = splitter_setting )
0 commit comments