|
29 | 29 | from ...activations import ACT2FN |
30 | 30 | from ...cache_utils import Cache, DynamicCache |
31 | 31 | from ...generation import GenerationConfig, GenerationMixin |
32 | | -from ...modeling_attn_mask_utils import _prepare_4d_attention_mask |
| 32 | +from ...masking_utils import create_bidirectional_mask |
33 | 33 | from ...modeling_flash_attention_utils import FlashAttentionKwargs |
34 | 34 | from ...modeling_layers import GradientCheckpointingLayer |
35 | 35 | from ...modeling_outputs import BaseModelOutput, ModelOutput |
@@ -396,13 +396,12 @@ def forward( |
396 | 396 | hidden_states = self.embeddings(pixel_values=pixel_values, patch_attention_mask=patch_attention_mask) |
397 | 397 |
|
398 | 398 | patch_attention_mask = patch_attention_mask.view(batch_size, -1) |
399 | | - # The call to `_upad_input` in `_flash_attention_forward` is expensive |
400 | | - # So when the `patch_attention_mask` is full of 1s (i.e. attending to the whole sequence), |
401 | | - # avoiding passing the attention_mask, which is equivalent to attending to the full sequence |
402 | | - if self.config._attn_implementation != "flash_attention_2": |
403 | | - patch_attention_mask = _prepare_4d_attention_mask(patch_attention_mask, hidden_states.dtype) |
404 | | - elif not torch.any(~patch_attention_mask): |
405 | | - patch_attention_mask = None |
| 399 | + # Create the correct attention mask based on the attention implementation |
| 400 | + patch_attention_mask = create_bidirectional_mask( |
| 401 | + config=self.config, |
| 402 | + input_embeds=hidden_states, |
| 403 | + attention_mask=patch_attention_mask, |
| 404 | + ) |
406 | 405 |
|
407 | 406 | encoder_outputs: BaseModelOutput = self.encoder( |
408 | 407 | inputs_embeds=hidden_states, |
|
0 commit comments