@@ -45,7 +45,13 @@ def from_name(cls, name: str):
4545 return cls (** transformer_configs [name ])
4646 # fuzzy search
4747 config = [config for config in transformer_configs if config in str (name ).upper () or config in str (name )]
48- assert len (config ) == 1 , name
48+
49+ # We may have two or more configs matched (e.g. "7B" and "Mistral-7B"). Find the best config match,
50+ # take longer name (as it have more symbols matched)
51+ if len (config ) > 1 :
52+ config .sort (key = len , reverse = True )
53+ assert len (config [0 ]) != len (config [1 ]), name # make sure only one 'best' match
54+
4955 return cls (** transformer_configs [config [0 ]])
5056
5157
@@ -56,6 +62,7 @@ def from_name(cls, name: str):
5662 "30B" : dict (n_layer = 60 , n_head = 52 , dim = 6656 ),
5763 "34B" : dict (n_layer = 48 , n_head = 64 , dim = 8192 , vocab_size = 32000 , n_local_heads = 8 , intermediate_size = 22016 , rope_base = 1000000 ), # CodeLlama-34B-Python-hf
5864 "70B" : dict (n_layer = 80 , n_head = 64 , dim = 8192 , n_local_heads = 8 , intermediate_size = 28672 ),
65+ "Mistral-7B" : dict (n_layer = 32 , n_head = 32 , n_local_heads = 8 , dim = 4096 , intermediate_size = 14336 , vocab_size = 32000 ),
5966}
6067
6168class KVCache (nn .Module ):
0 commit comments