From e5066128fd9d11f3b266507cc39b09820402b1e9 Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Thu, 23 Oct 2025 11:39:33 -0700 Subject: [PATCH 1/6] feat: Change default export behavior to re-export Signed-off-by: Dheeraj Peri --- py/torch_tensorrt/_compile.py | 4 ++-- tests/py/dynamo/models/test_export_serde.py | 20 ++++++++++---------- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/py/torch_tensorrt/_compile.py b/py/torch_tensorrt/_compile.py index f9df21856c..623bad7b9f 100644 --- a/py/torch_tensorrt/_compile.py +++ b/py/torch_tensorrt/_compile.py @@ -606,7 +606,7 @@ def save( inputs: Optional[Sequence[torch.Tensor]] = None, arg_inputs: Optional[Sequence[torch.Tensor]] = None, kwarg_inputs: Optional[dict[str, Any]] = None, - retrace: bool = False, + retrace: bool = True, pickle_protocol: int = 2, **kwargs: Any, ) -> None: @@ -661,7 +661,7 @@ def save( "Input model is of type nn.Module. Saving nn.Module directly is not supported. Supported model types torch.jit.ScriptModule | torch.fx.GraphModule | torch.export.ExportedProgram." ) elif module_type == _ModuleType.ts: - if not all([output_format == f for f in ["exported_program", "aot_inductor"]]): + if not all(output_format == f for f in ["exported_program", "aot_inductor"]): raise ValueError( "Provided model is a torch.jit.ScriptModule but the output_format specified is not torchscript. Other output formats are not supported" ) diff --git a/tests/py/dynamo/models/test_export_serde.py b/tests/py/dynamo/models/test_export_serde.py index d9c2ca3b0b..c499da160e 100644 --- a/tests/py/dynamo/models/test_export_serde.py +++ b/tests/py/dynamo/models/test_export_serde.py @@ -56,7 +56,7 @@ def forward(self, x): exp_program = torchtrt.dynamo.trace(model, **compile_spec) trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec) - torchtrt.save(trt_module, trt_ep_path) + torchtrt.save(trt_module, trt_ep_path, retrace=False) deser_trt_module = torchtrt.load(trt_ep_path).module() # Check Pyt and TRT exported program outputs @@ -111,7 +111,7 @@ def forward(self, x): exp_program = torchtrt.dynamo.trace(model, **compile_spec) trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec) - torchtrt.save(trt_module, trt_ep_path) + torchtrt.save(trt_module, trt_ep_path, retrace=False) deser_trt_module = torchtrt.load(trt_ep_path).module() # Check Pyt and TRT exported program outputs @@ -170,7 +170,7 @@ def forward(self, x): exp_program = torchtrt.dynamo.trace(model, **compile_spec) trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec) - torchtrt.save(trt_module, trt_ep_path) + torchtrt.save(trt_module, trt_ep_path, retrace=False) deser_trt_module = torchtrt.load(trt_ep_path).module() # Check Pyt and TRT exported program outputs @@ -232,7 +232,7 @@ def forward(self, x): exp_program = torchtrt.dynamo.trace(model, **compile_spec) trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec) - torchtrt.save(trt_module, trt_ep_path) + torchtrt.save(trt_module, trt_ep_path, retrace=False) deser_trt_module = torchtrt.load(trt_ep_path).module() outputs_pyt = model(input) @@ -331,7 +331,7 @@ def test_resnet18_cpu_offload(ir): msg="Model should be offloaded to CPU", ) model.cuda() - torchtrt.save(trt_module, trt_ep_path) + torchtrt.save(trt_module, trt_ep_path, retrace=False) deser_trt_module = torchtrt.load(trt_ep_path).module() outputs_pyt = model(input) @@ -380,7 +380,7 @@ def test_resnet18_dynamic(ir): exp_program = torchtrt.dynamo.trace(model, **compile_spec) trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec) - torchtrt.save(trt_module, trt_ep_path) + torchtrt.save(trt_module, trt_ep_path, retrace=False) # TODO: Enable this serialization issues are fixed # deser_trt_module = torchtrt.load(trt_ep_path).module() outputs_pyt = model(input) @@ -413,7 +413,7 @@ def test_resnet18_torch_exec_ops_serde(ir): exp_program = torchtrt.dynamo.trace(model, **compile_spec) trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec) - torchtrt.save(trt_module, trt_ep_path) + torchtrt.save(trt_module, trt_ep_path, retrace=False) deser_trt_module = torchtrt.load(trt_ep_path).module() outputs_pyt = deser_trt_module(input) outputs_trt = trt_module(input) @@ -463,7 +463,7 @@ def forward(self, x): exp_program = torchtrt.dynamo.trace(model, **compile_spec) trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec) - torchtrt.save(trt_module, trt_ep_path) + torchtrt.save(trt_module, trt_ep_path, retrace=False) deser_trt_module = torchtrt.load(trt_ep_path).module() outputs_pyt = model(input) @@ -525,7 +525,7 @@ def forward(self, x): exp_program = torchtrt.dynamo.trace(model, **compile_spec) trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec) model.cuda() - torchtrt.save(trt_module, trt_ep_path) + torchtrt.save(trt_module, trt_ep_path, retrace=False) deser_trt_module = torchtrt.load(trt_ep_path).module() outputs_pyt = model(input) @@ -584,7 +584,7 @@ def forward(self, x): exp_program = torchtrt.dynamo.trace(model, **compile_spec) trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec) - torchtrt.save(trt_module, trt_ep_path) + torchtrt.save(trt_module, trt_ep_path, retrace=False) deser_trt_module = torchtrt.load(trt_ep_path).module() outputs_pyt = model(input) From cc83065acd23229538bc9e72ca9f73176f17294e Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Thu, 23 Oct 2025 13:39:10 -0700 Subject: [PATCH 2/6] chore: update test Signed-off-by: Dheeraj Peri --- tests/py/dynamo/models/test_export_serde.py | 22 ++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/tests/py/dynamo/models/test_export_serde.py b/tests/py/dynamo/models/test_export_serde.py index c499da160e..1a372bb4c1 100644 --- a/tests/py/dynamo/models/test_export_serde.py +++ b/tests/py/dynamo/models/test_export_serde.py @@ -56,7 +56,7 @@ def forward(self, x): exp_program = torchtrt.dynamo.trace(model, **compile_spec) trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec) - torchtrt.save(trt_module, trt_ep_path, retrace=False) + torchtrt.save(trt_module, trt_ep_path, arg_inputs=[input], retrace=False) deser_trt_module = torchtrt.load(trt_ep_path).module() # Check Pyt and TRT exported program outputs @@ -111,7 +111,7 @@ def forward(self, x): exp_program = torchtrt.dynamo.trace(model, **compile_spec) trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec) - torchtrt.save(trt_module, trt_ep_path, retrace=False) + torchtrt.save(trt_module, trt_ep_path, arg_inputs=[input], retrace=False) deser_trt_module = torchtrt.load(trt_ep_path).module() # Check Pyt and TRT exported program outputs @@ -170,7 +170,7 @@ def forward(self, x): exp_program = torchtrt.dynamo.trace(model, **compile_spec) trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec) - torchtrt.save(trt_module, trt_ep_path, retrace=False) + torchtrt.save(trt_module, trt_ep_path, arg_inputs=[input], retrace=False) deser_trt_module = torchtrt.load(trt_ep_path).module() # Check Pyt and TRT exported program outputs @@ -232,7 +232,7 @@ def forward(self, x): exp_program = torchtrt.dynamo.trace(model, **compile_spec) trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec) - torchtrt.save(trt_module, trt_ep_path, retrace=False) + torchtrt.save(trt_module, trt_ep_path, arg_inputs=[input], retrace=False) deser_trt_module = torchtrt.load(trt_ep_path).module() outputs_pyt = model(input) @@ -279,7 +279,7 @@ def test_resnet18(ir): exp_program = torchtrt.dynamo.trace(model, **compile_spec) trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec) - torchtrt.save(trt_module, trt_ep_path) + torchtrt.save(trt_module, trt_ep_path, arg_inputs=[input], retrace=False) deser_trt_module = torchtrt.load(trt_ep_path).module() outputs_pyt = model(input) @@ -331,7 +331,7 @@ def test_resnet18_cpu_offload(ir): msg="Model should be offloaded to CPU", ) model.cuda() - torchtrt.save(trt_module, trt_ep_path, retrace=False) + torchtrt.save(trt_module, trt_ep_path, arg_inputs=[input], retrace=False) deser_trt_module = torchtrt.load(trt_ep_path).module() outputs_pyt = model(input) @@ -380,7 +380,7 @@ def test_resnet18_dynamic(ir): exp_program = torchtrt.dynamo.trace(model, **compile_spec) trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec) - torchtrt.save(trt_module, trt_ep_path, retrace=False) + torchtrt.save(trt_module, trt_ep_path, arg_inputs=[input], retrace=False) # TODO: Enable this serialization issues are fixed # deser_trt_module = torchtrt.load(trt_ep_path).module() outputs_pyt = model(input) @@ -413,7 +413,7 @@ def test_resnet18_torch_exec_ops_serde(ir): exp_program = torchtrt.dynamo.trace(model, **compile_spec) trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec) - torchtrt.save(trt_module, trt_ep_path, retrace=False) + torchtrt.save(trt_module, trt_ep_path, arg_inputs=[input], retrace=False) deser_trt_module = torchtrt.load(trt_ep_path).module() outputs_pyt = deser_trt_module(input) outputs_trt = trt_module(input) @@ -463,7 +463,7 @@ def forward(self, x): exp_program = torchtrt.dynamo.trace(model, **compile_spec) trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec) - torchtrt.save(trt_module, trt_ep_path, retrace=False) + torchtrt.save(trt_module, trt_ep_path, arg_inputs=[input], retrace=False) deser_trt_module = torchtrt.load(trt_ep_path).module() outputs_pyt = model(input) @@ -525,7 +525,7 @@ def forward(self, x): exp_program = torchtrt.dynamo.trace(model, **compile_spec) trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec) model.cuda() - torchtrt.save(trt_module, trt_ep_path, retrace=False) + torchtrt.save(trt_module, trt_ep_path, arg_inputs=[input], retrace=False) deser_trt_module = torchtrt.load(trt_ep_path).module() outputs_pyt = model(input) @@ -584,7 +584,7 @@ def forward(self, x): exp_program = torchtrt.dynamo.trace(model, **compile_spec) trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec) - torchtrt.save(trt_module, trt_ep_path, retrace=False) + torchtrt.save(trt_module, trt_ep_path, arg_inputs=[input], retrace=False) deser_trt_module = torchtrt.load(trt_ep_path).module() outputs_pyt = model(input) From 7be898f6ff3948af53d2f5d35e4033cd4053e101 Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Fri, 24 Oct 2025 12:17:04 -0700 Subject: [PATCH 3/6] chore: fix tests Signed-off-by: Dheeraj Peri --- .../dynamo/models/test_export_kwargs_serde.py | 10 ++++----- tests/py/dynamo/models/test_export_serde.py | 22 +++++++++---------- 2 files changed, 16 insertions(+), 16 deletions(-) diff --git a/tests/py/dynamo/models/test_export_kwargs_serde.py b/tests/py/dynamo/models/test_export_kwargs_serde.py index 70a0fde12f..dabbad3cc8 100644 --- a/tests/py/dynamo/models/test_export_kwargs_serde.py +++ b/tests/py/dynamo/models/test_export_kwargs_serde.py @@ -76,7 +76,7 @@ def forward(self, x, b=5, c=None, d=None): # Save the module trt_ep_path = os.path.join(tempfile.gettempdir(), "compiled.ep") - torchtrt.save(trt_gm, trt_ep_path) + torchtrt.save(trt_gm, trt_ep_path, retrace=False) # Clean up model env torch._dynamo.reset() @@ -138,7 +138,7 @@ def forward(self, x, b=5, c=None, d=None): # Save the module trt_ep_path = os.path.join(tempfile.gettempdir(), "compiled.ep") - torchtrt.save(trt_gm, trt_ep_path) + torchtrt.save(trt_gm, trt_ep_path, retrace=False) # Clean up model env torch._dynamo.reset() @@ -209,7 +209,7 @@ def forward(self, x, b=5, c=None, d=None): # Save the module trt_ep_path = os.path.join(tempfile.gettempdir(), "compiled.ep") - torchtrt.save(trt_gm, trt_ep_path) + torchtrt.save(trt_gm, trt_ep_path, retrace=False) # Clean up model env torch._dynamo.reset() @@ -299,7 +299,7 @@ def forward(self, x, b=None, c=None, d=None, e=[]): ) # Save the module trt_ep_path = os.path.join(tempfile.gettempdir(), "compiled.ep") - torchtrt.save(trt_gm, trt_ep_path) + torchtrt.save(trt_gm, trt_ep_path, retrace=False) # Clean up model env torch._dynamo.reset() @@ -389,7 +389,7 @@ def forward(self, x, b=None, c=None, d=None, e=[]): ) # Save the module trt_ep_path = os.path.join(tempfile.gettempdir(), "compiled.ep") - torchtrt.save(trt_gm, trt_ep_path) + torchtrt.save(trt_gm, trt_ep_path, retrace=False) # Clean up model env torch._dynamo.reset() diff --git a/tests/py/dynamo/models/test_export_serde.py b/tests/py/dynamo/models/test_export_serde.py index 1a372bb4c1..c5b007e34b 100644 --- a/tests/py/dynamo/models/test_export_serde.py +++ b/tests/py/dynamo/models/test_export_serde.py @@ -56,7 +56,7 @@ def forward(self, x): exp_program = torchtrt.dynamo.trace(model, **compile_spec) trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec) - torchtrt.save(trt_module, trt_ep_path, arg_inputs=[input], retrace=False) + torchtrt.save(trt_module, trt_ep_path, retrace=False) deser_trt_module = torchtrt.load(trt_ep_path).module() # Check Pyt and TRT exported program outputs @@ -111,7 +111,7 @@ def forward(self, x): exp_program = torchtrt.dynamo.trace(model, **compile_spec) trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec) - torchtrt.save(trt_module, trt_ep_path, arg_inputs=[input], retrace=False) + torchtrt.save(trt_module, trt_ep_path, retrace=False) deser_trt_module = torchtrt.load(trt_ep_path).module() # Check Pyt and TRT exported program outputs @@ -170,7 +170,7 @@ def forward(self, x): exp_program = torchtrt.dynamo.trace(model, **compile_spec) trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec) - torchtrt.save(trt_module, trt_ep_path, arg_inputs=[input], retrace=False) + torchtrt.save(trt_module, trt_ep_path, retrace=False) deser_trt_module = torchtrt.load(trt_ep_path).module() # Check Pyt and TRT exported program outputs @@ -232,7 +232,7 @@ def forward(self, x): exp_program = torchtrt.dynamo.trace(model, **compile_spec) trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec) - torchtrt.save(trt_module, trt_ep_path, arg_inputs=[input], retrace=False) + torchtrt.save(trt_module, trt_ep_path, retrace=False) deser_trt_module = torchtrt.load(trt_ep_path).module() outputs_pyt = model(input) @@ -279,7 +279,7 @@ def test_resnet18(ir): exp_program = torchtrt.dynamo.trace(model, **compile_spec) trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec) - torchtrt.save(trt_module, trt_ep_path, arg_inputs=[input], retrace=False) + torchtrt.save(trt_module, trt_ep_path, retrace=False) deser_trt_module = torchtrt.load(trt_ep_path).module() outputs_pyt = model(input) @@ -331,7 +331,7 @@ def test_resnet18_cpu_offload(ir): msg="Model should be offloaded to CPU", ) model.cuda() - torchtrt.save(trt_module, trt_ep_path, arg_inputs=[input], retrace=False) + torchtrt.save(trt_module, trt_ep_path, retrace=False) deser_trt_module = torchtrt.load(trt_ep_path).module() outputs_pyt = model(input) @@ -380,7 +380,7 @@ def test_resnet18_dynamic(ir): exp_program = torchtrt.dynamo.trace(model, **compile_spec) trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec) - torchtrt.save(trt_module, trt_ep_path, arg_inputs=[input], retrace=False) + torchtrt.save(trt_module, trt_ep_path, retrace=False) # TODO: Enable this serialization issues are fixed # deser_trt_module = torchtrt.load(trt_ep_path).module() outputs_pyt = model(input) @@ -413,7 +413,7 @@ def test_resnet18_torch_exec_ops_serde(ir): exp_program = torchtrt.dynamo.trace(model, **compile_spec) trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec) - torchtrt.save(trt_module, trt_ep_path, arg_inputs=[input], retrace=False) + torchtrt.save(trt_module, trt_ep_path, retrace=False) deser_trt_module = torchtrt.load(trt_ep_path).module() outputs_pyt = deser_trt_module(input) outputs_trt = trt_module(input) @@ -463,7 +463,7 @@ def forward(self, x): exp_program = torchtrt.dynamo.trace(model, **compile_spec) trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec) - torchtrt.save(trt_module, trt_ep_path, arg_inputs=[input], retrace=False) + torchtrt.save(trt_module, trt_ep_path, retrace=False) deser_trt_module = torchtrt.load(trt_ep_path).module() outputs_pyt = model(input) @@ -525,7 +525,7 @@ def forward(self, x): exp_program = torchtrt.dynamo.trace(model, **compile_spec) trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec) model.cuda() - torchtrt.save(trt_module, trt_ep_path, arg_inputs=[input], retrace=False) + torchtrt.save(trt_module, trt_ep_path, retrace=False) deser_trt_module = torchtrt.load(trt_ep_path).module() outputs_pyt = model(input) @@ -584,7 +584,7 @@ def forward(self, x): exp_program = torchtrt.dynamo.trace(model, **compile_spec) trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec) - torchtrt.save(trt_module, trt_ep_path, arg_inputs=[input], retrace=False) + torchtrt.save(trt_module, trt_ep_path, retrace=False) deser_trt_module = torchtrt.load(trt_ep_path).module() outputs_pyt = model(input) From 7d2e33fbb2e3238d6c363441b21287d2424ab9a2 Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Wed, 29 Oct 2025 01:43:31 -0700 Subject: [PATCH 4/6] chore: fix some tests Signed-off-by: Dheeraj Peri --- py/torch_tensorrt/dynamo/_refit.py | 8 +++++++- tests/py/dynamo/models/test_model_refit.py | 10 ++++++---- tests/py/dynamo/runtime/test_002_lazy_engine_init.py | 4 +++- 3 files changed, 16 insertions(+), 6 deletions(-) diff --git a/py/torch_tensorrt/dynamo/_refit.py b/py/torch_tensorrt/dynamo/_refit.py index 9aae901f87..0cacb5653f 100644 --- a/py/torch_tensorrt/dynamo/_refit.py +++ b/py/torch_tensorrt/dynamo/_refit.py @@ -274,9 +274,15 @@ def refit_module_weights( else: for name, submodule in compiled_module.named_children(): if not isinstance( - submodule, (PythonTorchTensorRTModule, TorchTensorRTModule) + submodule, + ( + PythonTorchTensorRTModule, + TorchTensorRTModule, + torch.nn.modules.module.Module, + ), ): continue + settings = submodule.settings assert settings is not None diff --git a/tests/py/dynamo/models/test_model_refit.py b/tests/py/dynamo/models/test_model_refit.py index 222221e089..2ec66e1367 100644 --- a/tests/py/dynamo/models/test_model_refit.py +++ b/tests/py/dynamo/models/test_model_refit.py @@ -540,8 +540,8 @@ def test_refit_one_engine_inline_runtime_with_weightmap(): min_block_size = 1 use_python_runtime = False - exp_program = torch.export.export(model, tuple(inputs)) - exp_program2 = torch.export.export(model2, tuple(inputs)) + exp_program = torch.export.export(model, tuple(inputs), strict=False) + exp_program2 = torch.export.export(model2, tuple(inputs), strict=False) trt_gm = torchtrt.dynamo.compile( exp_program, @@ -551,8 +551,10 @@ def test_refit_one_engine_inline_runtime_with_weightmap(): min_block_size=min_block_size, immutable_weights=False, ) - torchtrt.save(trt_gm, trt_ep_path) + torchtrt.save(trt_gm, trt_ep_path, arg_inputs=inputs) + trt_gm = torch.export.load(trt_ep_path) + new_trt_gm = refit_module_weights( compiled_module=trt_gm, new_weight_module=exp_program2, @@ -906,7 +908,7 @@ def test_refit_one_engine_inline_runtime_without_weightmap(): min_block_size=min_block_size, immutable_weights=False, ) - torchtrt.save(trt_gm, trt_ep_path) + torchtrt.save(trt_gm, trt_ep_path, arg_inputs=inputs) trt_gm = torch.export.load(trt_ep_path) new_trt_gm = refit_module_weights( compiled_module=trt_gm, diff --git a/tests/py/dynamo/runtime/test_002_lazy_engine_init.py b/tests/py/dynamo/runtime/test_002_lazy_engine_init.py index ca82797090..539c11a303 100644 --- a/tests/py/dynamo/runtime/test_002_lazy_engine_init.py +++ b/tests/py/dynamo/runtime/test_002_lazy_engine_init.py @@ -314,7 +314,9 @@ def test_lazy_engine_init_cpp_serialization(self): trt_mod = torchtrt.compile(model, **compile_spec) with tempfile.TemporaryDirectory() as tmpdir: - torch_tensorrt.save(trt_mod, os.path.join(tmpdir, "tmp_trt_mod.ep")) + torch_tensorrt.save( + trt_mod, os.path.join(tmpdir, "tmp_trt_mod.ep"), arg_inputs=(input,) + ) new_trt_mod = torch.export.load(os.path.join(tmpdir, "tmp_trt_mod.ep")) loaded_trt_mod = new_trt_mod.module() From c4d1e876a691f4a74324a5874fa998c2667c2a82 Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Fri, 31 Oct 2025 14:10:48 -0700 Subject: [PATCH 5/6] chore: updates to refit Signed-off-by: Dheeraj Peri --- py/torch_tensorrt/dynamo/_refit.py | 56 +++++++++++++++---- .../dynamo/lowering/passes/__init__.py | 1 + tests/py/dynamo/models/test_model_refit.py | 4 +- 3 files changed, 48 insertions(+), 13 deletions(-) diff --git a/py/torch_tensorrt/dynamo/_refit.py b/py/torch_tensorrt/dynamo/_refit.py index 0cacb5653f..56cb7e1525 100644 --- a/py/torch_tensorrt/dynamo/_refit.py +++ b/py/torch_tensorrt/dynamo/_refit.py @@ -27,6 +27,7 @@ ) from torch_tensorrt.dynamo.conversion.truncate_double import repair_double_inputs from torch_tensorrt.dynamo.lowering import ( + clean_up_graph_after_modifications, get_decompositions, post_lowering, pre_export_lowering, @@ -272,6 +273,9 @@ def refit_module_weights( compiled_submodules_map[name] = submodule else: + # Handle torch modules + compiled_submodules_map = {} + guard_fn_modules = [] for name, submodule in compiled_module.named_children(): if not isinstance( submodule, @@ -283,7 +287,41 @@ def refit_module_weights( ): continue - settings = submodule.settings + # When we re-export the graph module, torch.export._unlift.GuardsFn modules are being added as submodules. + if isinstance(submodule, torch.export._unlift.GuardsFn): + guard_fn_modules.append(name) + continue + # Obtain the settings + + compiled_submodules = [ + (name.replace("_engine", ""), engine) + for name, engine in submodule.__dict__.items() + if "engine" in name + ] + + encoded_metadata = compiled_submodules[0][1].__getstate__()[0][ + SERIALIZED_METADATA_IDX + ] + assert ( + encoded_metadata != "" + ), "The engine provided is either not refittable or was built with a version of Torch-TensorRT that is too old, please recompile using the latest version" + settings = TorchTensorRTModule.decode_metadata(encoded_metadata)["settings"] + + compiled_submodules_map[name] = submodule + + # Delete the guard fn modules to avoid the guard fn modules being refitted + # First, remove nodes in the graph that reference the guard function modules + for node in list(compiled_module.graph.nodes): + if node.op == "call_module" and node.target in guard_fn_modules: + compiled_module.graph.erase_node(node) + + # Now delete the submodules themselves + for guard_fn_module_name in guard_fn_modules: + # delattr(compiled_module, guard_fn_module_name) + compiled_module.delete_submodule(guard_fn_module_name) + + # Clean up the graph + clean_up_graph_after_modifications(compiled_module) assert settings is not None @@ -417,6 +455,7 @@ def refit_module_weights( ) else: compiled_submodule = getattr(compiled_module, name) + weight_name_map = None if use_weight_map_cache: try: @@ -433,21 +472,16 @@ def refit_module_weights( logger.warning( "This engine does not have a weight map cache. Rebuilding the weight map" ) - if isinstance(compiled_submodule, PythonTorchTensorRTModule): + + # Rexporting the TRT compiled graph module and loading it back doesn't preserve the instance type and registers + # the compiled submodule as torch.nn.Module. So we use settings.use_python_runtime to determine the instance type. + if settings.use_python_runtime: engine = compiled_submodule.engine - elif isinstance(compiled_submodule, TorchTensorRTModule): + else: engine_info = compiled_submodule.engine.__getstate__()[0] engine = get_engine_from_encoded_engine( engine_info[ENGINE_IDX], runtime ) - elif isinstance(compiled_submodule, torch.fx.graph_module.GraphModule): - # This is graph break resulted by unsupported ops - compiled_submodule.load_state_dict(new_submodule.state_dict()) - continue - else: - raise AssertionError( - "The type of graph module is not supported for refitting." - ) except AttributeError: raise AssertionError( "The type of graph module is not supported for refitting or two compiled modules do not match." diff --git a/py/torch_tensorrt/dynamo/lowering/passes/__init__.py b/py/torch_tensorrt/dynamo/lowering/passes/__init__.py index c0e2803e60..c980224869 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/__init__.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/__init__.py @@ -1,3 +1,4 @@ from ._aten_lowering_pass import * +from .pass_utils import clean_up_graph_after_modifications from .remove_sym_nodes import remove_sym_nodes from .repair_input_aliasing import repair_input_aliasing diff --git a/tests/py/dynamo/models/test_model_refit.py b/tests/py/dynamo/models/test_model_refit.py index 2ec66e1367..e6b7f6e2a4 100644 --- a/tests/py/dynamo/models/test_model_refit.py +++ b/tests/py/dynamo/models/test_model_refit.py @@ -551,8 +551,7 @@ def test_refit_one_engine_inline_runtime_with_weightmap(): min_block_size=min_block_size, immutable_weights=False, ) - torchtrt.save(trt_gm, trt_ep_path, arg_inputs=inputs) - + torchtrt.save(trt_gm, trt_ep_path, arg_inputs=inputs, retrace=True) trt_gm = torch.export.load(trt_ep_path) new_trt_gm = refit_module_weights( @@ -567,6 +566,7 @@ def test_refit_one_engine_inline_runtime_with_weightmap(): expected_outputs, refitted_outputs = exp_program2.module()(*inputs), new_trt_gm( *inputs ) + for expected_output, refitted_output in zip(expected_outputs, refitted_outputs): assertions.assertTrue( torch.allclose(expected_output, refitted_output, 1e-2, 1e-2), From f3c086681e80a2dbfcecd0f95143b255caa6a179 Mon Sep 17 00:00:00 2001 From: cehongwang Date: Tue, 4 Nov 2025 23:44:40 +0000 Subject: [PATCH 6/6] Fixed the CI issue --- py/torch_tensorrt/dynamo/_refit.py | 69 ++++++++++++++++++++++-------- 1 file changed, 52 insertions(+), 17 deletions(-) diff --git a/py/torch_tensorrt/dynamo/_refit.py b/py/torch_tensorrt/dynamo/_refit.py index 56cb7e1525..2dcc75bcb7 100644 --- a/py/torch_tensorrt/dynamo/_refit.py +++ b/py/torch_tensorrt/dynamo/_refit.py @@ -95,6 +95,8 @@ def construct_refit_mapping_from_weight_name_map( engine_weight_map = {} for engine_weight_name, (sd_weight_name, np_weight_type) in weight_name_map.items(): # Add more constant folding converters here + trt_dtype = dtype._from(np_weight_type).to(trt.DataType) + torch_dtype = dtype._from(np_weight_type).to(torch.dtype) if engine_weight_name.split(" ")[-1] in ["SCALE", "SHIFT"]: # Batch Norm Layer params = {} @@ -107,12 +109,12 @@ def construct_refit_mapping_from_weight_name_map( engine_weight_map[engine_weight_name] = eval( engine_weight_name.split(" ")[-1].lower() ) + elif sd_weight_name not in state_dict: # If weights is not in sd, we can leave it unchanged continue else: - trt_dtype = dtype._from(np_weight_type).to(trt.DataType) - torch_dtype = dtype._from(np_weight_type).to(torch.dtype) + engine_weight_map[engine_weight_name] = state_dict[sd_weight_name].to( to_torch_device(settings.device) ) @@ -277,13 +279,16 @@ def refit_module_weights( compiled_submodules_map = {} guard_fn_modules = [] for name, submodule in compiled_module.named_children(): - if not isinstance( - submodule, - ( - PythonTorchTensorRTModule, - TorchTensorRTModule, - torch.nn.modules.module.Module, - ), + if ( + not isinstance( + submodule, + ( + PythonTorchTensorRTModule, + TorchTensorRTModule, + torch.nn.modules.module.Module, + ), + ) + or "_run_on_gpu" in name ): continue @@ -299,13 +304,21 @@ def refit_module_weights( if "engine" in name ] - encoded_metadata = compiled_submodules[0][1].__getstate__()[0][ - SERIALIZED_METADATA_IDX - ] - assert ( - encoded_metadata != "" - ), "The engine provided is either not refittable or was built with a version of Torch-TensorRT that is too old, please recompile using the latest version" - settings = TorchTensorRTModule.decode_metadata(encoded_metadata)["settings"] + settings = None + try: + # If the gm is not inlined or transformed by retracing, the settings is stored in the submodule + settings = submodule.settings + except AttributeError: + + encoded_metadata = [ + engine for name, engine in compiled_submodules if name == "engine" + ][0].__getstate__()[0][SERIALIZED_METADATA_IDX] + assert ( + encoded_metadata != "" + ), "The engine provided is either not refittable or was built with a version of Torch-TensorRT that is too old, please recompile using the latest version" + settings = TorchTensorRTModule.decode_metadata(encoded_metadata)[ + "settings" + ] compiled_submodules_map[name] = submodule @@ -455,12 +468,29 @@ def refit_module_weights( ) else: compiled_submodule = getattr(compiled_module, name) + if "_run_on_acc" not in name: + compiled_submodule.load_state_dict(new_submodule.state_dict()) + continue weight_name_map = None if use_weight_map_cache: try: weight_name_map = compiled_submodule.weight_name_map except AttributeError: + if isinstance(compiled_submodule, torch.nn.Module): + # Torch retrace module + assert ( + not settings.use_python_runtime + ), "Refitting a torch retraced module is only supported with use_python_runtime=False" + encoded_metadata = [ + engine + for name, engine in compiled_submodules + if name == "engine" + ][0].__getstate__()[0][SERIALIZED_METADATA_IDX] + weight_name_map = TorchTensorRTModule.decode_metadata( + encoded_metadata + )["weight_name_map"] + if not isinstance( compiled_submodule, torch.fx.graph_module.GraphModule ): @@ -540,7 +570,12 @@ def refit_module_weights( new_engine_info[ENGINE_IDX] = bytes(serialized_engine) refitted_engine = torch.classes.tensorrt.Engine(tuple(new_engine_info)) setattr(compiled_module, f"{name}_engine", refitted_engine) - + elif isinstance(compiled_submodule, torch.nn.Module): + # Torch retrace module + new_engine_info = list(engine_info) + new_engine_info[ENGINE_IDX] = bytes(serialized_engine) + refitted_engine = torch.classes.tensorrt.Engine(tuple(new_engine_info)) + compiled_submodule.engine = refitted_engine del engine gc.collect() torch.cuda.empty_cache()