-
Notifications
You must be signed in to change notification settings - Fork 39
[feat] Hybrid Mamba model with Mamba and discrete Mamba 2 layers #194
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 43 commits
ad2b8d5
5dbc72a
5213e9e
01d0fe4
ae8f3ca
fa6d3bc
963e674
0842fcb
1c81719
37ba0d5
11a5db3
4af7eb7
be93749
dd469bc
ebe1b75
a4400fd
c49148c
5c8d930
ef6791b
f03dd10
2414252
f4d411d
ee86c68
2561738
ad8a48c
5243a88
075a31f
a788989
2700660
baaf714
833b586
9e2897d
b231cb8
8ccaa28
864fff2
7f2b35f
81c71af
7b7ce62
776e67b
3456884
b48f68d
9a35783
32b8aa1
4f9aad0
68de5d1
ebc516a
cb95e52
79c9a4b
bb3ba66
a5297be
2d39857
df032b5
c8fdbb9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,138 @@ | ||
| from fast_llm.config import Field, FieldHint, FieldUpdate, check_field, config_class | ||
| from fast_llm.engine.base_model.config import BaseModelArchitectureConfig | ||
| from fast_llm.layers.common.config import NormalizationArchitectureConfig, NormalizationConfig | ||
| from fast_llm.tensor import TensorSpace | ||
| from fast_llm.utils import Assert | ||
|
|
||
|
|
||
| class SSMDimNames: | ||
| d_model = "model_dimension_D" | ||
| d_state = "state_dimension_N" | ||
| d_conv = "size_of_conv1d_input" # dimention of the conv1d input in mamba layers | ||
| d_inner = "inner_dimension_after_expansion" | ||
| dt_rank = "rank_of_Ξ" | ||
tscholak marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| d_inner_proj_m = "inner_projection_dimension_mamba" | ||
| d_inner_proj_m2 = "inner_projection_dimension_mamba2" | ||
| d_x_proj = "x_projection_dimension" | ||
| headdim = "head_dimension_P" # dimention of the mamba2 head | ||
| d_conv_kernel = "1d_conv_kernel_size" # kernel size of the conv1d in mamba layers | ||
| n_qk_heads = "number_of_qk_heads" | ||
| n_v_heads = "number_of_v_heads" | ||
|
|
||
|
|
||
| @config_class() | ||
| class SSMArchitectureConfig(BaseModelArchitectureConfig): | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please adjust field names for our naming conventions. |
||
| _abstract = False | ||
|
|
||
| # Normalization | ||
| normalization: NormalizationArchitectureConfig = Field( | ||
| default_factory=NormalizationArchitectureConfig, | ||
| desc="Configuration for the normalization layers architecture.", | ||
| hint=FieldHint.core, | ||
| ) | ||
|
|
||
| expansion_factor: int = Field( | ||
| default=2, desc="Expansion factor for Mamba blocks.", hint=FieldHint.core, valid=check_field(Assert.gt, 0) | ||
| ) | ||
|
|
||
| state_size: int = Field( | ||
| default=16, | ||
| desc="State size for Mamba blocks.", | ||
| hint=FieldHint.core, | ||
| valid=check_field(Assert.gt, 0), | ||
| ) | ||
| conv_kernel_dimension: int = Field( | ||
| default=4, | ||
| desc="Conv kernel dimension for Mamba blocks.", | ||
| hint=FieldHint.core, | ||
| valid=check_field(Assert.gt, 0), | ||
| ) | ||
|
|
||
| # Layer parameters | ||
| add_bias_linear: bool = Field( | ||
| default=False, | ||
| desc="Whether to use bias in SSM layers", | ||
| hint=FieldHint.core, | ||
| ) | ||
|
|
||
| dt_rank: str | int = Field( | ||
|
||
| default="auto", | ||
| desc="Rank of the Ξ projection matrix. If 'auto', set to ceil(hidden_size/16)", | ||
| hint=FieldHint.core, | ||
| ) | ||
|
|
||
| chunk_size: int = Field( | ||
| default=256, | ||
| desc="Chunk size for Mamba2 blocks.", | ||
| hint=FieldHint.core, | ||
| ) | ||
|
|
||
| n_qk_heads: int = Field( | ||
| default=32, | ||
| desc="Number of QK heads for Mamba2 blocks.", | ||
| hint=FieldHint.core, | ||
| ) | ||
|
|
||
| n_v_heads: int = Field( | ||
| default=32, | ||
| desc="Number of V heads for Mamba2 blocks.", | ||
| hint=FieldHint.core, | ||
| ) | ||
|
|
||
| activation: str = Field( | ||
| default="silu", | ||
| desc="Activation function for Mamba2 blocks.", | ||
| hint=FieldHint.core, | ||
| ) | ||
|
|
||
|
|
||
| @config_class() | ||
| class SSMLayerConfig(SSMArchitectureConfig): | ||
| """Configuration for a Structured State Space Model (SSM) layer.""" | ||
|
|
||
| normalization: NormalizationConfig = FieldUpdate(default_factory=NormalizationConfig) | ||
| # Performance optimization | ||
| use_fast_path: bool = Field( | ||
| default=True, | ||
| desc="Whether to use optimized CUDA kernels when available", | ||
| hint=FieldHint.performance, | ||
| ) | ||
|
|
||
| debug_ssm: bool = Field( | ||
| default=False, | ||
| desc="debug_ssm", | ||
| hint=FieldHint.optional, | ||
| ) | ||
|
|
||
| dt_min: float = Field( | ||
| default=0.001, | ||
| desc="Minimum step size for discretization", | ||
| hint=FieldHint.core, | ||
| valid=check_field(Assert.gt, 0), | ||
| ) | ||
|
|
||
| dt_max: float = Field( | ||
| default=0.1, | ||
| desc="Maximum step size for discretization", | ||
| hint=FieldHint.core, | ||
| valid=check_field(Assert.gt, 0), | ||
| ) | ||
|
|
||
| dt_init_floor: float = Field( | ||
oleksost marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| default=1e-4, | ||
| desc="Minimum value for initializing dt", | ||
| hint=FieldHint.core, | ||
| valid=check_field(Assert.gt, 0), | ||
| ) | ||
|
|
||
| def setup_tensor_space(self, tensor_space: TensorSpace) -> None: | ||
| pass | ||
|
|
||
| def _validate(self) -> None: | ||
| """Validate configuration parameters.""" | ||
|
|
||
| super()._validate() | ||
| Assert.geq(self.dt_max, self.dt_min) | ||
|
|
||
| if isinstance(self.dt_rank, int): | ||
| Assert.gt(self.dt_rank, 0) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,234 @@ | ||
| import math | ||
|
|
||
| import torch | ||
| from einops import rearrange | ||
tscholak marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined | ||
|
|
||
| from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace | ||
| from fast_llm.layers.common.linear import Linear | ||
| from fast_llm.layers.ssm.config import SSMDimNames, SSMLayerConfig | ||
| from fast_llm.tensor import ParameterMeta, init_ones_, init_uniform_, init_zeros_, kaiming_init_ | ||
|
|
||
| try: | ||
| from causal_conv1d import causal_conv1d_fn, causal_conv1d_update | ||
| except ImportError: | ||
| causal_conv1d_fn, causal_conv1d_update = None, None | ||
|
|
||
|
|
||
| try: | ||
| from mamba_ssm.ops.triton.selective_state_update import selective_state_update | ||
| except ImportError: | ||
| selective_state_update = None | ||
|
|
||
| """ | ||
| This code is adapted fropm https://github.com/cartesia-ai/edge/blob/main/cartesia-pytorch/cartesia_pytorch/Llamba/mixers/discrete_mamba2.py | ||
| """ | ||
|
|
||
|
|
||
| def bias_init_method(conv_weight): | ||
tscholak marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(conv_weight) | ||
| bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 | ||
| return init_uniform_(-bound, bound) | ||
|
|
||
|
|
||
| class DiscreteMamba2(torch.nn.Module): | ||
| """DiscreteMamba2 (taken github.com/goombalab/phi-mamba.git).""" | ||
|
|
||
| def __init__( | ||
| self, | ||
| config: SSMLayerConfig, | ||
| layer_idx: int, | ||
| tensor_space: TensorSpace, | ||
| ): | ||
| """ | ||
| See the class .kernel.SSKernel for the kernel constructor which accepts kernel_args. | ||
| TODO: check what this comment means | ||
| Relevant options that are worth considering and tuning include "mode" + "measure", "dt_min", "dt_max", "lr". | ||
|
|
||
| Other options are all experimental and should not need to be configured. | ||
| """ | ||
| # factory_kwargs = {"device": "meta"} # , "dtype": torch.bfloat16} | ||
| super().__init__() | ||
| self.config: SSMLayerConfig = config | ||
| bias = config.add_bias_linear | ||
| self.layer_idx = layer_idx | ||
|
|
||
| td_inner = tensor_space.get_tensor_dim(SSMDimNames.d_inner) | ||
| td_state = tensor_space.get_tensor_dim(SSMDimNames.d_state) | ||
| td_model = tensor_space.get_tensor_dim(SSMDimNames.d_model) | ||
| td_conv = tensor_space.get_tensor_dim(SSMDimNames.d_conv) | ||
| td_n_qk_heads = tensor_space.get_tensor_dim(SSMDimNames.n_qk_heads) | ||
| td_n_v_heads = tensor_space.get_tensor_dim(SSMDimNames.n_v_heads) | ||
| td_conv_kernel = tensor_space.get_tensor_dim(SSMDimNames.d_conv_kernel) | ||
| td_inner_proj = tensor_space.get_tensor_dim(SSMDimNames.d_inner_proj_m2) | ||
|
|
||
| self.d_model = td_model.size | ||
| self.d_inner = td_inner.size | ||
| self.d_state = td_state.size | ||
| self.chunk_size = config.chunk_size | ||
| self.n_qk_heads = td_n_qk_heads.size | ||
| self.n_v_heads = td_n_v_heads.size | ||
| self.conv_kernel_size = td_conv_kernel.size | ||
|
|
||
| self.activation = config.activation | ||
tscholak marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| if self.activation == "silu": | ||
| self.act = torch.nn.SiLU() | ||
| elif self.activation == "identity": | ||
| self.act = torch.nn.Identity() | ||
| else: | ||
| raise ValueError(f"Activation {self.activation} not supported") | ||
|
|
||
| # TODO: double check innitializations | ||
| # Projections | ||
| self.in_proj = Linear(td_model, td_inner_proj, bias=bias, weight_init_method=kaiming_init_(td_model.size)) | ||
| self.z_bias = ( | ||
| ParameterMeta.from_dims( | ||
| (td_inner,), | ||
| weight_decay=False, | ||
| init_method=init_zeros_, | ||
| ) | ||
| if not bias | ||
| else 0.0 | ||
| ) | ||
|
|
||
| # Convolutional layer | ||
| self.conv1d_weight = ParameterMeta.from_dims( | ||
| (td_conv, TensorDim("1", 1), td_conv_kernel), | ||
| init_method=init_uniform_( | ||
| 1 / math.sqrt(td_conv.size * td_conv_kernel.size), 1 / math.sqrt(td_conv.size * td_conv_kernel.size) | ||
| ), # see https://github.com/pytorch/pytorch/blob/1eba9b3aa3c43f86f4a2c807ac8e12c4a7767340/torch/nn/modules/conv.py#L180C53-L180C67 | ||
| ) | ||
| self.conv1d_bias = ParameterMeta.from_dims((td_conv,), init_method=bias_init_method(self.conv1d_weight)) | ||
|
|
||
| # D "skip" parameter | ||
| self.D = ParameterMeta.from_dims( | ||
tscholak marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| (td_n_qk_heads,), | ||
| weight_decay=False, | ||
| init_method=init_ones_, | ||
| ) | ||
|
|
||
| # out_proj | ||
| self.out_proj = Linear( | ||
| td_inner, | ||
| td_model, | ||
| bias=bias, | ||
| weight_init_method=kaiming_init_(td_inner.size), | ||
| ) | ||
|
|
||
| @property | ||
| def d_output(self): | ||
| """Returns the output dimension of the model.""" | ||
| return self.d_model | ||
|
|
||
| @property | ||
| def state_to_tensor(self): | ||
| """Returns the state of the model as a tensor.""" | ||
| return self.layer.state_to_tensor | ||
|
|
||
| def forward(self, u, **kwargs): | ||
|
||
| """ | ||
| Args: | ||
| u: (B, L, D), | ||
|
|
||
| Returns: | ||
| outputs: dict. | ||
| outputs["hidden_states"]: (B, L, D). | ||
| outputs["state"]: inference cache. | ||
| """ | ||
| outputs = {} | ||
| # assert state is None | ||
| batch, seqlen, dim = u.shape | ||
|
|
||
| state = None | ||
|
|
||
| # Hacky way to initialize state during inference | ||
| chunk_size = self.chunk_size if state is None else seqlen | ||
|
|
||
| # Pad input to nearest multiple of chunklen | ||
| padded_len = (1 + (seqlen - 1) // chunk_size) * chunk_size | ||
| u = torch.nn.functional.pad(u, (0, 0, 0, padded_len - seqlen)) | ||
|
|
||
| # Project input | ||
| xBCzA_log = self.in_proj(u) | ||
|
|
||
| xBC, z, A_log = torch.split( | ||
| xBCzA_log, | ||
| [ | ||
| self.d_inner + 2 * self.n_qk_heads * self.d_state, | ||
| self.d_inner, | ||
| self.n_v_heads, | ||
| ], | ||
| dim=-1, | ||
| ) | ||
|
|
||
| if state is not None: | ||
| # If we just take xBC[:, :, -self.d_conv :], it will error if seqlen < self.d_conv | ||
| # Instead torch.nn.functional.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise. | ||
| xBC_t = rearrange(xBC[:, :seqlen, :], "b l d -> b d l") | ||
| state["conv"].copy_( | ||
| torch.nn.functional.pad(xBC_t, (self.conv_kernel_size - xBC_t.shape[-1], 0)) | ||
| ) # Update state (B D W) | ||
|
|
||
| # Convolutional layer | ||
| xBC = self.convolutional_forward(xBC, padded_len) | ||
|
|
||
| x, B, C = torch.split( | ||
| xBC, | ||
| [ | ||
| self.d_inner, | ||
| self.n_qk_heads * self.d_state, | ||
| self.n_qk_heads * self.d_state, | ||
| ], | ||
| dim=-1, | ||
| ) | ||
|
|
||
| x = rearrange(x, "b l (h n) -> b l h n", h=self.n_v_heads) | ||
| B = rearrange(B, "b l (h n) -> b l h n", h=self.n_qk_heads) | ||
| C = rearrange(C, "b l (h n) -> b l h n", h=self.n_qk_heads) | ||
|
|
||
| # SSM forward | ||
| result = mamba_chunk_scan_combined( | ||
| x=x / torch.nn.functional.softplus(A_log).to(x.dtype).unsqueeze(-1), | ||
| dt=A_log, | ||
| dt_softplus=True, | ||
| A=-torch.ones(self.n_v_heads, device=A_log.device), | ||
| B=B, | ||
| C=C, | ||
| chunk_size=chunk_size, | ||
| # initial_states=(state["ssm"] if state is not None else None), # currently not supported by mamba_ssm.utils.generation | ||
| return_final_states=(state is not None), | ||
| ) | ||
|
|
||
| if state is not None: | ||
| y, ssm_state = result | ||
| state["ssm"].copy_(ssm_state) | ||
| else: | ||
| y = result | ||
|
|
||
| Du = torch.einsum("h,blhp->blhp", self.D, x) | ||
| y = rearrange(y + Du, "b l h p -> b l (h p)") | ||
|
|
||
| # Norm and gate | ||
| out = self.out_proj(y * torch.nn.functional.silu(z + self.z_bias)) | ||
| outputs["hidden_states"] = out[:, :seqlen, :] | ||
|
|
||
| # TODO: since we do not support inference for now, we only return the hidden states for now. | ||
| return outputs["hidden_states"].contiguous() | ||
|
|
||
| def convolutional_forward(self, xBC, padded_len): | ||
| """Convolutional layer forward pass for the full sequence.""" | ||
| if causal_conv1d_fn is None or self.activation not in [ | ||
| "silu", | ||
| "swish", | ||
| "identity", | ||
| ]: | ||
| raise NotImplementedError("Only support causal_conv1d_fn kernel for now") | ||
| # xBC = self.act(self.conv1d(xBC.transpose(1, 2))[..., :padded_len].transpose(1, 2)) | ||
| else: | ||
| xBC = causal_conv1d_fn( | ||
| xBC.transpose(1, 2), | ||
| rearrange(self.conv1d_weight, "d 1 w -> d w"), | ||
| self.conv1d_bias, | ||
| activation=None if self.activation == "identity" else self.activation, | ||
| ).transpose(1, 2) | ||
| return xBC | ||
Uh oh!
There was an error while loading. Please reload this page.