@@ -152,7 +152,7 @@ fn resolve_attention(config: &Option<Config>, lora_adapters: &Option<String>) ->
152152 "flashdecoding"
153153 } ;
154154
155- match config. head_dim {
155+ match config. get_head_dim ( ) {
156156 Some ( h) if h == 64 || h == 128 || h == 256 => {
157157 if lora_adapters. is_some ( ) && prefix_caching. is_none ( ) {
158158 tracing:: info!( "Disabling prefix caching because of lora adapters" ) ;
@@ -214,6 +214,7 @@ struct RawConfig {
214214 num_key_value_heads : Option < usize > ,
215215 num_hidden_layers : Option < usize > ,
216216 head_dim : Option < usize > ,
217+ text_config : Option < TextConfig > ,
217218 vision_config : Option < VisionConfig > ,
218219 is_encoder_decoder : Option < bool > ,
219220 #[ serde( rename = "num_experts_per_tok" ) ]
@@ -233,6 +234,11 @@ struct QuantizationConfig {
233234#[ derive( Debug , Deserialize ) ]
234235struct VisionConfig { }
235236
237+ #[ derive( Debug , Deserialize ) ]
238+ struct TextConfig {
239+ head_dim : Option < usize > ,
240+ }
241+
236242#[ derive( Debug , Deserialize ) ]
237243struct Config {
238244 max_position_embeddings : Option < usize > ,
@@ -244,6 +250,7 @@ struct Config {
244250 intermediate_size : Option < usize > ,
245251 hidden_size : Option < usize > ,
246252 model_type : Option < String > ,
253+ text_config : Option < TextConfig > ,
247254 vision_config : Option < VisionConfig > ,
248255 is_encoder_decoder : bool ,
249256 num_experts_per_token : usize ,
@@ -253,6 +260,14 @@ struct Config {
253260}
254261
255262impl Config {
263+ fn get_head_dim ( & self ) -> Option < usize > {
264+ self . head_dim . or_else ( || {
265+ self . text_config
266+ . as_ref ( )
267+ . and_then ( |text_config| text_config. head_dim )
268+ } )
269+ }
270+
256271 fn flop ( & self ) -> Option < u64 > {
257272 if self . vision_config . is_some ( ) {
258273 // VLM are much harder to predict and VRAM requirements
@@ -261,7 +276,7 @@ impl Config {
261276 }
262277 let num_heads = self . num_heads ? as u64 ;
263278 let num_kv_heads = self . num_kv_heads ? as u64 ;
264- let head_dim = self . head_dim ? as u64 ;
279+ let head_dim = self . get_head_dim ( ) ? as u64 ;
265280 let hidden_size = self . hidden_size ? as u64 ;
266281 let intermediate_size = ( self . intermediate_size ?
267282 * ( self . num_experts_per_token + self . num_shared_experts ) )
@@ -289,7 +304,7 @@ impl Config {
289304 }
290305 // 2 for key and values
291306 // 2 for f16 dtype?
292- Some ( self . num_kv_heads ? * 2 * self . head_dim ? * 2 * self . num_layers ?)
307+ Some ( self . num_kv_heads ? * 2 * self . get_head_dim ( ) ? * 2 * self . num_layers ?)
293308 }
294309
295310 fn mlp_vram_per_tok ( & self ) -> Option < usize > {
@@ -310,8 +325,8 @@ impl Config {
310325 }
311326
312327 fn model_vram ( & self ) -> Option < usize > {
313- let attn_vram = ( self . num_heads ? + 2 * self . num_kv_heads ?) * self . head_dim ?;
314- let o_vram = self . num_heads ? * self . head_dim ? * self . hidden_size ?;
328+ let attn_vram = ( self . num_heads ? + 2 * self . num_kv_heads ?) * self . get_head_dim ( ) ?;
329+ let o_vram = self . num_heads ? * self . get_head_dim ( ) ? * self . hidden_size ?;
315330 // gate + up + down = 3
316331 let mlp_vram = 3 * self . intermediate_size ? * self . num_experts * self . hidden_size ?;
317332 let layer_vram = mlp_vram + attn_vram + o_vram;
@@ -349,6 +364,7 @@ impl From<RawConfig> for Config {
349364 let num_kv_heads = other. num_key_value_heads . or ( other. num_attention_heads ) ;
350365 let intermediate_size = other. intermediate_size ;
351366 let model_type = other. model_type ;
367+ let text_config = other. text_config ;
352368 let vision_config = other. vision_config ;
353369 let is_encoder_decoder = other. is_encoder_decoder . unwrap_or ( false ) ;
354370 let num_experts_per_token = other. num_experts_per_token . unwrap_or ( 1 ) ;
@@ -360,6 +376,7 @@ impl From<RawConfig> for Config {
360376 quantize,
361377 head_dim,
362378 model_type,
379+ text_config,
363380 vision_config,
364381 is_encoder_decoder,
365382 hidden_size,
0 commit comments