2525from ...activations import ACT2FN
2626from ...cache_utils import Cache , DynamicCache , EncoderDecoderCache
2727from ...generation import GenerationMixin
28- from ...modeling_attn_mask_utils import AttentionMaskConverter
28+ from ...masking_utils import create_bidirectional_mask , create_causal_mask
2929from ...modeling_layers import GradientCheckpointingLayer
3030from ...modeling_outputs import (
3131 BaseModelOutput ,
4141 DUMMY_INPUTS ,
4242 DUMMY_MASK ,
4343 auto_docstring ,
44- is_torch_flex_attn_available ,
45- is_torchdynamo_compiling ,
4644 logging ,
4745)
4846from .configuration_mt5 import MT5Config
4947
5048
51- if is_torch_flex_attn_available ():
52- from torch .nn .attention .flex_attention import BlockMask
53-
54- from ...integrations .flex_attention import make_flex_block_causal_mask
55-
5649logger = logging .get_logger (__name__ )
5750
5851
@@ -735,40 +728,31 @@ def forward(
735728 past_key_values_length , past_key_values_length + seq_length , device = inputs_embeds .device
736729 )
737730
738- if attention_mask is None and not is_torchdynamo_compiling ():
739- # required mask seq length can be calculated via length of past cache
740- mask_seq_length = past_key_values_length + seq_length
741- attention_mask = torch .ones (batch_size , mask_seq_length , device = inputs_embeds .device )
742-
743731 if self .config .is_decoder :
744- causal_mask = self ._update_causal_mask (
745- attention_mask ,
746- inputs_embeds ,
747- cache_position ,
748- past_key_values .self_attention_cache
732+ attention_mask = create_causal_mask (
733+ config = self .config ,
734+ input_embeds = inputs_embeds ,
735+ attention_mask = attention_mask ,
736+ cache_position = cache_position ,
737+ past_key_values = past_key_values .self_attention_cache
749738 if isinstance (past_key_values , EncoderDecoderCache )
750739 else past_key_values ,
751- output_attentions ,
752740 )
753- elif attention_mask is not None :
754- causal_mask = attention_mask [:, None , None , :]
755- causal_mask = causal_mask .to (dtype = inputs_embeds .dtype )
756- causal_mask = (1.0 - causal_mask ) * torch .finfo (inputs_embeds .dtype ).min
757741 else :
758- causal_mask = None
742+ attention_mask = create_bidirectional_mask (
743+ config = self .config ,
744+ input_embeds = inputs_embeds ,
745+ attention_mask = attention_mask ,
746+ )
759747
760- # If a 2D or 3D attention mask is provided for the cross-attention
761- # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
748+ encoder_extended_attention_mask = None
762749 if self .is_decoder and encoder_hidden_states is not None :
763- encoder_batch_size , encoder_sequence_length , _ = encoder_hidden_states .size ()
764- encoder_hidden_shape = (encoder_batch_size , encoder_sequence_length )
765- if encoder_attention_mask is None :
766- encoder_attention_mask = torch .ones (
767- encoder_hidden_shape , device = inputs_embeds .device , dtype = torch .long
768- )
769- encoder_extended_attention_mask = self .invert_attention_mask (encoder_attention_mask )
770- else :
771- encoder_extended_attention_mask = None
750+ encoder_extended_attention_mask = create_bidirectional_mask (
751+ config = self .config ,
752+ input_embeds = inputs_embeds ,
753+ attention_mask = encoder_attention_mask ,
754+ encoder_hidden_states = encoder_hidden_states ,
755+ )
772756
773757 all_hidden_states = () if output_hidden_states else None
774758 all_attentions = () if output_attentions else None
@@ -778,13 +762,13 @@ def forward(
778762
779763 hidden_states = self .dropout (inputs_embeds )
780764
781- for i , layer_module in enumerate ( self .block ) :
765+ for layer_module in self .block :
782766 if output_hidden_states :
783767 all_hidden_states = all_hidden_states + (hidden_states ,)
784768
785769 layer_outputs = layer_module (
786770 hidden_states ,
787- causal_mask ,
771+ attention_mask ,
788772 position_bias ,
789773 encoder_hidden_states ,
790774 encoder_extended_attention_mask ,
@@ -837,131 +821,6 @@ def forward(
837821 cross_attentions = all_cross_attentions ,
838822 )
839823
840- # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._update_causal_mask
841- def _update_causal_mask (
842- self ,
843- attention_mask : Union [torch .Tensor , "BlockMask" ],
844- input_tensor : torch .Tensor ,
845- cache_position : torch .Tensor ,
846- past_key_values : Cache ,
847- output_attentions : bool = False ,
848- ):
849- if self .config ._attn_implementation == "flash_attention_2" :
850- if attention_mask is not None and (attention_mask == 0.0 ).any ():
851- return attention_mask
852- return None
853- if self .config ._attn_implementation == "flex_attention" :
854- if isinstance (attention_mask , torch .Tensor ):
855- attention_mask = make_flex_block_causal_mask (attention_mask )
856- return attention_mask
857-
858- # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
859- # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
860- # to infer the attention mask.
861- past_seen_tokens = past_key_values .get_seq_length () if past_key_values is not None else 0
862- using_compilable_cache = past_key_values .is_compileable if past_key_values is not None else False
863-
864- # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
865- if self .config ._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions :
866- if AttentionMaskConverter ._ignore_causal_mask_sdpa (
867- attention_mask ,
868- inputs_embeds = input_tensor ,
869- past_key_values_length = past_seen_tokens ,
870- is_training = self .training ,
871- ):
872- return None
873-
874- dtype = input_tensor .dtype
875- sequence_length = input_tensor .shape [1 ]
876- if using_compilable_cache :
877- target_length = past_key_values .get_max_cache_shape ()
878- else :
879- target_length = (
880- attention_mask .shape [- 1 ]
881- if isinstance (attention_mask , torch .Tensor )
882- else past_seen_tokens + sequence_length + 1
883- )
884-
885- # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
886- causal_mask = self ._prepare_4d_causal_attention_mask_with_cache_position (
887- attention_mask ,
888- sequence_length = sequence_length ,
889- target_length = target_length ,
890- dtype = dtype ,
891- cache_position = cache_position ,
892- batch_size = input_tensor .shape [0 ],
893- )
894-
895- if (
896- self .config ._attn_implementation == "sdpa"
897- and attention_mask is not None
898- and attention_mask .device .type in ["cuda" , "xpu" , "npu" ]
899- and not output_attentions
900- ):
901- # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
902- # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
903- # Details: https://github.com/pytorch/pytorch/issues/110213
904- min_dtype = torch .finfo (dtype ).min
905- causal_mask = AttentionMaskConverter ._unmask_unattended (causal_mask , min_dtype )
906-
907- return causal_mask
908-
909- @staticmethod
910- # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position
911- def _prepare_4d_causal_attention_mask_with_cache_position (
912- attention_mask : torch .Tensor ,
913- sequence_length : int ,
914- target_length : int ,
915- dtype : torch .dtype ,
916- cache_position : torch .Tensor ,
917- batch_size : int ,
918- ** kwargs ,
919- ):
920- """
921- Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
922- `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
923-
924- Args:
925- attention_mask (`torch.Tensor`):
926- A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
927- `(batch_size, 1, query_length, key_value_length)`.
928- sequence_length (`int`):
929- The sequence length being processed.
930- target_length (`int`):
931- The target length: when generating with static cache, the mask should be as long as the static cache,
932- to account for the 0 padding, the part of the cache that is not filled yet.
933- dtype (`torch.dtype`):
934- The dtype to use for the 4D attention mask.
935- cache_position (`torch.Tensor`):
936- Indices depicting the position of the input sequence tokens in the sequence.
937- batch_size (`torch.Tensor`):
938- Batch size.
939- """
940- if attention_mask is not None and attention_mask .dim () == 4 :
941- # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
942- causal_mask = attention_mask
943- else :
944- min_dtype = torch .finfo (dtype ).min
945- causal_mask = torch .full (
946- (sequence_length , target_length ), fill_value = min_dtype , dtype = dtype , device = cache_position .device
947- )
948- if sequence_length != 1 :
949- causal_mask = torch .triu (causal_mask , diagonal = 1 )
950- causal_mask *= torch .arange (target_length , device = cache_position .device ) > cache_position .reshape (- 1 , 1 )
951- causal_mask = causal_mask [None , None , :, :].expand (batch_size , 1 , - 1 , - 1 )
952- if attention_mask is not None :
953- causal_mask = causal_mask .clone () # copy to contiguous memory for in-place edit
954- mask_length = attention_mask .shape [- 1 ]
955- padding_mask = causal_mask [:, :, :, :mask_length ] + attention_mask [:, None , None , :].to (
956- causal_mask .device
957- )
958- padding_mask = padding_mask == 0
959- causal_mask [:, :, :, :mask_length ] = causal_mask [:, :, :, :mask_length ].masked_fill (
960- padding_mask , min_dtype
961- )
962-
963- return causal_mask
964-
965824
966825@auto_docstring
967826class MT5Model (MT5PreTrainedModel ):
0 commit comments