Skip to content

Commit 33c60a5

Browse files
authored
[T5Gemma] Fix cross attention cache (#41890)
* fix * add test * style * added comment
1 parent fa22b56 commit 33c60a5

File tree

3 files changed

+22
-2
lines changed

3 files changed

+22
-2
lines changed

src/transformers/models/t5gemma/modeling_t5gemma.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -797,7 +797,9 @@ def forward(
797797
inputs_embeds = self.embed_tokens(input_ids)
798798

799799
if not self.training and use_cache and past_key_values is None:
800-
past_key_values = EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
800+
# We do not pass the config to the cross attn cache to avoid initializing SWA
801+
# --> we use full attention between our cross attentions
802+
past_key_values = EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache())
801803
if cache_position is None:
802804
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
803805
cache_position = torch.arange(

src/transformers/models/t5gemma/modular_t5gemma.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -835,7 +835,9 @@ def forward(
835835
inputs_embeds = self.embed_tokens(input_ids)
836836

837837
if not self.training and use_cache and past_key_values is None:
838-
past_key_values = EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
838+
# We do not pass the config to the cross attn cache to avoid initializing SWA
839+
# --> we use full attention between our cross attentions
840+
past_key_values = EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache())
839841
if cache_position is None:
840842
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
841843
cache_position = torch.arange(

tests/models/t5gemma/test_modeling_t5gemma.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,14 @@
1919

2020
import pytest
2121
from parameterized import parameterized
22+
from pytest import mark
2223

2324
from transformers import T5GemmaConfig, T5GemmaModuleConfig, is_torch_available
2425
from transformers.testing_utils import (
26+
require_flash_attn,
2527
require_torch,
2628
require_torch_accelerator,
29+
require_torch_gpu,
2730
torch_device,
2831
)
2932

@@ -1267,6 +1270,19 @@ def test_flex_attention_with_grads(self):
12671270
# If this does not raise an error, the test passes (see https://github.com/huggingface/transformers/pull/35605)
12681271
_ = model(**dummy_inputs)
12691272

1273+
@require_flash_attn
1274+
@require_torch_gpu
1275+
@mark.flash_attn_test
1276+
def test_generate_beyond_sliding_window_with_flash_attn(self):
1277+
config, input_ids, _, attention_mask, _, _ = self.model_tester.prepare_config_and_inputs()
1278+
config.decoder.sliding_window = 2 # arbitrary but less than seq_len
1279+
1280+
model = self.model_tester.causal_lm_class(config=config).to(dtype=torch.float16, device=torch_device).eval()
1281+
model.set_attn_implementation("flash_attention_2")
1282+
1283+
# Only generate beyond prefill, we don't care about the output as it only checks for crashes
1284+
_ = model.generate(input_ids, attention_mask=attention_mask, max_new_tokens=2, use_cache=True)
1285+
12701286

12711287
class T5GemmaEncoderOnlyModelTester:
12721288
config_class = T5GemmaConfig

0 commit comments

Comments
 (0)