@@ -508,17 +508,15 @@ def modify(
508508 # https://github.com/huggingface/transformers/blob/v4.56-release/src/transformers/trainer.py#L566
509509 self .is_quantized = False
510510
511- self .num_ttt_steps = 3 # NOTE: (hg) hardcoded for now. Might add to config later.
512- self ._cached_attn_blk_masks = []
511+ self .num_ttt_steps = 4 # NOTE: (hg) hardcoded for now. Might add to config later.
512+ self ._cached_attn_blk_masks = {}
513513
514514 def _get_ttt_attention_mask (self , seq_length , ttt_step ):
515515 # compile and cached flex attention masks in first call
516- if ttt_step >= len ( self ._cached_attn_blk_masks ) :
517- self ._cached_attn_blk_masks .append (
518- self ._compute_ttt_attention_mask (seq_length , ttt_step )
516+ if ttt_step not in self ._cached_attn_blk_masks :
517+ self ._cached_attn_blk_masks .update (
518+ { ttt_step : self ._compute_ttt_attention_mask (seq_length , ttt_step )}
519519 )
520-
521- # return cached flex attention mask
522520 return self ._cached_attn_blk_masks [ttt_step ]
523521
524522 def _prepare_decoder_attention_mask (
@@ -600,44 +598,26 @@ def _get_eagle_module_inputs(
600598
601599 def _compute_ttt_attention_mask (self , seq_length , ttt_step ) -> BlockMask | torch .Tensor :
602600 """Return TTT attention_mask tensor of type BlockMask or Tensor depends on eagle attn impl."""
603- if ttt_step == 0 :
604-
605- def msk (b , h , q_idx , kv_idx ):
606- # symbolic attention mask of shape [seq_len, 2* seq_len] for TTT step 0
607- return (kv_idx <= (q_idx - 1 )) | (kv_idx == q_idx + seq_length )
608601
609- elif ttt_step == 1 :
610-
611- def msk (b , h , q_idx , kv_idx ):
612- # attention mask of shape [seq_len, 3* seq_len] for TTT step 1
613- return (
614- (kv_idx <= (q_idx - 2 ))
615- | ((kv_idx == q_idx + seq_length - 1 ) & (kv_idx >= seq_length ))
616- | ((kv_idx == q_idx + 2 * seq_length ) & (kv_idx >= seq_length * 2 ))
617- )
618- elif ttt_step == 2 :
619-
620- def msk (b , h , q_idx , kv_idx ):
621- # attention mask of shape [seq_len, 4* seq_len] for TTT step 2
622- return (
623- (kv_idx <= (q_idx - 3 ))
624- | ((kv_idx == q_idx + seq_length - 2 ) & (kv_idx >= seq_length ))
625- | ((kv_idx == q_idx + 2 * seq_length - 1 ) & (kv_idx >= seq_length * 2 ))
626- | ((kv_idx == q_idx + 3 * seq_length ) & (kv_idx >= seq_length * 3 ))
602+ def msk_func (b , h , q_idx , kv_idx ):
603+ mask = kv_idx <= (q_idx - ttt_step )
604+ for i in range (1 , ttt_step + 1 ):
605+ mask_block_i = (kv_idx == q_idx + i * seq_length - (ttt_step - i )) & (
606+ kv_idx >= seq_length * i
627607 )
628- else :
629- raise ValueError ( f"EAGLE TTT step { ttt_step } is not supported" )
608+ mask = mask | mask_block_i
609+ return mask
630610
631611 dtypemin = torch .finfo (self ._base_llm_config .dtype ).min
632612 q_len = seq_length
633- kv_len = seq_length * (2 + ttt_step )
613+ kv_len = seq_length * (1 + ttt_step )
634614 if self .eagle_module .config ._attn_implementation == "flex_attention" :
635615 # Return block mask for flex attention
636- block_mask = create_block_mask (msk , B = None , H = None , Q_LEN = q_len , KV_LEN = kv_len )
616+ block_mask = create_block_mask (msk_func , B = None , H = None , Q_LEN = q_len , KV_LEN = kv_len )
637617 return block_mask
638618 else :
639619 # Return tensor mask for non-flex attention
640- tensor_mask = msk (
620+ tensor_mask = msk_func (
641621 None ,
642622 None ,
643623 torch .arange (q_len ).view (1 , 1 , q_len , 1 ),
@@ -847,69 +827,54 @@ def forward(
847827 inputs_embeds = self ._llm_or_vlm_embedding (eagle_input_ids , kwargs )
848828
849829 position_embeddings = self .eagle_rotary_emb (eagle_input_hidden_states , position_ids )
850-
851- # Then, we run eagle forward
852- _ , eagle_prenorm_h , eagle_logits , eagle_cache = self ._eagle_forward (
853- eagle_input_hidden_states ,
854- inputs_embeds ,
855- attention_mask_0 ,
856- position_ids ,
857- position_embeddings ,
858- eagle_cache ,
859- )
860-
861830 past_key_values .eagle_cache = eagle_cache
862831
863- # Compute loss on the eagle modules
864- classification_loss , acc = self ._eagle_loss (
865- base_model_logits [:, 1 :],
866- eagle_logits [:, :- 1 ],
867- loss_mask [:, 1 :],
868- )
869- eagle_loss = classification_loss
870- train_accs .append (acc )
871-
872832 # ====Perform training-time-testing with 3 extra eagle forward passes====
873- if self .training :
874- for ttt_step in range (self .num_ttt_steps ):
875- eagle_input_hidden_states = torch .cat (
833+ for ttt_step in range (self .num_ttt_steps ):
834+ attention_mask = (
835+ attention_mask_0
836+ if ttt_step == 0
837+ else self ._get_ttt_attention_mask (seq_length , ttt_step )
838+ )
839+ _ , eagle_input_hidden_states , eagle_logits , eagle_cache = self ._eagle_forward (
840+ eagle_input_hidden_states ,
841+ inputs_embeds ,
842+ attention_mask ,
843+ position_ids ,
844+ position_embeddings ,
845+ eagle_cache ,
846+ )
847+ eagle_input_hidden_states = torch .cat (
848+ (
849+ torch .zeros (
850+ (b , 1 , h ),
851+ dtype = eagle_input_hidden_states .dtype ,
852+ device = eagle_input_hidden_states .device ,
853+ ),
854+ eagle_input_hidden_states [:, :- 1 , :],
855+ ),
856+ dim = 1 ,
857+ )
858+ classification_loss , acc = self ._eagle_loss (
859+ # base model predict +1 tok, while eagle predict +2
860+ # so we shift base model outputs compared to eagle outputs
861+ base_model_logits [:, 1 :],
862+ eagle_logits [:, :- 1 ],
863+ # additionally, we mask the first n tok of eagle outputs at nth TTT step
864+ torch .cat (
876865 (
877- torch .zeros (
878- (b , 1 , h ),
879- dtype = eagle_input_hidden_states .dtype ,
880- device = eagle_input_hidden_states .device ,
881- ),
882- eagle_prenorm_h [:, :- 1 , :],
866+ torch .zeros (b , ttt_step , dtype = loss_mask .dtype , device = loss_mask .device ),
867+ loss_mask [:, 1 + ttt_step :],
883868 ),
884869 dim = 1 ,
885- )
886- attention_mask = self ._get_ttt_attention_mask (seq_length , ttt_step )
887- _ , eagle_prenorm_h , eagle_logits , eagle_cache = self ._eagle_forward (
888- eagle_input_hidden_states ,
889- inputs_embeds ,
890- attention_mask ,
891- position_ids ,
892- position_embeddings ,
893- eagle_cache ,
894- )
895- classification_loss , acc = self ._eagle_loss (
896- # base model predict +1 tok, while eagle predict +2
897- # so we shift base model outputs compared to eagle outputs
898- base_model_logits [:, 1 :],
899- eagle_logits [:, :- 1 ],
900- # additionally, we mask the first n tok of eagle outputs at nth TTT step
901- torch .cat (
902- (
903- torch .zeros (
904- b , 1 + ttt_step , dtype = loss_mask .dtype , device = loss_mask .device
905- ),
906- loss_mask [:, 2 + ttt_step :],
907- ),
908- dim = 1 ,
909- ),
910- )
911- eagle_loss += classification_loss
912- train_accs .append (acc )
870+ ),
871+ )
872+ eagle_loss = (
873+ classification_loss if eagle_loss is None else eagle_loss + classification_loss
874+ )
875+ train_accs .append (acc )
876+ if not self .training :
877+ break
913878 # Finally, we merge base model loss and eagle loss, raise error if both are None
914879 if base_model_loss is not None and eagle_loss is not None :
915880 loss = base_model_loss + eagle_loss
0 commit comments