@@ -26,7 +26,7 @@ class ModelArgs:
2626 dim : int = 4096
2727 intermediate_size : int = None
2828 n_local_heads : int = - 1
29- head_dim : int = 64
29+ head_dim : int = None
3030 rope_base : float = 10000
3131 norm_eps : float = 1e-5
3232
@@ -37,7 +37,8 @@ def __post_init__(self):
3737 hidden_dim = 4 * self .dim
3838 n_hidden = int (2 * hidden_dim / 3 )
3939 self .intermediate_size = find_multiple (n_hidden , 256 )
40- self .head_dim = self .dim // self .n_head
40+ if self .head_dim is None :
41+ self .head_dim = self .dim // self .n_head
4142
4243 @classmethod
4344 def from_name (cls , name : str ):
@@ -51,6 +52,7 @@ def from_name(cls, name: str):
5152
5253transformer_configs = {
5354 "gemma-2b" : dict (dim = 2048 , vocab_size = 256000 , n_layer = 18 , n_head = 8 , n_local_heads = 1 , intermediate_size = 16384 ),
55+ "gemma-7b" : dict (dim = 3072 , vocab_size = 256000 , n_layer = 28 , n_head = 16 , n_local_heads = 16 , intermediate_size = 24576 , head_dim = 256 ),
5456 "CodeLlama-7b-Python-hf" : dict (block_size = 16384 , vocab_size = 32000 , n_layer = 32 , dim = 4096 , rope_base = 1000000 ),
5557 "7B" : dict (n_layer = 32 , n_head = 32 , dim = 4096 ),
5658 "13B" : dict (n_layer = 40 , n_head = 40 , dim = 5120 ),
@@ -95,14 +97,13 @@ def __init__(self, config: ModelArgs) -> None:
9597 def setup_caches (self , max_batch_size , max_seq_length ):
9698 if self .max_seq_length >= max_seq_length and self .max_batch_size >= max_batch_size :
9799 return
98- head_dim = self .config .dim // self .config .n_head
99100 max_seq_length = find_multiple (max_seq_length , 8 )
100101 self .max_seq_length = max_seq_length
101102 self .max_batch_size = max_batch_size
102103 for b in self .layers :
103- b .attention .kv_cache = KVCache (max_batch_size , max_seq_length , self .config .n_local_heads , head_dim )
104+ b .attention .kv_cache = KVCache (max_batch_size , max_seq_length , self .config .n_local_heads , self . config . head_dim )
104105
105- self .freqs_cis = precompute_freqs_cis (self .config .block_size , self .config .dim // self . config . n_head , self .config .rope_base )
106+ self .freqs_cis = precompute_freqs_cis (self .config .block_size , self .config .head_dim , self .config .rope_base )
106107 self .causal_mask = torch .tril (torch .ones (self .max_seq_length , self .max_seq_length , dtype = torch .bool ))
107108
108109 def forward (self , idx : Tensor , input_pos : Optional [Tensor ] = None ) -> Tensor :
@@ -145,7 +146,7 @@ def __init__(self, config: ModelArgs):
145146 total_head_dim = (config .n_head + 2 * config .n_local_heads ) * config .head_dim
146147 # key, query, value projections for all heads, but in a batch
147148 self .wqkv = nn .Linear (config .dim , total_head_dim , bias = False )
148- self .wo = nn .Linear (config .dim , config .dim , bias = False )
149+ self .wo = nn .Linear (config .n_head * config . head_dim , config .dim , bias = False )
149150 self .kv_cache = None
150151
151152 self .n_head = config .n_head
@@ -165,7 +166,7 @@ def forward(self, x: Tensor, freqs_cis: Tensor, mask: Tensor, input_pos: Optiona
165166 bsz , seqlen , _ = x .shape
166167
167168 kv_size = self .n_local_heads * self .head_dim
168- q , k , v = self .wqkv (x ).split ([self .dim , kv_size , kv_size ], dim = - 1 )
169+ q , k , v = self .wqkv (x ).split ([self .n_head * self . head_dim , kv_size , kv_size ], dim = - 1 )
169170
170171 q = q .view (bsz , seqlen , self .n_head , self .head_dim )
171172 k = k .view (bsz , seqlen , self .n_local_heads , self .head_dim )
@@ -183,7 +184,7 @@ def forward(self, x: Tensor, freqs_cis: Tensor, mask: Tensor, input_pos: Optiona
183184 v = v .repeat_interleave (self .n_head // self .n_local_heads , dim = 1 )
184185 y = F .scaled_dot_product_attention (q , k , v , attn_mask = mask , dropout_p = 0.0 )
185186
186- y = y .transpose (1 , 2 ).contiguous ().view (bsz , seqlen , self .dim )
187+ y = y .transpose (1 , 2 ).contiguous ().view (bsz , seqlen , self .n_head * self . head_dim )
187188
188189 y = self .wo (y )
189190 return y
@@ -197,7 +198,7 @@ def __init__(self, config: ModelArgs) -> None:
197198 self .w2 = nn .Linear (config .intermediate_size , config .dim , bias = False )
198199
199200 def forward (self , x : Tensor ) -> Tensor :
200- return self .w2 (F .gelu (self .w1 (x )) * self .w3 (x ))
201+ return self .w2 (F .gelu (self .w1 (x ), approximate = "tanh" ) * self .w3 (x ))
201202
202203
203204class RMSNorm (nn .Module ):
0 commit comments