33
44# This source code is licensed under the license found in the
55# LICENSE file in the root directory of this source tree.
6+ import math
67from dataclasses import dataclass
78from typing import Optional
89
@@ -29,6 +30,7 @@ class ModelArgs:
2930 head_dim : int = 64
3031 rope_base : float = 10000
3132 norm_eps : float = 1e-5
33+ rope_scaling : Optional [dict ] = None
3234
3335 def __post_init__ (self ):
3436 if self .n_local_heads == - 1 :
@@ -68,6 +70,9 @@ def from_name(cls, name: str):
6870
6971 "llama-3-8b" : dict (block_size = 8192 , n_layer = 32 , n_head = 32 , n_local_heads = 8 , dim = 4096 , intermediate_size = 14336 , vocab_size = 128256 , rope_base = 500000 ),
7072 "llama-3-70b" : dict (block_size = 8192 , n_layer = 80 , n_head = 64 , n_local_heads = 8 , dim = 8192 , intermediate_size = 28672 , vocab_size = 128256 , rope_base = 500000 ),
73+ "llama-3.1-405b" : dict (block_size = 131072 , n_layer = 126 , n_head = 128 , n_local_heads = 8 , dim = 16384 , intermediate_size = 53248 , vocab_size = 128256 , rope_base = 500000 ,
74+ rope_scaling = dict (factor = 8.0 , low_freq_factor = 1.0 , high_freq_factor = 4.0 , original_max_position_embeddings = 8192 ),
75+ ),
7176}
7277
7378class KVCache (nn .Module ):
@@ -119,7 +124,7 @@ def setup_caches(self, max_batch_size, max_seq_length):
119124 for b in self .layers :
120125 b .attention .kv_cache = KVCache (max_batch_size , max_seq_length , self .config .n_local_heads , head_dim , dtype )
121126
122- self .freqs_cis = precompute_freqs_cis (self .config .block_size , self .config .dim // self .config .n_head , self .config .rope_base , dtype )
127+ self .freqs_cis = precompute_freqs_cis (self .config .block_size , self .config .dim // self .config .n_head , self .config .rope_base , dtype , self . config . rope_scaling )
123128 self .causal_mask = torch .tril (torch .ones (self .max_seq_length , self .max_seq_length , dtype = torch .bool ))
124129
125130 def forward (self , idx : Tensor , input_pos : Optional [Tensor ] = None ) -> Tensor :
@@ -230,11 +235,36 @@ def forward(self, x: Tensor) -> Tensor:
230235 return output * self .weight
231236
232237
238+ def apply_rope_scaling (freqs : torch .Tensor , rope_scaling : Optional [dict ] = None ):
239+ factor = rope_scaling ["factor" ]
240+ low_freq_factor = rope_scaling ["low_freq_factor" ]
241+ high_freq_factor = rope_scaling ["high_freq_factor" ]
242+ old_context_len = rope_scaling ["original_max_position_embeddings" ]
243+
244+ low_freq_wavelen = old_context_len / low_freq_factor
245+ high_freq_wavelen = old_context_len / high_freq_factor
246+ new_freqs = []
247+ for freq in freqs :
248+ wavelen = 2 * math .pi / freq
249+ if wavelen < high_freq_wavelen :
250+ new_freqs .append (freq )
251+ elif wavelen > low_freq_wavelen :
252+ new_freqs .append (freq / factor )
253+ else :
254+ assert low_freq_wavelen != high_freq_wavelen
255+ smooth = (old_context_len / wavelen - low_freq_factor ) / (high_freq_factor - low_freq_factor )
256+ new_freqs .append ((1 - smooth ) * freq / factor + smooth * freq )
257+ return torch .tensor (new_freqs , dtype = freqs .dtype , device = freqs .device )
258+
259+
233260def precompute_freqs_cis (
234261 seq_len : int , n_elem : int , base : int = 10000 ,
235- dtype : torch .dtype = torch .bfloat16
262+ dtype : torch .dtype = torch .bfloat16 ,
263+ rope_scaling : Optional [dict ] = None ,
236264) -> Tensor :
237265 freqs = 1.0 / (base ** (torch .arange (0 , n_elem , 2 )[: (n_elem // 2 )].float () / n_elem ))
266+ if rope_scaling is not None :
267+ freqs = apply_rope_scaling (freqs , rope_scaling )
238268 t = torch .arange (seq_len , device = freqs .device )
239269 freqs = torch .outer (t , freqs )
240270 freqs_cis = torch .polar (torch .ones_like (freqs ), freqs )
0 commit comments