33import collections .abc
44import copy
55import logging
6- from typing import Any , Optional , Sequence , Tuple
6+ from typing import Any , List , Optional , Sequence , Tuple
77
88import numpy as np
99import tensorrt as trt
1313from torch_tensorrt ._Input import Input
1414from torch_tensorrt .dynamo import partitioning
1515from torch_tensorrt .dynamo ._exporter import inline_torch_modules
16- from torch_tensorrt .dynamo .conversion import CompilationSettings
16+ from torch_tensorrt .dynamo ._settings import CompilationSettings
1717from torch_tensorrt .dynamo .conversion ._conversion import infer_module_output_dtypes
1818from torch_tensorrt .dynamo .conversion ._ConverterRegistry import (
1919 DYNAMO_CONVERTERS as CONVERTERS ,
@@ -108,38 +108,97 @@ def construct_refit_mapping(
108108 return weight_map
109109
110110
111+ def construct_refit_mapping_from_weight_name_map (
112+ weight_name_map : dict [Any , Any ], state_dict : dict [Any , Any ]
113+ ) -> dict [Any , Any ]:
114+ engine_weight_map = {}
115+ for engine_weight_name , (sd_weight_name , np_weight_type ) in weight_name_map .items ():
116+ trt_dtype = dtype .try_from (np_weight_type ).to (trt .DataType )
117+ torch_dtype = dtype .try_from (np_weight_type ).to (torch .dtype )
118+ if engine_weight_name .split (" " )[- 1 ] in ["SCALE" , "SHIFT" ]:
119+ # Batch Norm Layer
120+ params = {}
121+ for w in sd_weight_name :
122+ params [w .split ("." )[- 1 ]] = state_dict [w ]
123+ scale = params ["weight" ] / torch .sqrt (params ["running_var" ] + 1e-7 )
124+ shift = params ["bias" ] - params ["running_mean" ] * scale
125+ # Set scale to scale or shift to shift
126+ engine_weight_map [engine_weight_name ] = eval (
127+ engine_weight_name .split (" " )[- 1 ].lower ()
128+ )
129+
130+ elif sd_weight_name not in state_dict :
131+ # If weights is not in sd, we can leave it unchanged
132+ continue
133+ else :
134+ engine_weight_map [engine_weight_name ] = state_dict [sd_weight_name ]
135+
136+ engine_weight_map [engine_weight_name ] = (
137+ engine_weight_map [engine_weight_name ]
138+ .clone ()
139+ .reshape (- 1 )
140+ .contiguous ()
141+ .to (torch_dtype ),
142+ trt_dtype ,
143+ )
144+
145+ return engine_weight_map
146+
147+
111148def _refit_single_trt_engine_with_gm (
112149 new_gm : torch .fx .GraphModule ,
113150 old_engine : trt .ICudaEngine ,
114- input_list : Tuple [Any , ... ],
151+ input_list : Sequence [Any ],
115152 settings : CompilationSettings = CompilationSettings (),
153+ weight_name_map : Optional [dict [str , List [str ]]] = None ,
116154) -> None :
117155 """
118156 Refit a TensorRT Engine in place
119157 """
120- # Get the refitting mapping
121- mapping = construct_refit_mapping (new_gm , input_list , settings )
158+
122159 refitted = set ()
123160
124- trt_wt_location = trt .TensorLocation .HOST
125161 refitter = trt .Refitter (old_engine , TRT_LOGGER )
126162 weight_list = refitter .get_all_weights ()
127163
128- for layer_name in weight_list :
129- if layer_name not in mapping :
130- raise AssertionError (f"{ layer_name } is not found in weight mapping" )
131- # Use Numpy to create weights
132- weight , datatype = mapping [layer_name ]
133- trt_wt_tensor = trt .Weights (datatype , weight .ctypes .data , weight .size )
134- refitter .set_named_weights (layer_name , trt_wt_tensor , trt_wt_location )
135- refitted .add (layer_name )
164+ if weight_name_map :
165+ # Get the refitting mapping
166+ trt_wt_location = trt .TensorLocation .DEVICE
167+ mapping = construct_refit_mapping_from_weight_name_map (
168+ weight_name_map , new_gm .state_dict ()
169+ )
170+ for layer_name in weight_list :
171+ if layer_name not in mapping :
172+ logger .warning (f"{ layer_name } is not found in weight mapping." )
173+ continue
174+ # Use Numpy to create weights
175+ weight , weight_dtype = mapping [layer_name ]
176+ trt_wt_tensor = trt .Weights (
177+ weight_dtype , weight .data_ptr (), torch .numel (weight )
178+ )
179+ refitter .set_named_weights (layer_name , trt_wt_tensor , trt_wt_location )
180+ assert (
181+ len (refitter .get_missing_weights ()) == 0
182+ ), "Fast refitting failed due to incomplete mapping"
136183
137- if len (refitted ) != len (weight_list ):
138- logger .warning ("Not all weights have been refitted!!!" )
184+ else :
185+ mapping = construct_refit_mapping (new_gm , input_list , settings )
186+ trt_wt_location = trt .TensorLocation .HOST
187+ for layer_name in weight_list :
188+ if layer_name not in mapping :
189+ raise AssertionError (f"{ layer_name } is not found in weight mapping" )
190+ # Use Numpy to create weights
191+ weight , datatype = mapping [layer_name ]
192+ trt_wt_tensor = trt .Weights (datatype , weight .ctypes .data , weight .size )
193+ refitter .set_named_weights (layer_name , trt_wt_tensor , trt_wt_location )
194+ refitted .add (layer_name )
195+
196+ if len (refitted ) != len (weight_list ):
197+ logger .warning ("Not all weights have been refitted!!!" )
139198
140199 if not refitter .refit_cuda_engine ():
141200 logger .error ("Error: failed to refit new weights." )
142- exit ( 0 )
201+ raise AssertionError ( "Refitting failed." )
143202
144203
145204def refit_module_weights (
@@ -148,6 +207,8 @@ def refit_module_weights(
148207 arg_inputs : Optional [Tuple [Any , ...]] = None ,
149208 kwarg_inputs : Optional [dict [str , Any ]] = None ,
150209 verify_output : bool = False ,
210+ use_weight_map_cache : bool = True ,
211+ in_place : bool = False ,
151212) -> torch .fx .GraphModule :
152213 """
153214 Refit a compiled graph module with ExportedProgram. This performs weight updates in compiled_module without recompiling the engine.
@@ -170,7 +231,12 @@ def refit_module_weights(
170231 if len (list (compiled_module .named_children ())) == 0 :
171232 inline_module = True
172233
173- compiled_module = copy .deepcopy (compiled_module )
234+ if not in_place :
235+ compiled_module = copy .deepcopy (compiled_module )
236+ elif inline_module :
237+ raise AssertionError (
238+ "Exported program does not support modifying in place. Please set inplace to false and use the returned graph module."
239+ )
174240
175241 # Get the settings and check the setting to be uniform
176242 settings : CompilationSettings = None
@@ -182,13 +248,14 @@ def refit_module_weights(
182248 for name , engine in compiled_module .__dict__ .items ()
183249 if "engine" in name
184250 ]
185- encoded_settings = compiled_submodules [0 ][1 ].__getstate__ ()[0 ][
251+ # [('_run_on_acc_0', inline_module)]
252+ encoded_metadata = compiled_submodules [0 ][1 ].__getstate__ ()[0 ][
186253 SERIALIZED_METADATA_IDX
187254 ]
188255 assert (
189- encoded_settings != ""
190- ), "Settings are not saved in the engine. Please recompile the engine with make_refitable=True. "
191- settings = TorchTensorRTModule .decode_metadata (encoded_settings )
256+ encoded_metadata != ""
257+ ), "The engine provided is either not refittable or was built with a version of Torch-TensorRT that is too old, please recompile using the latest version with make_refitable=True"
258+ settings = TorchTensorRTModule .decode_metadata (encoded_metadata )[ "settings" ]
192259 # Handle torch modules
193260 compiled_submodules_map = dict (compiled_submodules )
194261 for name , submodule in compiled_module .named_children ():
@@ -287,6 +354,7 @@ def refit_module_weights(
287354 # Extract engine from the submodule
288355 try :
289356 if inline_module :
357+ weight_name_map = None
290358 compiled_submodule = compiled_submodules_map [name ]
291359 # If this is a torch module, load the old state_dict
292360 if "_run_on_acc" not in name :
@@ -297,8 +365,33 @@ def refit_module_weights(
297365 engine = get_engine_from_encoded_engine (
298366 engine_info [ENGINE_IDX ], runtime
299367 )
368+ if use_weight_map_cache :
369+ encoded_metadata = compiled_submodule .__getstate__ ()[0 ][
370+ SERIALIZED_METADATA_IDX
371+ ]
372+ weight_name_map = TorchTensorRTModule .decode_metadata (
373+ encoded_metadata
374+ )["weight_name_map" ]
375+ if not weight_name_map :
376+ use_weight_map_cache = False
377+ logger .warning (
378+ "This engine does not have a weight map cache. Rebuilding the weight map"
379+ )
300380 else :
301381 compiled_submodule = getattr (compiled_module , name )
382+ weight_name_map = None
383+ if use_weight_map_cache :
384+ try :
385+ weight_name_map = compiled_submodule .weight_name_map
386+ except AttributeError :
387+ logger .warning (
388+ "The module was compiled with an old version of Torch-TensorRT. Rebuilding the weight map."
389+ )
390+ if not weight_name_map :
391+ use_weight_map_cache = False
392+ logger .warning (
393+ "This engine does not have a weight map cache. Rebuilding the weight map"
394+ )
302395 if isinstance (compiled_submodule , PythonTorchTensorRTModule ):
303396 engine = compiled_submodule .engine
304397 elif isinstance (compiled_submodule , TorchTensorRTModule ):
@@ -335,13 +428,25 @@ def refit_module_weights(
335428 to_torch_device (settings .device ),
336429 name ,
337430 )
338-
339- _refit_single_trt_engine_with_gm (
340- new_gm = new_submodule ,
341- old_engine = engine ,
342- input_list = submodule_inputs ,
343- settings = settings ,
344- )
431+ try :
432+ _refit_single_trt_engine_with_gm (
433+ new_gm = new_submodule ,
434+ old_engine = engine ,
435+ input_list = submodule_inputs ,
436+ settings = settings ,
437+ weight_name_map = weight_name_map ,
438+ )
439+ except AssertionError as e :
440+ # If fast_refit is used and failed, we fall back to regular refit
441+ logger .warning (e )
442+ if use_weight_map_cache and weight_name_map :
443+ _refit_single_trt_engine_with_gm (
444+ new_gm = new_submodule ,
445+ old_engine = engine ,
446+ input_list = submodule_inputs ,
447+ settings = settings ,
448+ weight_name_map = None ,
449+ )
345450
346451 if isinstance (compiled_submodule , TorchTensorRTModule ):
347452 serialized_engine = bytes (engine .serialize ())
0 commit comments