Skip to content

Commit 1f2620b

Browse files
committed
fix(ci): reduce memory safety margin
1 parent 24f3b7e commit 1f2620b

File tree

3 files changed

+2
-43
lines changed

3 files changed

+2
-43
lines changed

benchmark_v2/framework/benchmark_runner.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -117,8 +117,6 @@ def flush_memory():
117117
# Clear CUDA cache
118118
if torch.cuda.is_available():
119119
torch.cuda.empty_cache()
120-
torch.cuda.reset_max_memory_allocated()
121-
torch.cuda.reset_peak_memory_stats()
122120
torch.cuda.synchronize()
123121
gc.collect()
124122

src/transformers/generation/continuous_batching/cache.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,7 @@ def __init__(
189189
num_blocks, max_batch_tokens = memory_handler.infer_num_blocks_and_max_batch_tokens(
190190
num_blocks=getattr(generation_config, "num_blocks", None),
191191
max_batch_tokens=getattr(generation_config, "max_batch_tokens", None),
192-
max_memory_percent=getattr(generation_config, "max_memory", 0.9),
192+
max_memory_percent=getattr(generation_config, "max_memory", 0.8),
193193
cache_dtype=self.dtype,
194194
)
195195

@@ -335,14 +335,6 @@ def update(
335335
# Return the new KV values
336336
return key_states_with_cache, value_states_with_cache
337337

338-
@traced
339-
def close(self):
340-
self.key_cache.clear()
341-
self.value_cache.clear()
342-
343-
torch._dynamo.reset()
344-
torch._dynamo.reset_code_caches()
345-
346338

347339
# TODO: rework computation with the groups and their sizes
348340
class PagedAttentionMemoryHandler:

src/transformers/generation/continuous_batching/continuous_api.py

Lines changed: 1 addition & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -708,35 +708,6 @@ def _sample(self, probs: torch.Tensor, do_sample: bool) -> None:
708708
tokens = next_tokens.size(1) # Get seq_len dimension
709709
self.output_ids[:, :tokens].copy_(next_tokens)
710710

711-
def close(self):
712-
self.cache.close()
713-
self.requests_in_batch.clear()
714-
715-
if self._graphs is not None:
716-
self._graphs.clear()
717-
718-
del self.input_ids
719-
del self.position_ids
720-
del self.cumulative_seqlens_q
721-
del self.logits_indices
722-
del self.output_ids
723-
724-
self.cumulative_seqlens_k.clear()
725-
726-
if self.attention_mask is not None:
727-
self.attention_mask.clear()
728-
self.attention_mask = None
729-
730-
self.write_index_storage.clear()
731-
self.read_index_storage.clear()
732-
733-
if torch.cuda.is_available():
734-
torch.cuda.synchronize()
735-
import gc
736-
737-
gc.collect()
738-
torch.cuda.empty_cache()
739-
740711

741712
# Manager Class (User Interface)
742713
@attach_tracer()
@@ -855,9 +826,7 @@ def stop(self, block: bool = True, timeout: Optional[float] = None) -> None:
855826
if block:
856827
self.join(stop_trigger_time, timeout)
857828

858-
if self.batch_processor is not None:
859-
self.batch_processor.close()
860-
self.batch_processor = None # NOTE: this is enough to clear memory after stop, still calling `close()` because it calls torch cache intrinsics
829+
self.batch_processor = None
861830

862831
def join(self, stop_trigger_time: float, timeout: Optional[float] = None) -> None:
863832
"""Wait for the background thread to finish.

0 commit comments

Comments
 (0)