@@ -123,7 +123,6 @@ class PythonTorchTensorRTModule(Module): # type: ignore[misc]
123123
124124 def __init__ (
125125 self ,
126- cuda_engine : trt .ICudaEngine = None ,
127126 serialized_engine : Optional [bytes ] = None ,
128127 input_binding_names : Optional [List [str ]] = None ,
129128 output_binding_names : Optional [List [str ]] = None ,
@@ -183,19 +182,7 @@ def __init__(
183182 # Unused currently - to be used by Dynamic Shape support implementation
184183 self .memory_pool = None
185184
186- if cuda_engine :
187- assert isinstance (
188- cuda_engine , trt .ICudaEngine
189- ), "Cuda engine must be a trt.ICudaEngine object"
190- self .engine = cuda_engine
191- elif serialized_engine :
192- assert isinstance (
193- serialized_engine , bytes
194- ), "Serialized engine must be a bytes object"
195- self .engine = serialized_engine
196- else :
197- raise ValueError ("Serialized engine or cuda engine must be provided" )
198-
185+ self .serialized_engine = serialized_engine
199186 self .input_names = (
200187 input_binding_names if input_binding_names is not None else []
201188 )
@@ -217,6 +204,7 @@ def __init__(
217204 else False
218205 )
219206 self .settings = settings
207+ self .engine = None
220208 self .weight_name_map = weight_name_map
221209 self .target_platform = Platform .current_platform ()
222210 self .runtime_states = TorchTRTRuntimeStates (
@@ -231,7 +219,7 @@ def __init__(
231219 self .output_allocator : Optional [DynamicOutputAllocator ] = None
232220 self .use_output_allocator_outputs = False
233221
234- if self .engine and not self .settings .lazy_engine_init :
222+ if self .serialized_engine is not None and not self .settings .lazy_engine_init :
235223 self .setup_engine ()
236224
237225 def get_streamable_device_memory_budget (self ) -> Any :
@@ -272,22 +260,13 @@ def set_default_device_memory_budget(self) -> int:
272260 return self ._set_device_memory_budget (budget_bytes )
273261
274262 def setup_engine (self ) -> None :
275-
276- if isinstance (self .engine , trt .ICudaEngine ):
277- pass
278- elif isinstance (self .engine , bytes ):
279- runtime = trt .Runtime (TRT_LOGGER )
280- self .engine = runtime .deserialize_cuda_engine (self .engine )
281- else :
282- raise ValueError (
283- "Expected engine as trt.ICudaEngine or serialized engine as bytes"
284- )
285-
286263 assert (
287264 self .target_platform == Platform .current_platform ()
288265 ), f"TensorRT engine was not built to target current platform (target: { self .target_platform } , current: { Platform .current_platform ()} )"
289266
290267 self .initialized = True
268+ runtime = trt .Runtime (TRT_LOGGER )
269+ self .engine = runtime .deserialize_cuda_engine (self .serialized_engine )
291270 if self .settings .enable_weight_streaming :
292271 self .set_default_device_memory_budget ()
293272 self .context = self .engine .create_execution_context ()
@@ -323,7 +302,7 @@ def _check_initialized(self) -> None:
323302 raise RuntimeError ("PythonTorchTensorRTModule is not initialized." )
324303
325304 def _on_state_dict (self , state_dict : Dict [str , Any ], prefix : str , _ : Any ) -> None :
326- state_dict [prefix + "engine" ] = self .engine
305+ state_dict [prefix + "engine" ] = self .serialized_engine
327306 state_dict [prefix + "input_names" ] = self .input_names
328307 state_dict [prefix + "output_names" ] = self .output_names
329308 state_dict [prefix + "platform" ] = self .target_platform
@@ -338,7 +317,7 @@ def _load_from_state_dict(
338317 unexpected_keys : Any ,
339318 error_msgs : Any ,
340319 ) -> None :
341- self .engine = state_dict [prefix + "engine" ]
320+ self .serialized_engine = state_dict [prefix + "engine" ]
342321 self .input_names = state_dict [prefix + "input_names" ]
343322 self .output_names = state_dict [prefix + "output_names" ]
344323 self .target_platform = state_dict [prefix + "platform" ]
0 commit comments