Skip to content

Commit 24f3b7e

Browse files
committed
fix(ci): effectively clear cache
1 parent 5b03675 commit 24f3b7e

File tree

2 files changed

+40
-5
lines changed

2 files changed

+40
-5
lines changed

src/transformers/generation/continuous_batching/cache.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -335,6 +335,14 @@ def update(
335335
# Return the new KV values
336336
return key_states_with_cache, value_states_with_cache
337337

338+
@traced
339+
def close(self):
340+
self.key_cache.clear()
341+
self.value_cache.clear()
342+
343+
torch._dynamo.reset()
344+
torch._dynamo.reset_code_caches()
345+
338346

339347
# TODO: rework computation with the groups and their sizes
340348
class PagedAttentionMemoryHandler:

src/transformers/generation/continuous_batching/continuous_api.py

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -708,6 +708,35 @@ def _sample(self, probs: torch.Tensor, do_sample: bool) -> None:
708708
tokens = next_tokens.size(1) # Get seq_len dimension
709709
self.output_ids[:, :tokens].copy_(next_tokens)
710710

711+
def close(self):
712+
self.cache.close()
713+
self.requests_in_batch.clear()
714+
715+
if self._graphs is not None:
716+
self._graphs.clear()
717+
718+
del self.input_ids
719+
del self.position_ids
720+
del self.cumulative_seqlens_q
721+
del self.logits_indices
722+
del self.output_ids
723+
724+
self.cumulative_seqlens_k.clear()
725+
726+
if self.attention_mask is not None:
727+
self.attention_mask.clear()
728+
self.attention_mask = None
729+
730+
self.write_index_storage.clear()
731+
self.read_index_storage.clear()
732+
733+
if torch.cuda.is_available():
734+
torch.cuda.synchronize()
735+
import gc
736+
737+
gc.collect()
738+
torch.cuda.empty_cache()
739+
711740

712741
# Manager Class (User Interface)
713742
@attach_tracer()
@@ -826,11 +855,9 @@ def stop(self, block: bool = True, timeout: Optional[float] = None) -> None:
826855
if block:
827856
self.join(stop_trigger_time, timeout)
828857

829-
torch.cuda.synchronize()
830-
import gc
831-
832-
gc.collect()
833-
torch.cuda.empty_cache()
858+
if self.batch_processor is not None:
859+
self.batch_processor.close()
860+
self.batch_processor = None # NOTE: this is enough to clear memory after stop, still calling `close()` because it calls torch cache intrinsics
834861

835862
def join(self, stop_trigger_time: float, timeout: Optional[float] = None) -> None:
836863
"""Wait for the background thread to finish.

0 commit comments

Comments
 (0)