diff --git a/torch2trt/torch2trt.py b/torch2trt/torch2trt.py index 6a33a9ee..e07d37a7 100644 --- a/torch2trt/torch2trt.py +++ b/torch2trt/torch2trt.py @@ -522,18 +522,16 @@ def torch2trt(module, inputs = tuple(inputs) if not isinstance(inputs, tuple): inputs = (inputs,) - - # run once to get num outputs - outputs = module(*inputs) - if not isinstance(outputs, tuple) and not isinstance(outputs, list): - outputs = (outputs,) - if input_names is None: input_names = default_input_names(len(inputs)) - if output_names is None: - output_names = default_output_names(len(outputs)) - + if use_onnx: + # run once to get num outputs + outputs = module(*inputs) + if not isinstance(outputs, tuple) and not isinstance(outputs, list): + outputs = (outputs,) + if output_names is None: + output_names = default_output_names(len(outputs)) f = io.BytesIO() torch.onnx.export(module, inputs, f, input_names=input_names, output_names=output_names) @@ -553,6 +551,8 @@ def torch2trt(module, if not isinstance(outputs, tuple) and not isinstance(outputs, list): outputs = (outputs,) + if output_names is None: + output_names = default_output_names(len(outputs)) ctx.mark_outputs(outputs, output_names) builder.max_workspace_size = max_workspace_size @@ -575,7 +575,11 @@ def torch2trt(module, inputs, int8_calib_dataset, batch_size=int8_calib_batch_size, algorithm=int8_calib_algorithm ) + del inputs + del outputs + torch.cuda.empty_cache() engine = builder.build_cuda_engine(network) + torch.cuda.empty_cache() module_trt = TRTModule(engine, input_names, output_names)