@@ -50,6 +50,7 @@ def from_name(cls, name: str):
5050
5151
5252transformer_configs = {
53+ "gemma-2b" : dict (dim = 2048 , vocab_size = 256000 , n_layer = 18 , n_head = 8 , n_local_heads = 1 , intermediate_size = 16384 ),
5354 "CodeLlama-7b-Python-hf" : dict (block_size = 16384 , vocab_size = 32000 , n_layer = 32 , dim = 4096 , rope_base = 1000000 ),
5455 "7B" : dict (n_layer = 32 , n_head = 32 , dim = 4096 ),
5556 "13B" : dict (n_layer = 40 , n_head = 40 , dim = 5120 ),
@@ -109,6 +110,7 @@ def forward(self, idx: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:
109110 mask = self .causal_mask [None , None , input_pos ]
110111 freqs_cis = self .freqs_cis [input_pos ]
111112 x = self .tok_embeddings (idx )
113+ x = (self .config .dim ** 0.5 ) * x
112114
113115 for i , layer in enumerate (self .layers ):
114116 x = layer (x , input_pos , freqs_cis , mask )
@@ -195,7 +197,7 @@ def __init__(self, config: ModelArgs) -> None:
195197 self .w2 = nn .Linear (config .intermediate_size , config .dim , bias = False )
196198
197199 def forward (self , x : Tensor ) -> Tensor :
198- return self .w2 (F .silu (self .w1 (x )) * self .w3 (x ))
200+ return self .w2 (F .gelu (self .w1 (x )) * self .w3 (x ))
199201
200202
201203class RMSNorm (nn .Module ):
@@ -209,7 +211,7 @@ def _norm(self, x):
209211
210212 def forward (self , x : Tensor ) -> Tensor :
211213 output = self ._norm (x .float ()).type_as (x )
212- return output * self .weight
214+ return output * ( 1 + self .weight )
213215
214216
215217def precompute_freqs_cis (
0 commit comments