@@ -601,34 +601,35 @@ def __init__(
601601 super ().__init__ (dim , max_position_embeddings , base )
602602
603603 def _set_cos_sin_cache (self , seq_len ):
604- self .max_seq_len_cached = seq_len
605- dim = self .dim
606-
607- freq_extra = 1.0 / (self .base ** (paddle .arange (0 , dim , 2 , dtype = paddle .float32 ) / dim ))
608- freq_inter = 1.0 / (self .scaling_factor * self .base ** (paddle .arange (0 , dim , 2 , dtype = paddle .float32 ) / dim ))
609-
610- low , high = yarn_find_correction_range (
611- self .beta_fast ,
612- self .beta_slow ,
613- dim ,
614- self .base ,
615- self .original_max_position_embeddings ,
616- )
617- inv_freq_mask = 1.0 - yarn_linear_ramp_mask (low , high , dim // 2 )
618- self .inv_freq = freq_inter * (1 - inv_freq_mask ) + freq_extra * inv_freq_mask
604+ with paddle .amp .auto_cast (False ):
605+ self .max_seq_len_cached = seq_len
606+ dim = self .dim
607+
608+ freq_extra = 1.0 / (self .base ** (paddle .arange (0 , dim , 2 , dtype = paddle .float32 ) / dim ))
609+ freq_inter = 1.0 / (self .scaling_factor * self .base ** (paddle .arange (0 , dim , 2 , dtype = paddle .float32 ) / dim ))
610+
611+ low , high = yarn_find_correction_range (
612+ self .beta_fast ,
613+ self .beta_slow ,
614+ dim ,
615+ self .base ,
616+ self .original_max_position_embeddings ,
617+ )
618+ inv_freq_mask = 1.0 - yarn_linear_ramp_mask (low , high , dim // 2 )
619+ self .inv_freq = freq_inter * (1 - inv_freq_mask ) + freq_extra * inv_freq_mask
619620
620- t = paddle .arange (seq_len , dtype = paddle .float32 )
621+ t = paddle .arange (seq_len , dtype = paddle .float32 )
621622
622- freqs = paddle .outer (t , paddle .cast (self .inv_freq , dtype = "float32" ))
623+ freqs = paddle .outer (t , paddle .cast (self .inv_freq , dtype = "float32" ))
623624
624- _mscale = float (
625- yarn_get_mscale (self .scaling_factor , self .mscale )
626- / yarn_get_mscale (self .scaling_factor , self .mscale_all_dim )
627- )
625+ _mscale = float (
626+ yarn_get_mscale (self .scaling_factor , self .mscale )
627+ / yarn_get_mscale (self .scaling_factor , self .mscale_all_dim )
628+ )
628629
629- emb = paddle .concat ((freqs , freqs ), axis = - 1 )
630- self .cos_cached = emb .cos () * _mscale
631- self .sin_cached = emb .sin () * _mscale
630+ emb = paddle .concat ((freqs , freqs ), axis = - 1 )
631+ self .cos_cached = emb .cos () * _mscale
632+ self .sin_cached = emb .sin () * _mscale
632633
633634
634635def rotate_half (x ):
0 commit comments