Skip to content

Commit fcea1e1

Browse files
authored
Fixes Flash Attention implementation for models (#42149)
* flash-att3 fix for smolvlm2 * flash-att3 fix for idefics2 * idefics2 changes * reset idefics2
1 parent 563f2ff commit fcea1e1

File tree

2 files changed

+14
-16
lines changed

2 files changed

+14
-16
lines changed

src/transformers/models/idefics3/modeling_idefics3.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from ...activations import ACT2FN
2525
from ...cache_utils import Cache, DynamicCache
2626
from ...generation import GenerationMixin
27-
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
27+
from ...masking_utils import create_bidirectional_mask
2828
from ...modeling_flash_attention_utils import FlashAttentionKwargs
2929
from ...modeling_layers import GradientCheckpointingLayer
3030
from ...modeling_outputs import BaseModelOutput, ModelOutput
@@ -506,13 +506,12 @@ def forward(
506506
hidden_states = self.embeddings(pixel_values=pixel_values, patch_attention_mask=patch_attention_mask)
507507

508508
patch_attention_mask = patch_attention_mask.view(batch_size, -1)
509-
# The call to `_upad_input` in `_flash_attention_forward` is expensive
510-
# So when the `patch_attention_mask` is full of 1s (i.e. attending to the whole sequence),
511-
# avoiding passing the attention_mask, which is equivalent to attending to the full sequence
512-
if self.config._attn_implementation != "flash_attention_2":
513-
patch_attention_mask = _prepare_4d_attention_mask(patch_attention_mask, hidden_states.dtype)
514-
elif not torch.any(~patch_attention_mask):
515-
patch_attention_mask = None
509+
# Create the correct attention mask based on the attention implementation
510+
patch_attention_mask = create_bidirectional_mask(
511+
config=self.config,
512+
input_embeds=hidden_states,
513+
attention_mask=patch_attention_mask,
514+
)
516515

517516
encoder_outputs: BaseModelOutput = self.encoder(
518517
inputs_embeds=hidden_states,

src/transformers/models/smolvlm/modeling_smolvlm.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from ...activations import ACT2FN
3030
from ...cache_utils import Cache, DynamicCache
3131
from ...generation import GenerationConfig, GenerationMixin
32-
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
32+
from ...masking_utils import create_bidirectional_mask
3333
from ...modeling_flash_attention_utils import FlashAttentionKwargs
3434
from ...modeling_layers import GradientCheckpointingLayer
3535
from ...modeling_outputs import BaseModelOutput, ModelOutput
@@ -396,13 +396,12 @@ def forward(
396396
hidden_states = self.embeddings(pixel_values=pixel_values, patch_attention_mask=patch_attention_mask)
397397

398398
patch_attention_mask = patch_attention_mask.view(batch_size, -1)
399-
# The call to `_upad_input` in `_flash_attention_forward` is expensive
400-
# So when the `patch_attention_mask` is full of 1s (i.e. attending to the whole sequence),
401-
# avoiding passing the attention_mask, which is equivalent to attending to the full sequence
402-
if self.config._attn_implementation != "flash_attention_2":
403-
patch_attention_mask = _prepare_4d_attention_mask(patch_attention_mask, hidden_states.dtype)
404-
elif not torch.any(~patch_attention_mask):
405-
patch_attention_mask = None
399+
# Create the correct attention mask based on the attention implementation
400+
patch_attention_mask = create_bidirectional_mask(
401+
config=self.config,
402+
input_embeds=hidden_states,
403+
attention_mask=patch_attention_mask,
404+
)
406405

407406
encoder_outputs: BaseModelOutput = self.encoder(
408407
inputs_embeds=hidden_states,

0 commit comments

Comments
 (0)