@@ -357,12 +357,14 @@ def wrapper(*args, **kwargs):
357357
358358
359359class ConversionContext (object ):
360- def __init__ (self , network , converters = CONVERTERS ):
360+
361+ def __init__ (self , network , converters = CONVERTERS , torch2trt_kwargs = None ):
361362 self .network = LayerNamingNetworkWrapper (self , network )
362363 self .lock = False
363364 self .method_args = None
364365 self .method_kwargs = None
365366 self .method_return = None
367+ self .torch2trt_kwargs = torch2trt_kwargs
366368 self .hooks = [
367369 ConversionHook (self , key , converter )
368370 for key , converter in converters .items ()
@@ -487,7 +489,12 @@ def torch2trt(module,
487489 int8_calib_dataset = None ,
488490 int8_calib_algorithm = DEFAULT_CALIBRATION_ALGORITHM ,
489491 int8_calib_batch_size = 1 ,
490- use_onnx = False ):
492+ use_onnx = False ,
493+ ** kwargs ):
494+
495+ # capture arguments to provide to context
496+ kwargs .update (locals ())
497+ kwargs .pop ('kwargs' )
491498
492499 inputs_in = inputs
493500
@@ -524,7 +531,7 @@ def torch2trt(module,
524531
525532 else :
526533 network = builder .create_network ()
527- with ConversionContext (network ) as ctx :
534+ with ConversionContext (network , torch2trt_kwargs = kwargs ) as ctx :
528535
529536 ctx .add_inputs (inputs , input_names )
530537
0 commit comments