Skip to content

Commit 2b8068c

Browse files
Aravind-11vasqu
andauthored
T5 migration to new masking interface (#41804)
* Refactor: migrate T5 attention masking to masking_utils interface * Refactor: migrate T5 attention masking to masking_utils interface * create_bidirectional_mask function with appropriate paramaters * create_bidirectional_mask function with appropriate paramaters * fixup executorch + import * revert causal masks * rm executorch stuff * add causal mask with non vmap * copies * remove unnecessary import --------- Co-authored-by: Vasqu <antonprogamer@gmail.com> Co-authored-by: Anton Vlasjuk <73884904+vasqu@users.noreply.github.com>
1 parent 33c60a5 commit 2b8068c

File tree

2 files changed

+42
-325
lines changed

2 files changed

+42
-325
lines changed

src/transformers/models/mt5/modeling_mt5.py

Lines changed: 21 additions & 162 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from ...activations import ACT2FN
2626
from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
2727
from ...generation import GenerationMixin
28-
from ...modeling_attn_mask_utils import AttentionMaskConverter
28+
from ...masking_utils import create_bidirectional_mask, create_causal_mask
2929
from ...modeling_layers import GradientCheckpointingLayer
3030
from ...modeling_outputs import (
3131
BaseModelOutput,
@@ -41,18 +41,11 @@
4141
DUMMY_INPUTS,
4242
DUMMY_MASK,
4343
auto_docstring,
44-
is_torch_flex_attn_available,
45-
is_torchdynamo_compiling,
4644
logging,
4745
)
4846
from .configuration_mt5 import MT5Config
4947

5048

51-
if is_torch_flex_attn_available():
52-
from torch.nn.attention.flex_attention import BlockMask
53-
54-
from ...integrations.flex_attention import make_flex_block_causal_mask
55-
5649
logger = logging.get_logger(__name__)
5750

5851

@@ -735,40 +728,31 @@ def forward(
735728
past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device
736729
)
737730

738-
if attention_mask is None and not is_torchdynamo_compiling():
739-
# required mask seq length can be calculated via length of past cache
740-
mask_seq_length = past_key_values_length + seq_length
741-
attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
742-
743731
if self.config.is_decoder:
744-
causal_mask = self._update_causal_mask(
745-
attention_mask,
746-
inputs_embeds,
747-
cache_position,
748-
past_key_values.self_attention_cache
732+
attention_mask = create_causal_mask(
733+
config=self.config,
734+
input_embeds=inputs_embeds,
735+
attention_mask=attention_mask,
736+
cache_position=cache_position,
737+
past_key_values=past_key_values.self_attention_cache
749738
if isinstance(past_key_values, EncoderDecoderCache)
750739
else past_key_values,
751-
output_attentions,
752740
)
753-
elif attention_mask is not None:
754-
causal_mask = attention_mask[:, None, None, :]
755-
causal_mask = causal_mask.to(dtype=inputs_embeds.dtype)
756-
causal_mask = (1.0 - causal_mask) * torch.finfo(inputs_embeds.dtype).min
757741
else:
758-
causal_mask = None
742+
attention_mask = create_bidirectional_mask(
743+
config=self.config,
744+
input_embeds=inputs_embeds,
745+
attention_mask=attention_mask,
746+
)
759747

760-
# If a 2D or 3D attention mask is provided for the cross-attention
761-
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
748+
encoder_extended_attention_mask = None
762749
if self.is_decoder and encoder_hidden_states is not None:
763-
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
764-
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
765-
if encoder_attention_mask is None:
766-
encoder_attention_mask = torch.ones(
767-
encoder_hidden_shape, device=inputs_embeds.device, dtype=torch.long
768-
)
769-
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
770-
else:
771-
encoder_extended_attention_mask = None
750+
encoder_extended_attention_mask = create_bidirectional_mask(
751+
config=self.config,
752+
input_embeds=inputs_embeds,
753+
attention_mask=encoder_attention_mask,
754+
encoder_hidden_states=encoder_hidden_states,
755+
)
772756

773757
all_hidden_states = () if output_hidden_states else None
774758
all_attentions = () if output_attentions else None
@@ -778,13 +762,13 @@ def forward(
778762

779763
hidden_states = self.dropout(inputs_embeds)
780764

781-
for i, layer_module in enumerate(self.block):
765+
for layer_module in self.block:
782766
if output_hidden_states:
783767
all_hidden_states = all_hidden_states + (hidden_states,)
784768

785769
layer_outputs = layer_module(
786770
hidden_states,
787-
causal_mask,
771+
attention_mask,
788772
position_bias,
789773
encoder_hidden_states,
790774
encoder_extended_attention_mask,
@@ -837,131 +821,6 @@ def forward(
837821
cross_attentions=all_cross_attentions,
838822
)
839823

840-
# Copied from transformers.models.gptj.modeling_gptj.GPTJModel._update_causal_mask
841-
def _update_causal_mask(
842-
self,
843-
attention_mask: Union[torch.Tensor, "BlockMask"],
844-
input_tensor: torch.Tensor,
845-
cache_position: torch.Tensor,
846-
past_key_values: Cache,
847-
output_attentions: bool = False,
848-
):
849-
if self.config._attn_implementation == "flash_attention_2":
850-
if attention_mask is not None and (attention_mask == 0.0).any():
851-
return attention_mask
852-
return None
853-
if self.config._attn_implementation == "flex_attention":
854-
if isinstance(attention_mask, torch.Tensor):
855-
attention_mask = make_flex_block_causal_mask(attention_mask)
856-
return attention_mask
857-
858-
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
859-
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
860-
# to infer the attention mask.
861-
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
862-
using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False
863-
864-
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
865-
if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions:
866-
if AttentionMaskConverter._ignore_causal_mask_sdpa(
867-
attention_mask,
868-
inputs_embeds=input_tensor,
869-
past_key_values_length=past_seen_tokens,
870-
is_training=self.training,
871-
):
872-
return None
873-
874-
dtype = input_tensor.dtype
875-
sequence_length = input_tensor.shape[1]
876-
if using_compilable_cache:
877-
target_length = past_key_values.get_max_cache_shape()
878-
else:
879-
target_length = (
880-
attention_mask.shape[-1]
881-
if isinstance(attention_mask, torch.Tensor)
882-
else past_seen_tokens + sequence_length + 1
883-
)
884-
885-
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
886-
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
887-
attention_mask,
888-
sequence_length=sequence_length,
889-
target_length=target_length,
890-
dtype=dtype,
891-
cache_position=cache_position,
892-
batch_size=input_tensor.shape[0],
893-
)
894-
895-
if (
896-
self.config._attn_implementation == "sdpa"
897-
and attention_mask is not None
898-
and attention_mask.device.type in ["cuda", "xpu", "npu"]
899-
and not output_attentions
900-
):
901-
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
902-
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
903-
# Details: https://github.com/pytorch/pytorch/issues/110213
904-
min_dtype = torch.finfo(dtype).min
905-
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
906-
907-
return causal_mask
908-
909-
@staticmethod
910-
# Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position
911-
def _prepare_4d_causal_attention_mask_with_cache_position(
912-
attention_mask: torch.Tensor,
913-
sequence_length: int,
914-
target_length: int,
915-
dtype: torch.dtype,
916-
cache_position: torch.Tensor,
917-
batch_size: int,
918-
**kwargs,
919-
):
920-
"""
921-
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
922-
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
923-
924-
Args:
925-
attention_mask (`torch.Tensor`):
926-
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
927-
`(batch_size, 1, query_length, key_value_length)`.
928-
sequence_length (`int`):
929-
The sequence length being processed.
930-
target_length (`int`):
931-
The target length: when generating with static cache, the mask should be as long as the static cache,
932-
to account for the 0 padding, the part of the cache that is not filled yet.
933-
dtype (`torch.dtype`):
934-
The dtype to use for the 4D attention mask.
935-
cache_position (`torch.Tensor`):
936-
Indices depicting the position of the input sequence tokens in the sequence.
937-
batch_size (`torch.Tensor`):
938-
Batch size.
939-
"""
940-
if attention_mask is not None and attention_mask.dim() == 4:
941-
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
942-
causal_mask = attention_mask
943-
else:
944-
min_dtype = torch.finfo(dtype).min
945-
causal_mask = torch.full(
946-
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
947-
)
948-
if sequence_length != 1:
949-
causal_mask = torch.triu(causal_mask, diagonal=1)
950-
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
951-
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
952-
if attention_mask is not None:
953-
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
954-
mask_length = attention_mask.shape[-1]
955-
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
956-
causal_mask.device
957-
)
958-
padding_mask = padding_mask == 0
959-
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
960-
padding_mask, min_dtype
961-
)
962-
963-
return causal_mask
964-
965824

966825
@auto_docstring
967826
class MT5Model(MT5PreTrainedModel):

0 commit comments

Comments
 (0)