-
Notifications
You must be signed in to change notification settings - Fork 39
[feat] Hybrid Mamba model with Mamba and discrete Mamba 2 layers #194
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Changes from 8 commits
Commits
Show all changes
53 commits
Select commit
Hold shift + click to select a range
ad2b8d5
wip
oleksost 5dbc72a
wip
oleksost 5213e9e
WIP
oleksost 01d0fe4
mamba1 block
oleksost ae8f3ca
removed build from remote
oleksost fa6d3bc
removed unneccesary tests
oleksost 963e674
removed unneccesary files
oleksost 0842fcb
test
oleksost 1c81719
test mamba1
oleksost 37ba0d5
tensor dimentions
oleksost 11a5db3
meta init with full model run
oleksost 4af7eb7
training, but having backward issues
oleksost be93749
integration into training pipeline
oleksost dd469bc
mamba2
oleksost ebe1b75
renamed config + skip test
oleksost a4400fd
skip tests if mamba not installed
oleksost c49148c
pre-commits
oleksost 5c8d930
cleanup
oleksost ef6791b
dependencies
oleksost f03dd10
descrete mamba2
oleksost 2414252
Merge branch 'ssm_mamba2' into ssm
oleksost f4d411d
test
oleksost ee86c68
llamba checkpoint converter
oleksost 2561738
cleanup
oleksost ad8a48c
test
oleksost 5243a88
Merge branch 'main' into ssm
oleksost 075a31f
mamba force build
oleksost a788989
mamba force build
oleksost 2700660
mamba force build
oleksost baaf714
causal conv skip build
oleksost 833b586
Merge branch 'main' into ssm
oleksost 9e2897d
docs.yaml
oleksost b231cb8
MTP hardcoded
oleksost 8ccaa28
import nvm
oleksost 864fff2
remove dependency on cartesia
oleksost 7f2b35f
save llamba
oleksost 81c71af
addressed comments
oleksost 7b7ce62
addressed comments
oleksost 776e67b
Merge branch 'main' into ssm
oleksost 3456884
nvm
oleksost b48f68d
renamed block pattern into block layout
oleksost 9a35783
renames
oleksost 32b8aa1
nvm
oleksost 4f9aad0
wip
oleksost 68de5d1
addressed comments
oleksost ebc516a
Merge branch 'main' into ssm
oleksost cb95e52
wip
oleksost 79c9a4b
batch config
oleksost bb3ba66
clean up
oleksost a5297be
nvm
oleksost 2d39857
tests
oleksost df032b5
nvm
oleksost c8fdbb9
identity activation into MLP
oleksost File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,147 @@ | ||
| import logging | ||
| import typing | ||
|
|
||
| import torch | ||
|
|
||
| from fast_llm.engine.config_utils.tensor_space import TensorDim | ||
| from fast_llm.tensor import ParameterMeta, init_zeros_ | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| class Conv1DBase(torch.nn.Module): | ||
| """ | ||
| A base module for 1D convolutional layers holding weights and biases. | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| in_channels: TensorDim, | ||
| out_channels: TensorDim, | ||
| kernel_size: int, | ||
| stride: int = 1, | ||
| padding: int = 0, | ||
| dilation: int = 1, | ||
| groups: int = 1, | ||
| *, | ||
| bias=True, | ||
| weight_init_method, | ||
| bias_init_method=init_zeros_, | ||
| auto_bias_grad_accumulation: bool = False, | ||
| lr_scale: float | None | tuple[float | None, ...] = None, | ||
| ): | ||
| super().__init__() | ||
| self._in_channels = in_channels | ||
| self._out_channels = out_channels | ||
| self._kernel_size = kernel_size | ||
| self._stride = stride | ||
| self._padding = padding | ||
| self._dilation = dilation | ||
| self._groups = groups | ||
|
|
||
| self.weight = ParameterMeta.from_dims( | ||
| (self._out_channels, TensorDim("D_in", self._in_channels.size // groups), TensorDim("D_kernel", self._kernel_size)), | ||
| init_method=weight_init_method, | ||
| auto_grad_accumulation=False, | ||
| lr_scale=lr_scale, | ||
| ) | ||
|
|
||
| if bias: | ||
| self.bias = ParameterMeta.from_dims( | ||
| (self._out_channels,), | ||
| init_method=bias_init_method, | ||
| weight_decay=False, | ||
| auto_grad_accumulation=auto_bias_grad_accumulation, | ||
| lr_scale=lr_scale, | ||
| ) | ||
| else: | ||
| self.bias = None | ||
|
|
||
|
|
||
| class Conv1D(Conv1DBase): | ||
| """ | ||
| A basic 1D convolutional layer without tensor parallelism. | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| in_channels: TensorDim, | ||
| out_channels: TensorDim, | ||
| kernel_size: int, | ||
| stride: int = 1, | ||
| padding: int = 0, | ||
| dilation: int = 1, | ||
| groups: int = 1, | ||
| *, | ||
| bias=True, | ||
| weight_init_method, | ||
| bias_init_method=init_zeros_, | ||
| lr_scale: float | None | tuple[float | None, ...] = None, | ||
| ): | ||
| assert in_channels.parallel_dim is None | ||
| assert out_channels.parallel_dim is None | ||
| super().__init__( | ||
| in_channels, | ||
| out_channels, | ||
| kernel_size, | ||
| stride, | ||
| padding, | ||
| dilation, | ||
| groups, | ||
| bias=bias, | ||
| weight_init_method=weight_init_method, | ||
| bias_init_method=bias_init_method, | ||
| lr_scale=lr_scale, | ||
| ) | ||
|
|
||
| def forward(self, input_: torch.Tensor) -> torch.Tensor: | ||
| return torch.nn.functional.conv1d( | ||
| input_, | ||
| self.weight, | ||
| self.bias, | ||
| stride=self._stride, | ||
| padding=self._padding, | ||
| dilation=self._dilation, | ||
| groups=self._groups, | ||
| ) | ||
|
|
||
| def forward_only( | ||
| self, input_: torch.Tensor | ||
| ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor, dict]]: | ||
| # Store context for backward pass | ||
| context = { | ||
| "input": input_, | ||
| "weight": self.weight, | ||
| "stride": self._stride, | ||
| "padding": self._padding, | ||
| "dilation": self._dilation, | ||
| "groups": self._groups, | ||
| } | ||
|
|
||
| output = torch.nn.functional.conv1d( | ||
| input_, | ||
| self.weight, | ||
| self.bias, | ||
| stride=self._stride, | ||
| padding=self._padding, | ||
| dilation=self._dilation, | ||
| groups=self._groups, | ||
| ) | ||
|
|
||
| return output, (input_, self.weight, context) | ||
|
|
||
| def backward(self, grad_output: torch.Tensor, context: tuple[torch.Tensor, torch.Tensor, dict]) -> torch.Tensor: | ||
| input_, weight, ctx = context | ||
|
|
||
| # Calculate gradients | ||
| grad_input = torch.nn.grad.conv1d_input( | ||
| input_.shape, | ||
| weight, | ||
| grad_output, | ||
| stride=ctx["stride"], | ||
| padding=ctx["padding"], | ||
| dilation=ctx["dilation"], | ||
| groups=ctx["groups"], | ||
| ) | ||
|
|
||
| return grad_input |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,175 @@ | ||
| import math | ||
| from typing import Optional | ||
|
|
||
| from fast_llm.config import Field, FieldHint, FieldUpdate, check_field, config_class, skip_valid_if_none | ||
| from fast_llm.engine.base_model.config import BaseModelConfig | ||
| from fast_llm.layers.common.config import NormalizationConfig | ||
| from fast_llm.layers.transformer.config import TransformerArchitectureConfig | ||
| from fast_llm.utils import Assert | ||
|
|
||
| @config_class() | ||
| class MambaConfig(TransformerArchitectureConfig, BaseModelConfig): | ||
| """Configuration for a Structured State Space Model (SSM) layer.""" | ||
|
|
||
| # Core architecture parameters | ||
| hidden_size: int = Field( | ||
| default=768, | ||
| desc="Size of the hidden representations", | ||
| hint=FieldHint.core, | ||
| valid=check_field(Assert.gt, 0), | ||
| ) | ||
|
|
||
| state_size: int = Field( | ||
| default=64, | ||
| desc="Size of the internal state vector", | ||
| hint=FieldHint.core, | ||
| valid=check_field(Assert.gt, 0), | ||
| ) | ||
|
|
||
| expansion_factor: int = Field( | ||
| default=2, | ||
| desc="Factor by which to expand hidden size in SSM computation", | ||
| hint=FieldHint.core, | ||
| valid=check_field(Assert.gt, 0), | ||
| ) | ||
|
|
||
| # SSM specific parameters | ||
| conv_dimension: int = Field( | ||
| default=4, | ||
| desc="Size of the convolutional kernel", | ||
| hint=FieldHint.core, | ||
| valid=check_field(Assert.gt, 0), | ||
| ) | ||
|
|
||
| dt_rank: str | int = Field( | ||
|
||
| default="auto", | ||
| desc="Rank of the Ξ projection matrix. If 'auto', set to ceil(hidden_size/16)", | ||
| hint=FieldHint.core, | ||
| ) | ||
|
|
||
| dt_min: float = Field( | ||
| default=0.001, | ||
| desc="Minimum step size for discretization", | ||
| hint=FieldHint.core, | ||
| valid=check_field(Assert.gt, 0), | ||
| ) | ||
|
|
||
| dt_max: float = Field( | ||
| default=0.1, | ||
| desc="Maximum step size for discretization", | ||
| hint=FieldHint.core, | ||
| valid=check_field(Assert.gt, 0), | ||
| ) | ||
|
|
||
| dt_init_floor: float = Field( | ||
oleksost marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| default=1e-4, | ||
| desc="Minimum value for initializing dt", | ||
| hint=FieldHint.core, | ||
| valid=check_field(Assert.gt, 0), | ||
| ) | ||
|
|
||
| # Layer parameters | ||
| add_bias_linear: bool = Field( | ||
| default=False, | ||
| desc="Whether to use bias in linear transformations", | ||
| hint=FieldHint.core, | ||
| ) | ||
|
|
||
| conv_bias: bool = Field( | ||
| default=True, | ||
| desc="Whether to use bias in convolution layer", | ||
| hint=FieldHint.core, | ||
| ) | ||
|
|
||
| # Normalization | ||
| normalization: NormalizationConfig = FieldUpdate( | ||
| default_factory=NormalizationConfig | ||
| ) | ||
|
|
||
| # Performance optimization | ||
| use_fast_path: bool = Field( | ||
| default=True, | ||
| desc="Whether to use optimized CUDA kernels when available", | ||
| hint=FieldHint.performance, | ||
| ) | ||
|
|
||
| # Initialization parameters | ||
| init_method_std: float = Field( | ||
| default=None, | ||
| desc="Default scale for weight initialization. Default: hidden_size**-0.5", | ||
| hint=FieldHint.optional, | ||
| valid=skip_valid_if_none(check_field(Assert.geq, 0)), | ||
| ) | ||
|
|
||
|
|
||
| device: str = Field( | ||
| default="cuda", | ||
| desc="device", | ||
| hint=FieldHint.optional, | ||
| ) | ||
|
|
||
| mamba_headdim: int = Field( | ||
| default=64, | ||
| desc="headdim", | ||
| hint=FieldHint.optional, | ||
| ) | ||
| mamba_ngroups: int = Field( | ||
| default=1, | ||
| desc="ngroups", | ||
| hint=FieldHint.optional, | ||
| ) | ||
|
|
||
| use_low_rank_mamba_proj: bool = Field( | ||
| default=False, | ||
| desc="use_low_rank_mamba_proj", | ||
| hint=FieldHint.optional, | ||
| ) | ||
|
|
||
| use_module_layernorm: bool = Field( | ||
| default=False, | ||
| desc="use_module_layernorm", | ||
| hint=FieldHint.optional, | ||
| ) | ||
|
|
||
| layernorm_epsilon: float = Field( | ||
| default=1e-5, | ||
| desc="layernorm_epsilon", | ||
| hint=FieldHint.optional, | ||
| ) | ||
|
|
||
| rms_norm: bool = Field( | ||
| default=False, | ||
| desc="rms_norm", | ||
| hint=FieldHint.optional, | ||
| ) | ||
|
|
||
| fused_add_norm: bool = Field( | ||
| default=False, | ||
| desc="fused_add_norm", | ||
| hint=FieldHint.optional, | ||
| ) | ||
|
|
||
| residual_in_fp32: bool = Field( | ||
| default=False, | ||
| desc="residual_in_fp32", | ||
| hint=FieldHint.optional, | ||
| ) | ||
|
|
||
| def _validate(self) -> None: | ||
| """Validate configuration parameters.""" | ||
| if self.init_method_std is None: | ||
| self.init_method_std = self.hidden_size**-0.5 | ||
|
|
||
| super()._validate() | ||
|
|
||
| # Validate SSM-specific parameters | ||
| Assert.gt(self.state_size, 0) | ||
oleksost marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| Assert.gt(self.expansion_factor, 0) | ||
| Assert.gt(self.conv_dimension, 0) | ||
| Assert.gt(self.dt_min, 0) | ||
| Assert.gt(self.dt_max, 0) | ||
| Assert.gt(self.dt_init_floor, 0) | ||
| Assert.geq(self.dt_max, self.dt_min) | ||
|
|
||
| if isinstance(self.dt_rank, int): | ||
| Assert.gt(self.dt_rank, 0) | ||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We'll need a proper architecture/non-architecture split for things to work properly.