Skip to content

Commit 309180f

Browse files
authored
[BLT] Fix cache usage (#42188)
* fix * properly * fix tests
1 parent 8976ceb commit 309180f

File tree

4 files changed

+32
-77
lines changed

4 files changed

+32
-77
lines changed

src/transformers/models/blt/modeling_blt.py

Lines changed: 13 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
import torch.nn.functional as F
2929

3030
from ...activations import ACT2FN
31-
from ...cache_utils import Cache, DynamicCache
31+
from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
3232
from ...generation import GenerationMixin
3333
from ...masking_utils import create_causal_mask
3434
from ...modeling_flash_attention_utils import FlashAttentionKwargs
@@ -321,7 +321,6 @@ def forward(
321321
hidden_states: torch.Tensor,
322322
attention_mask: torch.Tensor,
323323
position_embeddings: torch.Tensor,
324-
use_cache: bool = False,
325324
past_key_values=None,
326325
cache_position=None,
327326
**kwargs,
@@ -393,9 +392,7 @@ def forward(
393392
self,
394393
hidden_states: torch.Tensor,
395394
cross_attention_states: Optional[torch.Tensor] = None,
396-
past_key_values: Optional[Cache] = None,
397395
attention_mask: Optional[torch.Tensor] = None,
398-
cache_position: Optional[torch.LongTensor] = None,
399396
**kwargs: Unpack[TransformersKwargs],
400397
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
401398
"""Input shape: Batch x Time x Channel"""
@@ -404,27 +401,13 @@ def forward(
404401
query_states = self.q_proj(query_states)
405402
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
406403

407-
if cross_attention_states is not None:
408-
cross_attention_states = self.k_norm(cross_attention_states)
409-
key_states = self.k_proj(cross_attention_states)
410-
value_states = self.v_proj(cross_attention_states)
411-
key_states = key_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2)
412-
value_states = value_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2)
413-
if past_key_values is not None:
414-
key_states, value_states = past_key_values.update(
415-
key_states, value_states, self.layer_idx, {"cache_position": cache_position}
416-
)
417-
elif cache_position[0] != 0:
418-
key_states, value_states = (
419-
past_key_values.layers[self.layer_idx].keys,
420-
past_key_values.layers[self.layer_idx].values,
421-
)
422-
else:
423-
raise ValueError(
424-
"Cross attention layer can't find neither `cross_attn_states` nor cached values for key/values!"
425-
)
426-
attention_interface: Callable = eager_attention_forward
404+
cross_attention_states = self.k_norm(cross_attention_states)
405+
key_states = self.k_proj(cross_attention_states)
406+
value_states = self.v_proj(cross_attention_states)
407+
key_states = key_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2)
408+
value_states = value_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2)
427409

410+
attention_interface: Callable = eager_attention_forward
428411
if self.config._attn_implementation != "eager":
429412
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
430413

@@ -1089,6 +1072,9 @@ def forward(
10891072
if (input_ids is None) ^ (inputs_embeds is not None):
10901073
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
10911074

1075+
if use_cache and past_key_values is None:
1076+
past_key_values = EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
1077+
10921078
# Extract input embeddings as early as possible
10931079
if inputs_embeds is not None:
10941080
encoder_embeds = inputs_embeds
@@ -1137,7 +1123,7 @@ def forward(
11371123
input_embeds=encoder_embeds,
11381124
attention_mask=attention_mask,
11391125
cache_position=cache_position,
1140-
past_key_values=past_key_values,
1126+
past_key_values=past_key_values.self_attention_cache if past_key_values is not None else None,
11411127
position_ids=position_ids,
11421128
)
11431129

@@ -1157,6 +1143,7 @@ def forward(
11571143
encoder_attention_mask=cross_attn_mask_enc,
11581144
num_patches=patch_lengths.shape[1],
11591145
patch_ids=patch_ids,
1146+
past_key_values=past_key_values.self_attention_cache if past_key_values is not None else None,
11601147
**kwargs,
11611148
)
11621149
encoder_cross_states = encoder_cross_states.view(batch_size, patch_lengths.shape[1], -1)
@@ -1192,7 +1179,7 @@ def forward(
11921179
patch_embeds=global_hidden_states,
11931180
attention_mask=causal_mask,
11941181
position_ids=position_ids,
1195-
past_key_values=past_key_values,
1182+
past_key_values=past_key_values.cross_attention_cache if past_key_values is not None else None,
11961183
cache_position=cache_position,
11971184
encoder_attention_mask=cross_attn_mask_dec,
11981185
**kwargs,

src/transformers/models/blt/modular_blt.py

Lines changed: 13 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
import torch.nn as nn
2323
import torch.nn.functional as F
2424

25-
from ...cache_utils import Cache, DynamicCache
25+
from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
2626
from ...masking_utils import create_causal_mask
2727
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
2828
from ...modeling_rope_utils import dynamic_rope_update
@@ -299,27 +299,6 @@ def __init__(self, config, layer_idx: int):
299299
class BltSelfAttention(MllamaTextSelfAttention):
300300
def __init__(self, config: BltConfig, layer_idx: int):
301301
super().__init__(config, layer_idx)
302-
self.is_causal = True
303-
304-
def forward(
305-
self,
306-
hidden_states: torch.Tensor,
307-
attention_mask: torch.Tensor,
308-
position_embeddings: torch.Tensor,
309-
use_cache: bool = False,
310-
past_key_values=None,
311-
cache_position=None,
312-
**kwargs,
313-
):
314-
return super().forward(
315-
hidden_states=hidden_states,
316-
attention_mask=attention_mask,
317-
position_embeddings=position_embeddings,
318-
use_cache=use_cache,
319-
past_key_values=past_key_values,
320-
cache_position=cache_position,
321-
**kwargs,
322-
)
323302

324303

325304
class BltCrossAttention(MllamaTextCrossAttention):
@@ -335,37 +314,21 @@ def forward(
335314
self,
336315
hidden_states: torch.Tensor,
337316
cross_attention_states: Optional[torch.Tensor] = None,
338-
past_key_values: Optional[Cache] = None,
339317
attention_mask: Optional[torch.Tensor] = None,
340-
cache_position: Optional[torch.LongTensor] = None,
341318
**kwargs: Unpack[TransformersKwargs],
342319
):
343320
bsz, q_len, _ = hidden_states.size()
344321
query_states = self.q_norm(hidden_states)
345322
query_states = self.q_proj(query_states)
346323
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
347324

348-
if cross_attention_states is not None:
349-
cross_attention_states = self.k_norm(cross_attention_states)
350-
key_states = self.k_proj(cross_attention_states)
351-
value_states = self.v_proj(cross_attention_states)
352-
key_states = key_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2)
353-
value_states = value_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2)
354-
if past_key_values is not None:
355-
key_states, value_states = past_key_values.update(
356-
key_states, value_states, self.layer_idx, {"cache_position": cache_position}
357-
)
358-
elif cache_position[0] != 0:
359-
key_states, value_states = (
360-
past_key_values.layers[self.layer_idx].keys,
361-
past_key_values.layers[self.layer_idx].values,
362-
)
363-
else:
364-
raise ValueError(
365-
"Cross attention layer can't find neither `cross_attn_states` nor cached values for key/values!"
366-
)
367-
attention_interface: Callable = eager_attention_forward
325+
cross_attention_states = self.k_norm(cross_attention_states)
326+
key_states = self.k_proj(cross_attention_states)
327+
value_states = self.v_proj(cross_attention_states)
328+
key_states = key_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2)
329+
value_states = value_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2)
368330

331+
attention_interface: Callable = eager_attention_forward
369332
if self.config._attn_implementation != "eager":
370333
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
371334

@@ -828,6 +791,9 @@ def forward(
828791
if (input_ids is None) ^ (inputs_embeds is not None):
829792
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
830793

794+
if use_cache and past_key_values is None:
795+
past_key_values = EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
796+
831797
# Extract input embeddings as early as possible
832798
if inputs_embeds is not None:
833799
encoder_embeds = inputs_embeds
@@ -876,7 +842,7 @@ def forward(
876842
input_embeds=encoder_embeds,
877843
attention_mask=attention_mask,
878844
cache_position=cache_position,
879-
past_key_values=past_key_values,
845+
past_key_values=past_key_values.self_attention_cache if past_key_values is not None else None,
880846
position_ids=position_ids,
881847
)
882848

@@ -896,6 +862,7 @@ def forward(
896862
encoder_attention_mask=cross_attn_mask_enc,
897863
num_patches=patch_lengths.shape[1],
898864
patch_ids=patch_ids,
865+
past_key_values=past_key_values.self_attention_cache if past_key_values is not None else None,
899866
**kwargs,
900867
)
901868
encoder_cross_states = encoder_cross_states.view(batch_size, patch_lengths.shape[1], -1)
@@ -931,7 +898,7 @@ def forward(
931898
patch_embeds=global_hidden_states,
932899
attention_mask=causal_mask,
933900
position_ids=position_ids,
934-
past_key_values=past_key_values,
901+
past_key_values=past_key_values.cross_attention_cache if past_key_values is not None else None,
935902
cache_position=cache_position,
936903
encoder_attention_mask=cross_attn_mask_dec,
937904
**kwargs,

src/transformers/models/mllama/modeling_mllama.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -534,10 +534,8 @@ def forward(
534534
hidden_states: torch.Tensor,
535535
attention_mask: torch.Tensor,
536536
position_embeddings: torch.Tensor,
537-
use_cache: bool = False,
538537
past_key_values=None,
539538
cache_position=None,
540-
position_ids=None,
541539
**kwargs,
542540
):
543541
bsz, q_len, _ = hidden_states.size()

tests/models/blt/test_modeling_blt.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -224,12 +224,15 @@ def test_eager_matches_sdpa_inference(
224224

225225
@require_torch_accelerator
226226
class BltIntegrationTest(unittest.TestCase):
227+
def setup(self):
228+
cleanup(torch_device, gc_collect=True)
229+
227230
def tearDown(self):
228231
# TODO (joao): automatic compilation, i.e. compilation when `cache_implementation="static"` is used, leaves
229232
# some memory allocated in the cache, which means some object is not being released properly. This causes some
230233
# unoptimal memory usage, e.g. after certain tests a 7B model in FP16 no longer fits in a 24GB GPU.
231234
# Investigate the root cause.
232-
cleanup(torch_device, gc_collect=False)
235+
cleanup(torch_device, gc_collect=True)
233236

234237
@slow
235238
@require_read_token
@@ -339,7 +342,7 @@ def test_model_logits(self):
339342
def test_model_bf16(self):
340343
"""Test Blt model with bfloat16 precision."""
341344
NUM_TOKENS_TO_GENERATE = 200
342-
EXPECTED_TEXT = "my name is alex and i am a student at the university of michigan. i am a senior majoring in computer science and minoring in mathematics. i am also a member of the michigan math club and the michigan computer s"
345+
EXPECTED_TEXT = "my name is alex and i am a student at the university of michigan in the college of arts and sciences. i am a senior majoring in computer science and minoring in mathematics. i am also a member of the michigan m"
343346

344347
prompt = "my name is"
345348

@@ -472,7 +475,7 @@ def test_model_eager(self):
472475
def test_model_bf16_static_cache(self):
473476
"""Test Blt model with bfloat16 precision and static cache."""
474477
NUM_TOKENS_TO_GENERATE = 200
475-
EXPECTED_TEXT = "my name is alex and i am a student at the university of michigan. i am a senior majoring in computer science and minoring in mathematics. i am also a member of the michigan math club and the michigan computer s"
478+
EXPECTED_TEXT = "my name is alex and i am a student at the university of michigan in the college of arts and sciences. i am a senior majoring in computer science and minoring in mathematics. i am also a member of the michigan m"
476479

477480
prompt = "my name is"
478481

0 commit comments

Comments
 (0)