1515from torch_tensorrt .dynamo .utils import (
1616 get_cpu_memory_usage ,
1717 get_output_dtypes ,
18- release_memory ,
18+ release_host_and_device_memory ,
1919)
2020
2121logger = logging .getLogger (__name__ )
@@ -39,7 +39,7 @@ def infer_module_output_dtypes(
3939 """
4040 outputs = [node for node in module .graph .nodes if node .op == "output" ]
4141 outputs = outputs [0 ].args
42- return get_output_dtypes (outputs , truncate_double )
42+ return get_output_dtypes (outputs , truncate_double ) # type: ignore[no-any-return]
4343
4444
4545def interpret_module_to_result (
@@ -60,8 +60,9 @@ def interpret_module_to_result(
6060 settings: Compilation settings
6161 engine_cache: Engine cache instance
6262 Returns:
63- TRTInterpreterResult
63+ SerializedInterpreterResult
6464 """
65+
6566 output_dtypes = infer_module_output_dtypes (
6667 module , truncate_double = settings .truncate_double
6768 )
@@ -80,7 +81,7 @@ def interpret_module_to_result(
8081 for attr in dir (module ):
8182 if attr .startswith ("_frozen_param" ):
8283 delattr (module , attr )
83- release_memory ()
84+ release_host_and_device_memory ()
8485 logger .debug (
8586 f"CPU memory usage after clearing frozen parameters and building memory in conversion: { get_cpu_memory_usage ()} MB"
8687 )
@@ -92,6 +93,27 @@ def interpret_module_to_result(
9293 logger .debug (
9394 f"CPU memory usage after serializing engine: { get_cpu_memory_usage ()} MB"
9495 )
96+
97+ # Engine caching only for refittable engines
98+ if (
99+ not settings .immutable_weights
100+ and settings .cache_built_engines
101+ and engine_cache is not None
102+ ):
103+ hash_val = engine_cache .get_hash (module , inputs , settings )
104+ engine_cache .insert (
105+ hash_val ,
106+ (
107+ serialized_engine ,
108+ interpreter_result .input_names ,
109+ interpreter_result .output_names ,
110+ inputs ,
111+ settings ,
112+ interpreter_result .weight_name_map ,
113+ interpreter_result .requires_output_allocator ,
114+ ),
115+ )
116+
95117 serialized_interpreter_result = SerializedInterpreterResult (
96118 serialized_engine = serialized_engine ,
97119 input_names = interpreter_result .input_names ,
@@ -120,7 +142,7 @@ def convert_module(
120142 Returns:
121143 PythonTorchTensorRTModule or TorchTensorRTModule
122144 """
123- interpreter_result = interpret_module_to_result (
145+ serialized_interpreter_result = interpret_module_to_result (
124146 module , inputs , settings , engine_cache = engine_cache
125147 )
126148
@@ -139,11 +161,11 @@ def convert_module(
139161 )
140162
141163 return rt_cls (
142- serialized_engine = interpreter_result .serialized_engine ,
143- input_binding_names = list (interpreter_result .input_names ),
144- output_binding_names = list (interpreter_result .output_names ),
164+ serialized_engine = serialized_interpreter_result .serialized_engine ,
165+ input_binding_names = list (serialized_interpreter_result .input_names ),
166+ output_binding_names = list (serialized_interpreter_result .output_names ),
145167 name = name ,
146168 settings = settings ,
147- weight_name_map = interpreter_result .weight_name_map ,
148- requires_output_allocator = interpreter_result .requires_output_allocator ,
169+ weight_name_map = serialized_interpreter_result .weight_name_map ,
170+ requires_output_allocator = serialized_interpreter_result .requires_output_allocator ,
149171 )
0 commit comments