@@ -89,6 +89,70 @@ def test_mapping():
8989 torch ._dynamo .reset ()
9090
9191
92+ @unittest .skipIf (
93+ not torch_trt .ENABLED_FEATURES .torch_tensorrt_runtime ,
94+ "TorchScript Frontend is not available" ,
95+ )
96+ @unittest .skipIf (
97+ not torch_trt .ENABLED_FEATURES .refit ,
98+ "Refit feature is not supported in Python 3.13 or higher" ,
99+ )
100+ @unittest .skipIf (
101+ not importlib .util .find_spec ("torchvision" ),
102+ "torchvision is not installed" ,
103+ )
104+ @pytest .mark .unit
105+ def test_conv_refit_with_weightmap ():
106+ class net (nn .Module ):
107+ def __init__ (self ):
108+ super ().__init__ ()
109+ self .conv = nn .Conv2d (3 , 3 , 1 )
110+
111+ def forward (self , x ):
112+ return self .conv (x )
113+
114+ model = net ().eval ().to ("cuda" )
115+ model2 = net ().eval ().to ("cuda" )
116+ inputs = [torch .randn ((1 , 3 , 224 , 224 )).to ("cuda" )]
117+ enabled_precisions = {torch .float }
118+ min_block_size = 1
119+ use_python_runtime = True
120+
121+ exp_program = torch .export .export (model , tuple (inputs ))
122+ exp_program2 = torch .export .export (model2 , tuple (inputs ))
123+
124+ trt_gm = torchtrt .dynamo .compile (
125+ exp_program ,
126+ tuple (inputs ),
127+ use_python_runtime = use_python_runtime ,
128+ enabled_precisions = enabled_precisions ,
129+ min_block_size = min_block_size ,
130+ immutable_weights = False ,
131+ )
132+
133+ new_trt_gm = refit_module_weights (
134+ compiled_module = trt_gm ,
135+ new_weight_module = exp_program2 ,
136+ arg_inputs = inputs ,
137+ use_weight_map_cache = True ,
138+ verify_output = True ,
139+ )
140+
141+ # Check the output
142+ model2 .to ("cuda" )
143+ expected_outputs , refitted_outputs = exp_program2 .module ()(* inputs ), new_trt_gm (
144+ * inputs
145+ )
146+ for expected_output , refitted_output in zip (expected_outputs , refitted_outputs ):
147+ assertions .assertTrue (
148+ torch .allclose (expected_output , refitted_output , 1e-2 , 1e-2 ),
149+ "Refit Result is not correct. Refit failed" ,
150+ )
151+ # Clean up model env
152+
153+ torch ._dynamo .reset ()
154+
155+
92156@unittest .skipIf (
93157 not torch_trt .ENABLED_FEATURES .torch_tensorrt_runtime ,
94158 "TorchScript Frontend is not available" ,
0 commit comments