@@ -708,6 +708,35 @@ def _sample(self, probs: torch.Tensor, do_sample: bool) -> None:
708708 tokens = next_tokens .size (1 ) # Get seq_len dimension
709709 self .output_ids [:, :tokens ].copy_ (next_tokens )
710710
711+ def close (self ):
712+ self .cache .close ()
713+ self .requests_in_batch .clear ()
714+
715+ if self ._graphs is not None :
716+ self ._graphs .clear ()
717+
718+ del self .input_ids
719+ del self .position_ids
720+ del self .cumulative_seqlens_q
721+ del self .logits_indices
722+ del self .output_ids
723+
724+ self .cumulative_seqlens_k .clear ()
725+
726+ if self .attention_mask is not None :
727+ self .attention_mask .clear ()
728+ self .attention_mask = None
729+
730+ self .write_index_storage .clear ()
731+ self .read_index_storage .clear ()
732+
733+ if torch .cuda .is_available ():
734+ torch .cuda .synchronize ()
735+ import gc
736+
737+ gc .collect ()
738+ torch .cuda .empty_cache ()
739+
711740
712741# Manager Class (User Interface)
713742@attach_tracer ()
@@ -826,11 +855,9 @@ def stop(self, block: bool = True, timeout: Optional[float] = None) -> None:
826855 if block :
827856 self .join (stop_trigger_time , timeout )
828857
829- torch .cuda .synchronize ()
830- import gc
831-
832- gc .collect ()
833- torch .cuda .empty_cache ()
858+ if self .batch_processor is not None :
859+ self .batch_processor .close ()
860+ self .batch_processor = None # NOTE: this is enough to clear memory after stop, still calling `close()` because it calls torch cache intrinsics
834861
835862 def join (self , stop_trigger_time : float , timeout : Optional [float ] = None ) -> None :
836863 """Wait for the background thread to finish.
0 commit comments