11import logging
22import warnings
33from datetime import datetime
4+ from packaging import version
45from typing import Any , Callable , Dict , List , NamedTuple , Optional , Sequence
56
67import numpy
@@ -40,6 +41,7 @@ def __init__(
4041 explicit_batch_dimension : bool = True ,
4142 explicit_precision : bool = False ,
4243 logger_level = None ,
44+ output_dtypes = None ,
4345 ):
4446 super ().__init__ (module )
4547
@@ -78,6 +80,9 @@ def __init__(
7880 trt .tensorrt .ITensor , TensorMetadata
7981 ] = dict ()
8082
83+ # Data types for TRT Module output Tensors
84+ self .output_dtypes = output_dtypes
85+
8186 def validate_input_specs (self ):
8287 for shape , _ , _ , shape_ranges , has_batch_dim in self .input_specs :
8388 if not self .network .has_implicit_batch_dimension :
@@ -178,13 +183,17 @@ def run(
178183 algorithm_selector: set up algorithm selection for certain layer
179184 timing_cache: enable timing cache for TensorRT
180185 profiling_verbosity: TensorRT logging level
186+ max_aux_streams: Maximum number of allowed auxiliary TRT streams for each engine
187+ version_compatible: Provide version forward-compatibility for engine plan files
188+ optimization_level: Builder optimization 0-5, higher levels imply longer build time,
189+ searching for more optimization options. TRT defaults to 3
181190 Return:
182191 TRTInterpreterResult
183192 """
184193 TRT_INTERPRETER_CALL_PRE_OBSERVER .observe (self .module )
185194
186195 # For float outputs, we set their dtype to fp16 only if lower_precision == LowerPrecision.FP16 and
187- # force_fp32_output=False.
196+ # force_fp32_output=False. Overriden by specifying output_dtypes
188197 self .output_fp16 = (
189198 not force_fp32_output and lower_precision == LowerPrecision .FP16
190199 )
@@ -224,14 +233,14 @@ def run(
224233 cache = builder_config .create_timing_cache (b"" )
225234 builder_config .set_timing_cache (cache , False )
226235
227- if trt .__version__ >= "8.2" :
236+ if version . parse ( trt .__version__ ) >= version . parse ( "8.2" ) :
228237 builder_config .profiling_verbosity = (
229238 profiling_verbosity
230239 if profiling_verbosity
231240 else trt .ProfilingVerbosity .LAYER_NAMES_ONLY
232241 )
233242
234- if trt .__version__ >= "8.6" :
243+ if version . parse ( trt .__version__ ) >= version . parse ( "8.6" ) :
235244 if max_aux_streams is not None :
236245 _LOGGER .info (f"Setting max aux streams to { max_aux_streams } " )
237246 builder_config .max_aux_streams = max_aux_streams
@@ -372,6 +381,11 @@ def output(self, target, args, kwargs):
372381 if not all (isinstance (output , trt .tensorrt .ITensor ) for output in outputs ):
373382 raise RuntimeError ("TensorRT requires all outputs to be Tensor!" )
374383
384+ if self .output_dtypes is not None and len (self .output_dtypes ) != len (outputs ):
385+ raise RuntimeError (
386+ f"Specified output dtypes ({ len (self .output_dtypes )} ) differ from number of outputs ({ len (outputs )} )"
387+ )
388+
375389 for i , output in enumerate (outputs ):
376390 if any (
377391 op_name in output .name .split ("_" )
@@ -396,6 +410,8 @@ def output(self, target, args, kwargs):
396410 self .network .mark_output (output )
397411 if output_bool :
398412 output .dtype = trt .bool
413+ elif self .output_dtypes is not None :
414+ output .dtype = torch_dtype_to_trt (self .output_dtypes [i ])
399415 elif self .output_fp16 and output .dtype == trt .float32 :
400416 output .dtype = trt .float16
401417 self ._output_names .append (name )
0 commit comments