@@ -4971,112 +4971,6 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
49714971 yield (new_name , data_torch )
49724972
49734973
4974- @ModelBase .register ("BambaForCausalLM" )
4975- class BambaModel (Mamba2Model ):
4976- """Bamba is a hybrid SSM + Attention model that uses Mamba2 SSM layers"""
4977- model_arch = gguf .MODEL_ARCH .BAMBA
4978- undo_permute = True
4979-
4980- def __init__ (self , * args , ** kwargs ):
4981-
4982- # Hybrid mamba models use a prefix for the mamba-specific params.
4983- # TODO: Extend this if the prefix(es) need to be configurable
4984- self .hparam_prefixes = ["mamba" ]
4985-
4986- super ().__init__ (* args , ** kwargs )
4987-
4988- # Use Llama conversion for attention
4989- self ._transformer_model_class : type [TextModel ] = LlamaModel
4990-
4991- # Lists of which layers use ssm vs attention
4992- self ._attn_layers = self .get_attn_layres ()
4993- self ._ssm_layers = [
4994- i for i in range (self .block_count )
4995- if i not in self ._attn_layers
4996- ]
4997-
4998- # n_group and d_inner are used during reshape_tensors for mamaba2
4999- self .d_model = self .find_hparam (["hidden_size" , "d_model" ])
5000- self .n_group = self .find_hparam (["n_groups" ])
5001- self .d_inner = self .find_hparam (["expand" ]) * self .d_model
5002-
5003- def get_attn_layres (self ) -> list [int ]:
5004- attn_layers = self .hparams .get ("attn_layer_indices" , [])
5005- if not attn_layers :
5006- attn_period = self .hparams .get ("attn_layer_period" )
5007- assert attn_period , "Didn't find attn_layer_indices or attn_layer_period"
5008- attn_offset = self .hparams .get ("attn_layer_offset" )
5009- assert attn_offset is not None , "No attention layer offset set with attn_layer_period"
5010- attn_layers = [
5011- i for i in range (self .block_count )
5012- if i % attn_period == attn_offset
5013- ]
5014- return attn_layers
5015-
5016- def find_hparam (self , keys : Iterable [str ], * args , ** kwargs ) -> Any :
5017- prefixed = []
5018- for pfx in self .hparam_prefixes :
5019- prefixed .extend (
5020- "_" .join ([pfx , k ])
5021- for k in keys
5022- )
5023- keys = list (keys ) + prefixed
5024- return super ().find_hparam (keys , * args , ** kwargs )
5025-
5026- def set_gguf_parameters (self ):
5027-
5028- ## General Params ##
5029- self .gguf_writer .add_embedding_length (self .d_model )
5030- self .gguf_writer .add_block_count (self .block_count )
5031- self .gguf_writer .add_context_length (self .hparams .get ("max_position_embeddings" , 0 ))
5032- self .gguf_writer .add_vocab_size (self .hparams ["vocab_size" ])
5033- self .gguf_writer .add_feed_forward_length (self .hparams ["intermediate_size" ])
5034-
5035- ## Mamba mixer params ##
5036- self .gguf_writer .add_ssm_conv_kernel (self .find_hparam (["conv_kernel" , "d_conv" ]))
5037- self .gguf_writer .add_ssm_state_size (self .find_hparam (["state_size" , "d_state" ]))
5038- self .gguf_writer .add_ssm_group_count (self .n_group )
5039- self .gguf_writer .add_ssm_inner_size (self .d_inner )
5040- # NOTE: The mamba_dt_rank is _not_ the right field for how this is used
5041- # in llama.cpp
5042- self .gguf_writer .add_ssm_time_step_rank (self .find_hparam (["n_heads" ]))
5043-
5044- ## Attention params ##
5045- self .gguf_writer .add_attn_layer_indices (self ._attn_layers )
5046- if rope_dim := self .hparams .get ("attn_rotary_emb" ):
5047- self .gguf_writer .add_rope_dimension_count (rope_dim )
5048- self .gguf_writer .add_head_count (self .hparams ["num_attention_heads" ])
5049- self .gguf_writer .add_head_count_kv (self .find_hparam (["num_key_value_heads" , "n_head_kv" ]))
5050-
5051- ## Feed Forward Params ##
5052- self .gguf_writer .add_layer_norm_rms_eps (
5053- self .find_hparam (["layer_norm_epsilon" , "rms_norm_eps" ], optional = True ) or 1e-5
5054- )
5055-
5056- ## Validation ##
5057- d_head = self .find_hparam (["d_head" ], optional = True ) or 64
5058- assert self .hparams .get ("hidden_act" ) in [None , "silu" ], "Only SILU activation supported"
5059- assert self .d_inner % d_head == 0 , f"SSM inner size { self .d_inner } not a multiple of head dim { d_head } "
5060-
5061- def modify_tensors (
5062- self , data_torch : Tensor , name : str , bid : int | None
5063- ) -> Iterable [tuple [str , Tensor ]]:
5064-
5065- # Determine whether this is a mamaba layer or an attention layer
5066- if bid in self ._ssm_layers :
5067- for mamba_new_name , data_torch in super ().modify_tensors (
5068- data_torch , name , bid
5069- ):
5070- yield mamba_new_name , data_torch
5071- elif bid in self ._attn_layers :
5072- for llama_new_name , data_torch in self ._transformer_model_class .modify_tensors (
5073- self , data_torch , name , bid
5074- ):
5075- yield llama_new_name , data_torch
5076- else :
5077- yield self .map_tensor_name (name ), data_torch
5078-
5079-
50804974@ModelBase .register ("JambaForCausalLM" )
50814975class JambaModel (TextModel ):
50824976 model_arch = gguf .MODEL_ARCH .JAMBA
@@ -6579,19 +6473,66 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
65796473 return super ().modify_tensors (data_torch , name , bid )
65806474
65816475
6582- @ModelBase .register ("GraniteMoeHybridForCausalLM" )
6583- class GraniteMoeHybridModel (BambaModel , GraniteMoeModel ):
6584- """GraniteMoeHybrid is a hybrid SSM + MoE Attention model that uses Mamba2
6585- SSM layers"""
6586- model_arch = gguf .MODEL_ARCH .GRANITE_MOE_HYBRID
6476+ @ModelBase .register ("GraniteMoeHybridForCausalLM" , "BambaForCausalLM" )
6477+ class GraniteHybridModel (Mamba2Model , GraniteMoeModel ):
6478+ """GraniteHybrid is a hybrid SSM + Attention model that uses Mamba2 SSM
6479+ layers and optionally uses MoE w/ a shared expert"""
6480+ model_arch = gguf .MODEL_ARCH .GRANITE_HYBRID
6481+ undo_permute = True
6482+
6483+ def __init__ (self , * args , ** kwargs ):
6484+
6485+ # Hybrid mamba models use a prefix for the mamba-specific params.
6486+ # TODO: Extend this if the prefix(es) need to be configurable
6487+ self .hparam_prefixes = ["mamba" ]
6488+
6489+ super ().__init__ (* args , ** kwargs )
6490+
6491+ # Use Granite conversion for attention
6492+ self ._transformer_model_class : type [TextModel ] = GraniteModel
6493+
6494+ # Lists of which layers use ssm vs attention
6495+ self ._attn_layers = self .get_attn_layres ()
6496+ self ._ssm_layers = [
6497+ i for i in range (self .block_count )
6498+ if i not in self ._attn_layers
6499+ ]
6500+
6501+ # n_group and d_inner are used during reshape_tensors for mamaba2
6502+ self .d_model = self .find_hparam (["hidden_size" , "d_model" ])
6503+ self .n_group = self .find_hparam (["n_groups" ])
6504+ self .d_inner = self .find_hparam (["expand" ]) * self .d_model
65876505
65886506 def get_attn_layres (self ):
6507+ # Explicit list of layer type names
65896508 if layer_types := self .hparams .get ("layer_types" ):
65906509 return [
65916510 i for i , typ in enumerate (layer_types )
65926511 if typ == "attention"
65936512 ]
6594- return super ().get_attn_layres ()
6513+
6514+ # Layer types indicated by index or period
6515+ attn_layers = self .hparams .get ("attn_layer_indices" , [])
6516+ if not attn_layers :
6517+ attn_period = self .hparams .get ("attn_layer_period" )
6518+ assert attn_period , "Didn't find attn_layer_indices or attn_layer_period"
6519+ attn_offset = self .hparams .get ("attn_layer_offset" )
6520+ assert attn_offset is not None , "No attention layer offset set with attn_layer_period"
6521+ attn_layers = [
6522+ i for i in range (self .block_count )
6523+ if i % attn_period == attn_offset
6524+ ]
6525+ return attn_layers
6526+
6527+ def find_hparam (self , keys : Iterable [str ], * args , ** kwargs ) -> Any :
6528+ prefixed = []
6529+ for pfx in self .hparam_prefixes :
6530+ prefixed .extend (
6531+ "_" .join ([pfx , k ])
6532+ for k in keys
6533+ )
6534+ keys = list (keys ) + prefixed
6535+ return super ().find_hparam (keys , * args , ** kwargs )
65956536
65966537 def modify_tensors (
65976538 self , data_torch : Tensor , name : str , bid : int | None
@@ -6601,11 +6542,53 @@ def modify_tensors(
66016542 or "shared_mlp" in name
66026543 ):
66036544 return GraniteMoeModel .modify_tensors (self , data_torch , name , bid )
6604- return super ().modify_tensors (data_torch , name , bid )
6545+
6546+ # Determine whether this is a mamaba layer or an attention layer
6547+ if bid in self ._ssm_layers :
6548+ return super ().modify_tensors (data_torch , name , bid )
6549+ elif bid in self ._attn_layers :
6550+ return self ._transformer_model_class .modify_tensors (self , data_torch , name , bid )
6551+ return [(self .map_tensor_name (name ), data_torch )]
66056552
66066553 def set_gguf_parameters (self ):
66076554 GraniteMoeModel .set_gguf_parameters (self )
6608- BambaModel .set_gguf_parameters (self )
6555+
6556+ ## General Params ##
6557+ self .gguf_writer .add_embedding_length (self .d_model )
6558+ self .gguf_writer .add_block_count (self .block_count )
6559+ self .gguf_writer .add_context_length (self .hparams .get ("max_position_embeddings" , 0 ))
6560+ self .gguf_writer .add_vocab_size (self .hparams ["vocab_size" ])
6561+ self .gguf_writer .add_feed_forward_length (self .hparams ["intermediate_size" ])
6562+
6563+ ## Mamba mixer params ##
6564+ self .gguf_writer .add_ssm_conv_kernel (self .find_hparam (["conv_kernel" , "d_conv" ]))
6565+ self .gguf_writer .add_ssm_state_size (self .find_hparam (["state_size" , "d_state" ]))
6566+ self .gguf_writer .add_ssm_group_count (self .n_group )
6567+ self .gguf_writer .add_ssm_inner_size (self .d_inner )
6568+ # NOTE: The mamba_dt_rank is _not_ the right field for how this is used
6569+ # in llama.cpp
6570+ self .gguf_writer .add_ssm_time_step_rank (self .find_hparam (["n_heads" ]))
6571+
6572+ ## Attention params ##
6573+ self .gguf_writer .add_attn_layer_indices (self ._attn_layers )
6574+ if rope_dim := self .hparams .get ("attn_rotary_emb" ):
6575+ self .gguf_writer .add_rope_dimension_count (rope_dim )
6576+ self .gguf_writer .add_head_count (self .hparams ["num_attention_heads" ])
6577+ self .gguf_writer .add_head_count_kv (self .find_hparam (["num_key_value_heads" , "n_head_kv" ]))
6578+
6579+ ## Feed Forward Params ##
6580+ self .gguf_writer .add_layer_norm_rms_eps (
6581+ self .find_hparam (["layer_norm_epsilon" , "rms_norm_eps" ], optional = True ) or 1e-5
6582+ )
6583+
6584+ ## If Bamba, use rope, otherwise don't
6585+ use_rope = "BambaForCausalLM" in self .hparams ["architectures" ]
6586+ self .gguf_writer .add_rope_scaling_finetuned (use_rope )
6587+
6588+ ## Validation ##
6589+ d_head = self .find_hparam (["d_head" ], optional = True ) or 64
6590+ assert self .hparams .get ("hidden_act" ) in [None , "silu" ], "Only SILU activation supported"
6591+ assert self .d_inner % d_head == 0 , f"SSM inner size { self .d_inner } not a multiple of head dim { d_head } "
66096592
66106593 def set_vocab (self ):
66116594 self .hparams ["pad_vocab_size_multiple" ] = 8
0 commit comments