Skip to content
Merged
Show file tree
Hide file tree
Changes from 43 commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
ad2b8d5
wip
oleksost Feb 18, 2025
5dbc72a
wip
oleksost Mar 13, 2025
5213e9e
WIP
oleksost Mar 14, 2025
01d0fe4
mamba1 block
oleksost Mar 20, 2025
ae8f3ca
removed build from remote
oleksost Mar 20, 2025
fa6d3bc
removed unneccesary tests
oleksost Mar 20, 2025
963e674
removed unneccesary files
oleksost Mar 20, 2025
0842fcb
test
oleksost Mar 20, 2025
1c81719
test mamba1
oleksost Mar 24, 2025
37ba0d5
tensor dimentions
oleksost Mar 24, 2025
11a5db3
meta init with full model run
oleksost Mar 24, 2025
4af7eb7
training, but having backward issues
oleksost Mar 25, 2025
be93749
integration into training pipeline
oleksost Mar 30, 2025
dd469bc
mamba2
oleksost Mar 31, 2025
ebe1b75
renamed config + skip test
oleksost Mar 31, 2025
a4400fd
skip tests if mamba not installed
oleksost Mar 31, 2025
c49148c
pre-commits
oleksost Mar 31, 2025
5c8d930
cleanup
oleksost Mar 31, 2025
ef6791b
dependencies
oleksost Mar 31, 2025
f03dd10
descrete mamba2
oleksost Mar 31, 2025
2414252
Merge branch 'ssm_mamba2' into ssm
oleksost Mar 31, 2025
f4d411d
test
oleksost Mar 31, 2025
ee86c68
llamba checkpoint converter
oleksost Apr 3, 2025
2561738
cleanup
oleksost Apr 4, 2025
ad8a48c
test
oleksost Apr 4, 2025
5243a88
Merge branch 'main' into ssm
oleksost Apr 4, 2025
075a31f
mamba force build
oleksost Apr 7, 2025
a788989
mamba force build
oleksost Apr 7, 2025
2700660
mamba force build
oleksost Apr 7, 2025
baaf714
causal conv skip build
oleksost Apr 7, 2025
833b586
Merge branch 'main' into ssm
oleksost Apr 7, 2025
9e2897d
docs.yaml
oleksost Apr 7, 2025
b231cb8
MTP hardcoded
oleksost Apr 7, 2025
8ccaa28
import nvm
oleksost Apr 7, 2025
864fff2
remove dependency on cartesia
oleksost Apr 7, 2025
7f2b35f
save llamba
oleksost Apr 7, 2025
81c71af
addressed comments
oleksost Apr 8, 2025
7b7ce62
addressed comments
oleksost Apr 9, 2025
776e67b
Merge branch 'main' into ssm
oleksost Apr 9, 2025
3456884
nvm
oleksost Apr 10, 2025
b48f68d
renamed block pattern into block layout
oleksost Apr 11, 2025
9a35783
renames
oleksost Apr 11, 2025
32b8aa1
nvm
oleksost Apr 14, 2025
4f9aad0
wip
oleksost Apr 16, 2025
68de5d1
addressed comments
oleksost Apr 23, 2025
ebc516a
Merge branch 'main' into ssm
oleksost Apr 23, 2025
cb95e52
wip
oleksost Apr 23, 2025
79c9a4b
batch config
oleksost Apr 23, 2025
bb3ba66
clean up
oleksost Apr 23, 2025
a5297be
nvm
oleksost Apr 23, 2025
2d39857
tests
oleksost Apr 23, 2025
df032b5
nvm
oleksost Apr 23, 2025
c8fdbb9
identity activation into MLP
oleksost Apr 23, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ jobs:
run: |
pip install "torch>=2.2.2"
pip install pybind11
FLASH_ATTENTION_SKIP_CUDA_BUILD=TRUE FLASH_ATTENTION_FORCE_BUILD=TRUE pip install --no-build-isolation -e ".[CORE,OPTIONAL,DEV,DOCS]"
FLASH_ATTENTION_SKIP_CUDA_BUILD=TRUE FLASH_ATTENTION_FORCE_BUILD=TRUE MAMBA_SKIP_CUDA_BUILD=TRUE MAMBA_FORCE_BUILD=TRUE CAUSAL_CONV1D_FORCE_BUILD=TRUE CAUSAL_CONV1D_SKIP_CUDA_BUILD=TRUE pip install --no-build-isolation -e ".[CORE,OPTIONAL,DEV,DOCS]"

- name: Run tests
run: pytest .
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/docs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ jobs:
- run: |
pip install "torch>=2.2.2"
pip install pybind11
FLASH_ATTENTION_SKIP_CUDA_BUILD=TRUE FLASH_ATTENTION_FORCE_BUILD=TRUE pip install --no-build-isolation -e ".[CORE,OPTIONAL,DEV,DOCS]"
FLASH_ATTENTION_SKIP_CUDA_BUILD=TRUE FLASH_ATTENTION_FORCE_BUILD=TRUE MAMBA_SKIP_CUDA_BUILD=TRUE MAMBA_FORCE_BUILD=TRUE CAUSAL_CONV1D_FORCE_BUILD=TRUE CAUSAL_CONV1D_SKIP_CUDA_BUILD=TRUE pip install --no-build-isolation -e ".[CORE,OPTIONAL,DEV,DOCS]"
- name: Build the documentation
run: mkdocs build

Expand Down
138 changes: 138 additions & 0 deletions fast_llm/layers/ssm/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
from fast_llm.config import Field, FieldHint, FieldUpdate, check_field, config_class
from fast_llm.engine.base_model.config import BaseModelArchitectureConfig
from fast_llm.layers.common.config import NormalizationArchitectureConfig, NormalizationConfig
from fast_llm.tensor import TensorSpace
from fast_llm.utils import Assert


class SSMDimNames:
d_model = "model_dimension_D"
d_state = "state_dimension_N"
d_conv = "size_of_conv1d_input" # dimention of the conv1d input in mamba layers
d_inner = "inner_dimension_after_expansion"
dt_rank = "rank_of_Ξ”"
d_inner_proj_m = "inner_projection_dimension_mamba"
d_inner_proj_m2 = "inner_projection_dimension_mamba2"
d_x_proj = "x_projection_dimension"
headdim = "head_dimension_P" # dimention of the mamba2 head
d_conv_kernel = "1d_conv_kernel_size" # kernel size of the conv1d in mamba layers
n_qk_heads = "number_of_qk_heads"
n_v_heads = "number_of_v_heads"


@config_class()
class SSMArchitectureConfig(BaseModelArchitectureConfig):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please adjust field names for our naming conventions.

_abstract = False

# Normalization
normalization: NormalizationArchitectureConfig = Field(
default_factory=NormalizationArchitectureConfig,
desc="Configuration for the normalization layers architecture.",
hint=FieldHint.core,
)

expansion_factor: int = Field(
default=2, desc="Expansion factor for Mamba blocks.", hint=FieldHint.core, valid=check_field(Assert.gt, 0)
)

state_size: int = Field(
default=16,
desc="State size for Mamba blocks.",
hint=FieldHint.core,
valid=check_field(Assert.gt, 0),
)
conv_kernel_dimension: int = Field(
default=4,
desc="Conv kernel dimension for Mamba blocks.",
hint=FieldHint.core,
valid=check_field(Assert.gt, 0),
)

# Layer parameters
add_bias_linear: bool = Field(
default=False,
desc="Whether to use bias in SSM layers",
hint=FieldHint.core,
)

dt_rank: str | int = Field(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use None for derived defaults. dt_rank: int = Field(default=None, ...

default="auto",
desc="Rank of the Ξ” projection matrix. If 'auto', set to ceil(hidden_size/16)",
hint=FieldHint.core,
)

chunk_size: int = Field(
default=256,
desc="Chunk size for Mamba2 blocks.",
hint=FieldHint.core,
)

n_qk_heads: int = Field(
default=32,
desc="Number of QK heads for Mamba2 blocks.",
hint=FieldHint.core,
)

n_v_heads: int = Field(
default=32,
desc="Number of V heads for Mamba2 blocks.",
hint=FieldHint.core,
)

activation: str = Field(
default="silu",
desc="Activation function for Mamba2 blocks.",
hint=FieldHint.core,
)


@config_class()
class SSMLayerConfig(SSMArchitectureConfig):
"""Configuration for a Structured State Space Model (SSM) layer."""

normalization: NormalizationConfig = FieldUpdate(default_factory=NormalizationConfig)
# Performance optimization
use_fast_path: bool = Field(
default=True,
desc="Whether to use optimized CUDA kernels when available",
hint=FieldHint.performance,
)

debug_ssm: bool = Field(
default=False,
desc="debug_ssm",
hint=FieldHint.optional,
)

dt_min: float = Field(
default=0.001,
desc="Minimum step size for discretization",
hint=FieldHint.core,
valid=check_field(Assert.gt, 0),
)

dt_max: float = Field(
default=0.1,
desc="Maximum step size for discretization",
hint=FieldHint.core,
valid=check_field(Assert.gt, 0),
)

dt_init_floor: float = Field(
default=1e-4,
desc="Minimum value for initializing dt",
hint=FieldHint.core,
valid=check_field(Assert.gt, 0),
)

def setup_tensor_space(self, tensor_space: TensorSpace) -> None:
pass

def _validate(self) -> None:
"""Validate configuration parameters."""

super()._validate()
Assert.geq(self.dt_max, self.dt_min)

if isinstance(self.dt_rank, int):
Assert.gt(self.dt_rank, 0)
234 changes: 234 additions & 0 deletions fast_llm/layers/ssm/discrete_mamba2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,234 @@
import math

import torch
from einops import rearrange
from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined

from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace
from fast_llm.layers.common.linear import Linear
from fast_llm.layers.ssm.config import SSMDimNames, SSMLayerConfig
from fast_llm.tensor import ParameterMeta, init_ones_, init_uniform_, init_zeros_, kaiming_init_

try:
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
except ImportError:
causal_conv1d_fn, causal_conv1d_update = None, None


try:
from mamba_ssm.ops.triton.selective_state_update import selective_state_update
except ImportError:
selective_state_update = None

"""
This code is adapted fropm https://github.com/cartesia-ai/edge/blob/main/cartesia-pytorch/cartesia_pytorch/Llamba/mixers/discrete_mamba2.py
"""


def bias_init_method(conv_weight):
fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(conv_weight)
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
return init_uniform_(-bound, bound)


class DiscreteMamba2(torch.nn.Module):
"""DiscreteMamba2 (taken github.com/goombalab/phi-mamba.git)."""

def __init__(
self,
config: SSMLayerConfig,
layer_idx: int,
tensor_space: TensorSpace,
):
"""
See the class .kernel.SSKernel for the kernel constructor which accepts kernel_args.
TODO: check what this comment means
Relevant options that are worth considering and tuning include "mode" + "measure", "dt_min", "dt_max", "lr".

Other options are all experimental and should not need to be configured.
"""
# factory_kwargs = {"device": "meta"} # , "dtype": torch.bfloat16}
super().__init__()
self.config: SSMLayerConfig = config
bias = config.add_bias_linear
self.layer_idx = layer_idx

td_inner = tensor_space.get_tensor_dim(SSMDimNames.d_inner)
td_state = tensor_space.get_tensor_dim(SSMDimNames.d_state)
td_model = tensor_space.get_tensor_dim(SSMDimNames.d_model)
td_conv = tensor_space.get_tensor_dim(SSMDimNames.d_conv)
td_n_qk_heads = tensor_space.get_tensor_dim(SSMDimNames.n_qk_heads)
td_n_v_heads = tensor_space.get_tensor_dim(SSMDimNames.n_v_heads)
td_conv_kernel = tensor_space.get_tensor_dim(SSMDimNames.d_conv_kernel)
td_inner_proj = tensor_space.get_tensor_dim(SSMDimNames.d_inner_proj_m2)

self.d_model = td_model.size
self.d_inner = td_inner.size
self.d_state = td_state.size
self.chunk_size = config.chunk_size
self.n_qk_heads = td_n_qk_heads.size
self.n_v_heads = td_n_v_heads.size
self.conv_kernel_size = td_conv_kernel.size

self.activation = config.activation
if self.activation == "silu":
self.act = torch.nn.SiLU()
elif self.activation == "identity":
self.act = torch.nn.Identity()
else:
raise ValueError(f"Activation {self.activation} not supported")

# TODO: double check innitializations
# Projections
self.in_proj = Linear(td_model, td_inner_proj, bias=bias, weight_init_method=kaiming_init_(td_model.size))
self.z_bias = (
ParameterMeta.from_dims(
(td_inner,),
weight_decay=False,
init_method=init_zeros_,
)
if not bias
else 0.0
)

# Convolutional layer
self.conv1d_weight = ParameterMeta.from_dims(
(td_conv, TensorDim("1", 1), td_conv_kernel),
init_method=init_uniform_(
1 / math.sqrt(td_conv.size * td_conv_kernel.size), 1 / math.sqrt(td_conv.size * td_conv_kernel.size)
), # see https://github.com/pytorch/pytorch/blob/1eba9b3aa3c43f86f4a2c807ac8e12c4a7767340/torch/nn/modules/conv.py#L180C53-L180C67
)
self.conv1d_bias = ParameterMeta.from_dims((td_conv,), init_method=bias_init_method(self.conv1d_weight))

# D "skip" parameter
self.D = ParameterMeta.from_dims(
(td_n_qk_heads,),
weight_decay=False,
init_method=init_ones_,
)

# out_proj
self.out_proj = Linear(
td_inner,
td_model,
bias=bias,
weight_init_method=kaiming_init_(td_inner.size),
)

@property
def d_output(self):
"""Returns the output dimension of the model."""
return self.d_model

@property
def state_to_tensor(self):
"""Returns the state of the model as a tensor."""
return self.layer.state_to_tensor

def forward(self, u, **kwargs):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Names in this method are a bit cryptic (also pep8)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, yet again, for clarity, I would rather keep the variable names as in the original code for Mamba layers.

"""
Args:
u: (B, L, D),

Returns:
outputs: dict.
outputs["hidden_states"]: (B, L, D).
outputs["state"]: inference cache.
"""
outputs = {}
# assert state is None
batch, seqlen, dim = u.shape

state = None

# Hacky way to initialize state during inference
chunk_size = self.chunk_size if state is None else seqlen

# Pad input to nearest multiple of chunklen
padded_len = (1 + (seqlen - 1) // chunk_size) * chunk_size
u = torch.nn.functional.pad(u, (0, 0, 0, padded_len - seqlen))

# Project input
xBCzA_log = self.in_proj(u)

xBC, z, A_log = torch.split(
xBCzA_log,
[
self.d_inner + 2 * self.n_qk_heads * self.d_state,
self.d_inner,
self.n_v_heads,
],
dim=-1,
)

if state is not None:
# If we just take xBC[:, :, -self.d_conv :], it will error if seqlen < self.d_conv
# Instead torch.nn.functional.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise.
xBC_t = rearrange(xBC[:, :seqlen, :], "b l d -> b d l")
state["conv"].copy_(
torch.nn.functional.pad(xBC_t, (self.conv_kernel_size - xBC_t.shape[-1], 0))
) # Update state (B D W)

# Convolutional layer
xBC = self.convolutional_forward(xBC, padded_len)

x, B, C = torch.split(
xBC,
[
self.d_inner,
self.n_qk_heads * self.d_state,
self.n_qk_heads * self.d_state,
],
dim=-1,
)

x = rearrange(x, "b l (h n) -> b l h n", h=self.n_v_heads)
B = rearrange(B, "b l (h n) -> b l h n", h=self.n_qk_heads)
C = rearrange(C, "b l (h n) -> b l h n", h=self.n_qk_heads)

# SSM forward
result = mamba_chunk_scan_combined(
x=x / torch.nn.functional.softplus(A_log).to(x.dtype).unsqueeze(-1),
dt=A_log,
dt_softplus=True,
A=-torch.ones(self.n_v_heads, device=A_log.device),
B=B,
C=C,
chunk_size=chunk_size,
# initial_states=(state["ssm"] if state is not None else None), # currently not supported by mamba_ssm.utils.generation
return_final_states=(state is not None),
)

if state is not None:
y, ssm_state = result
state["ssm"].copy_(ssm_state)
else:
y = result

Du = torch.einsum("h,blhp->blhp", self.D, x)
y = rearrange(y + Du, "b l h p -> b l (h p)")

# Norm and gate
out = self.out_proj(y * torch.nn.functional.silu(z + self.z_bias))
outputs["hidden_states"] = out[:, :seqlen, :]

# TODO: since we do not support inference for now, we only return the hidden states for now.
return outputs["hidden_states"].contiguous()

def convolutional_forward(self, xBC, padded_len):
"""Convolutional layer forward pass for the full sequence."""
if causal_conv1d_fn is None or self.activation not in [
"silu",
"swish",
"identity",
]:
raise NotImplementedError("Only support causal_conv1d_fn kernel for now")
# xBC = self.act(self.conv1d(xBC.transpose(1, 2))[..., :padded_len].transpose(1, 2))
else:
xBC = causal_conv1d_fn(
xBC.transpose(1, 2),
rearrange(self.conv1d_weight, "d 1 w -> d w"),
self.conv1d_bias,
activation=None if self.activation == "identity" else self.activation,
).transpose(1, 2)
return xBC
Loading