File tree Expand file tree Collapse file tree 1 file changed +16
-5
lines changed Expand file tree Collapse file tree 1 file changed +16
-5
lines changed Original file line number Diff line number Diff line change @@ -260,11 +260,22 @@ struct Config {
260260
261261impl Config {
262262 fn get_head_dim ( & self ) -> Option < usize > {
263- self . head_dim . or_else ( || {
264- self . text_config
265- . as_ref ( )
266- . and_then ( |text_config| text_config. head_dim )
267- } )
263+ if let Some ( head_dim) = self . head_dim {
264+ return Some ( head_dim) ;
265+ }
266+
267+ let text_config = self . text_config . as_ref ( ) ?;
268+ if let Some ( head_size) = text_config. head_dim {
269+ return Some ( head_size) ;
270+ }
271+
272+ match self . model_type . as_deref ( ) {
273+ // We special-case gemma3 here, since we need flashinfer for
274+ // handling bidirectional masks. And flashinfer can only be
275+ // used when the head size is known.
276+ Some ( "gemma3" ) => Some ( 256 ) ,
277+ _ => None ,
278+ }
268279 }
269280
270281 fn flop ( & self ) -> Option < u64 > {
You can’t perform that action at this time.
0 commit comments