2424
2525
2626class TorchTRTRuntimeStates :
27- def __init__ (self , new_cudagraphs : bool , new_pre_allocated_output : bool ):
27+ def __init__ (self , new_cudagraphs : bool ):
2828 # Indicates whether CUDAGraphs were enabled in the previous execute_engine
2929 self .old_cudagraphs = new_cudagraphs
3030 # Indicates whether pre-allocated output was enabled in the previous execute_engine
31- self .old_pre_allocated_outputs = new_pre_allocated_output
31+ self .old_pre_allocated_outputs = False
32+ # Indicates whether context has changed
33+ self .context_changed = False
3234
3335 def set_runtime_states (
3436 self ,
3537 new_cudagraphs : bool ,
3638 new_pre_allocated_output : bool ,
3739 shape_changed : bool ,
38- ) -> Tuple [bool , bool ]:
39- # Evaluates whether certain conditions are met to enable CUDA Graph recording or to reuse pre-allocated outputs
40+ ) -> Tuple [bool , bool , bool ]:
41+ # Evaluates whether certain conditions are met to enable CUDA Graph recording or to use pre-allocated outputs
4042 # based on the current and previous states, as well as input shape has changed
4143 need_cudagraphs_record = False
4244 can_use_pre_allocated_outputs = False
43-
44- # Cudagraphs record is required if cudagraphs_enabled is switched to True regardless of shape change
45- if new_cudagraphs and (not self .old_cudagraphs or shape_changed ):
45+ need_cudagraphs_reset = False
46+
47+ # CUDA Graph recording is needed if CUDA graphs is enabled and:
48+ # - CUDA graphs were previously disabled
49+ # - or the shape has changed
50+ # - or the execution context has changed (e.g., weight streaming)
51+ if new_cudagraphs and (
52+ not self .old_cudagraphs or shape_changed or self .context_changed
53+ ):
4654 need_cudagraphs_record = True
4755
4856 # Pre-allocated output can be used when previous and current state are true without shape change
@@ -53,10 +61,19 @@ def set_runtime_states(
5361 ):
5462 can_use_pre_allocated_outputs = True
5563
64+ if not new_cudagraphs or shape_changed or self .context_changed :
65+ need_cudagraphs_reset = True
66+
5667 self .old_cudagraphs = new_cudagraphs
5768 self .old_pre_allocated_outputs = new_pre_allocated_output
69+ # reset flag
70+ self .context_changed = False
5871
59- return need_cudagraphs_record , can_use_pre_allocated_outputs
72+ return (
73+ need_cudagraphs_record ,
74+ can_use_pre_allocated_outputs ,
75+ need_cudagraphs_reset ,
76+ )
6077
6178
6279class PythonTorchTensorRTModule (Module ): # type: ignore[misc]
@@ -145,7 +162,7 @@ def __init__(
145162 self .weight_name_map = weight_name_map
146163 self .target_platform = Platform .current_platform ()
147164 self .runtime_states = TorchTRTRuntimeStates (
148- torch_tensorrt .runtime .get_cudagraphs_mode (), False
165+ torch_tensorrt .runtime .get_cudagraphs_mode ()
149166 )
150167 self .pre_allocated_outputs : List [torch .Tensor ] = []
151168 self .use_pre_allocated_outputs = False
@@ -168,6 +185,7 @@ def set_device_memory_budget(self, budget_bytes: int) -> int:
168185 del self .context
169186 budget_bytes = self ._set_device_memory_budget (budget_bytes )
170187 self .context = self .engine .create_execution_context ()
188+ self .runtime_states .context_changed = True
171189 return budget_bytes
172190
173191 def _set_device_memory_budget (self , budget_bytes : int ) -> int :
@@ -200,7 +218,6 @@ def setup_engine(self) -> None:
200218 if self .settings .enable_weight_streaming :
201219 self .set_default_device_memory_budget ()
202220 self .context = self .engine .create_execution_context ()
203-
204221 assert self .engine .num_io_tensors == (
205222 len (self .input_names ) + len (self .output_names )
206223 )
@@ -356,22 +373,22 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
356373
357374 cudagraphs_enabled = torch_tensorrt .runtime .get_cudagraphs_mode ()
358375 shape_changed = self .validate_input_shapes (inputs )
359- need_cudagraphs_record , can_use_pre_allocated_outputs = (
360- self .runtime_states .set_runtime_states (
361- cudagraphs_enabled , self .use_pre_allocated_outputs , shape_changed
362- )
376+ (
377+ need_cudagraphs_record ,
378+ can_use_pre_allocated_outputs ,
379+ need_cudagraphs_reset ,
380+ ) = self .runtime_states .set_runtime_states (
381+ cudagraphs_enabled , self .use_pre_allocated_outputs , shape_changed
363382 )
364383
384+ if need_cudagraphs_reset and self .cudagraph :
385+ self .cudagraph .reset ()
386+ self .cudagraph = None
387+
365388 if need_cudagraphs_record :
366- if self .cudagraph :
367- self .cudagraph .reset ()
368389 self ._input_buffers = [None ] * len (self .input_names )
369390 self ._output_buffers = [None ] * len (self .output_names )
370391
371- if not cudagraphs_enabled and self .cudagraph :
372- self .cudagraph .reset ()
373- self .cudagraph = None
374-
375392 # If in safe mode, check at each iteration for whether a switch is required
376393 if (
377394 torch_tensorrt .runtime ._multi_device_safe_mode ._PY_RT_MULTI_DEVICE_SAFE_MODE
0 commit comments