Skip to content

Commit f2066bc

Browse files
authored
added passing of torch2trt_kwargs to conversion context (#482)
* added passing of torch2trt_kwargs to conversion context * added passing of torch2trt_kwargs to conversion context
1 parent 755462a commit f2066bc

File tree

1 file changed

+10
-3
lines changed

1 file changed

+10
-3
lines changed

torch2trt/torch2trt.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -357,12 +357,14 @@ def wrapper(*args, **kwargs):
357357

358358

359359
class ConversionContext(object):
360-
def __init__(self, network, converters=CONVERTERS):
360+
361+
def __init__(self, network, converters=CONVERTERS, torch2trt_kwargs=None):
361362
self.network = LayerNamingNetworkWrapper(self, network)
362363
self.lock = False
363364
self.method_args = None
364365
self.method_kwargs = None
365366
self.method_return = None
367+
self.torch2trt_kwargs = torch2trt_kwargs
366368
self.hooks = [
367369
ConversionHook(self, key, converter)
368370
for key, converter in converters.items()
@@ -487,7 +489,12 @@ def torch2trt(module,
487489
int8_calib_dataset=None,
488490
int8_calib_algorithm=DEFAULT_CALIBRATION_ALGORITHM,
489491
int8_calib_batch_size=1,
490-
use_onnx=False):
492+
use_onnx=False,
493+
**kwargs):
494+
495+
# capture arguments to provide to context
496+
kwargs.update(locals())
497+
kwargs.pop('kwargs')
491498

492499
inputs_in = inputs
493500

@@ -524,7 +531,7 @@ def torch2trt(module,
524531

525532
else:
526533
network = builder.create_network()
527-
with ConversionContext(network) as ctx:
534+
with ConversionContext(network, torch2trt_kwargs=kwargs) as ctx:
528535

529536
ctx.add_inputs(inputs, input_names)
530537

0 commit comments

Comments
 (0)