From 392e99c5229e3a9d3e2ccdb81e06ce7154c06d01 Mon Sep 17 00:00:00 2001 From: Tomer Asida <57313761+tomeras91@users.noreply.github.com> Date: Wed, 10 Sep 2025 15:06:28 +0300 Subject: [PATCH 1/3] Add initial implementation for using MergedColumnParallelLinear if n_groups % tp_size == 0 Signed-off-by: Tomer Asida <57313761+tomeras91@users.noreply.github.com> --- .../layers/mamba/mamba_mixer2.py | 204 +++++++++++------- 1 file changed, 123 insertions(+), 81 deletions(-) diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index bb3fdd38dbef..dccfe303a021 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -19,6 +19,7 @@ from vllm.forward_context import ForwardContext, get_forward_context from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.layers.linear import (ColumnParallelLinear, + MergedColumnParallelLinear, RowParallelLinear) from vllm.model_executor.layers.mamba.abstract import MambaBase from vllm.model_executor.layers.mamba.mamba2_metadata import (Mamba2Metadata, @@ -264,9 +265,11 @@ def __init__(self, "If tensor parallel world size does not divide num_heads, " "then num_groups must equal 1.") - assert ( - self.tp_size == 1 or quant_config is None - ), "Tensor parallel currently not supported for quantized models." + assert (n_groups % self.tp_size == 0) or self.tp_size == 1 or \ + quant_config is None, ( + "Tensor parallel currently supported for quantized models only " + "if tensor parallel world size divides num groups." + ) self.ssm_state_size = ssm_state_size self.conv_kernel_size = conv_kernel_size @@ -285,92 +288,101 @@ def __init__(self, n_groups, self.tp_size) self.n_groups = n_groups + groups - self.conv_dim = intermediate_size + 2 * self.n_groups * ssm_state_size - self.conv1d = ColumnParallelLinear( - input_size=conv_kernel_size, - output_size=self.conv_dim, - bias=use_conv_bias, - quant_config=None, - ) - # unsqueeze to fit conv1d weights shape into the linear weights shape. - # Can't do this in `weight_loader` since it already exists in - # `ColumnParallelLinear` and `set_weight_attrs` - # doesn't allow to override it - self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1) + self.groups_ssm_state_size = self.n_groups * self.ssm_state_size + self.conv_dim = intermediate_size + 2 * self.groups_ssm_state_size + + if self.n_groups % self.tp_size == 0: + self.conv1d = MergedColumnParallelLinear( + input_size=conv_kernel_size, + output_sizes=[ + intermediate_size, + self.groups_ssm_state_size, + self.groups_ssm_state_size, + ], + bias=use_conv_bias, + quant_config=None, + prefix=f"{prefix}.conv1d", + ) - self.in_proj = ColumnParallelLinear( - input_size=hidden_size, - output_size=intermediate_size + self.conv_dim + self.num_heads, - bias=use_bias, - quant_config=quant_config, - ) + self.in_proj = MergedColumnParallelLinear( + input_size=hidden_size, + output_sizes=[ + intermediate_size, + intermediate_size, + self.groups_ssm_state_size, + self.groups_ssm_state_size, + self.num_heads, + ], + bias=use_bias, + quant_config=quant_config, + prefix=f"{prefix}.in_proj", + ) + else: + # This is the n_groups == 1 case, + # where we need to duplicate groups if TP>1. + + self.conv1d = ColumnParallelLinear( + input_size=conv_kernel_size, + output_size=self.conv_dim, + bias=use_conv_bias, + quant_config=None, + prefix=f"{prefix}.conv1d", + ) - # - because in_proj is a concatenation of 3 weights, we - # need to interleave them before sharding - # - use the custom weight loader mamba_v2_sharded_weight_loader - # for conv1d.bias, covn1d.weight and in_proj.weight - # - need to set these settings, to assign the groups to the head shards - group_shard_settings = ( - self.n_groups * self.ssm_state_size, # expected model size - (self.n_groups - n_groups) * - self.ssm_state_size, # extra dims assigned - n_groups == 1, # if there was only one group - ) - intermediate_settings = (intermediate_size, 0, False) - head_settings = (self.num_heads, 0, False) - - # - the weight already has a "weight_loader" attribute - # which set_weight_attrs will raise if we do not - # delete before trying to override it - # - ditto for the otther two weights below - delattr(self.conv1d.bias, "weight_loader") - set_weight_attrs( - self.conv1d.bias, - { - "weight_loader": - mamba_v2_sharded_weight_loader( - [ - intermediate_settings, - group_shard_settings, - group_shard_settings, - ], - self.tp_size, - tp_rank, - ) - }, - ) + self.in_proj = ColumnParallelLinear( + input_size=hidden_size, + output_size=intermediate_size + self.conv_dim + self.num_heads, + bias=use_bias, + quant_config=quant_config, + prefix=f"{prefix}.in_proj", + ) - delattr(self.conv1d.weight, "weight_loader") - set_weight_attrs( - self.conv1d.weight, - { - "weight_loader": - mamba_v2_sharded_weight_loader( - [ - intermediate_settings, - group_shard_settings, - group_shard_settings, - ], - self.tp_size, - tp_rank, - ) - }, - ) + # - because in_proj is a concatenation of 3 weights, we + # need to interleave them before sharding + # - use the custom weight loader mamba_v2_sharded_weight_loader + # for conv1d.bias, covn1d.weight and in_proj.weight + # - need to set these settings, to assign the groups + # to the head shards + group_shard_settings = ( + self.groups_ssm_state_size, # expected model size + (self.n_groups - n_groups) * + self.ssm_state_size, # extra dims assigned + n_groups == 1, # if there was only one group + ) + intermediate_settings = (intermediate_size, 0, False) + head_settings = (self.num_heads, 0, False) + + # - the weight already has a "weight_loader" attribute + # which set_weight_attrs will raise if we do not + # delete before trying to override it + # - ditto for the otther two weights below + delattr(self.conv1d.bias, "weight_loader") + set_weight_attrs( + self.conv1d.bias, + { + "weight_loader": + mamba_v2_sharded_weight_loader( + [ + intermediate_settings, + group_shard_settings, + group_shard_settings, + ], + self.tp_size, + tp_rank, + ) + }, + ) - if quant_config is None: - # - quant layers do not have a weight loader - delattr(self.in_proj.weight, "weight_loader") + delattr(self.conv1d.weight, "weight_loader") set_weight_attrs( - self.in_proj.weight, + self.conv1d.weight, { "weight_loader": mamba_v2_sharded_weight_loader( [ - intermediate_settings, # for gate intermediate_settings, group_shard_settings, group_shard_settings, - head_settings, # for dt ], self.tp_size, tp_rank, @@ -378,6 +390,38 @@ def __init__(self, }, ) + if quant_config is None: + # - quant layers do not have a weight loader + delattr(self.in_proj.weight, "weight_loader") + set_weight_attrs( + self.in_proj.weight, + { + "weight_loader": + mamba_v2_sharded_weight_loader( + [ + intermediate_settings, # for gate + intermediate_settings, + group_shard_settings, + group_shard_settings, + head_settings, # for dt + ], + self.tp_size, + tp_rank, + ) + }, + ) + + # unsqueeze to fit conv1d weights shape into the linear weights shape. + # Can't do this in `weight_loader` since it already exists in + # `ColumnParallelLinear` and `MergedColumnParallelLinear`, + # and `set_weight_attrs` doesn't allow to override it + self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1) + + # fmt: off + print(f"AAAAAAAAAAAAAAAA############################# MambaMixer2 __init__: {type(self.conv1d)=}", flush=True) # noqa: E501 + print(f"AAAAAAAAAAAAAAAA############################# MambaMixer2 __init__: {type(self.in_proj)=}", flush=True) # noqa: E501 + # fmt: on + # - these are TPed by heads to reduce the size of the # temporal shape self.A = nn.Parameter( @@ -495,8 +539,6 @@ def forward_cuda( chunk_indices_p = mamba2_metadata.chunk_indices chunk_offsets_p = mamba2_metadata.chunk_offsets - groups_time_state_size = self.n_groups * self.ssm_state_size - # 1. Gated MLP's linear projection projected_states, _ = self.in_proj(hidden_states) @@ -521,8 +563,8 @@ def forward_cuda( hidden_states_B_C, [ self.intermediate_size // self.tp_size, - groups_time_state_size // self.tp_size, - groups_time_state_size // self.tp_size, + self.groups_ssm_state_size // self.tp_size, + self.groups_ssm_state_size // self.tp_size, ], dim=-1, ) From eb4d81f66cadd3e75117f3d69f89d6db10c98e15 Mon Sep 17 00:00:00 2001 From: Tomer Asida <57313761+tomeras91@users.noreply.github.com> Date: Wed, 10 Sep 2025 17:13:52 +0300 Subject: [PATCH 2/3] fix assertion comment + remove debug print Signed-off-by: Tomer Asida <57313761+tomeras91@users.noreply.github.com> --- vllm/model_executor/layers/mamba/mamba_mixer2.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index dccfe303a021..fb6f61b26ad8 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -262,7 +262,7 @@ def __init__(self, ), "Tensor parallel world size must divide num heads." assert (n_groups % self.tp_size) == 0 or n_groups == 1, ( - "If tensor parallel world size does not divide num_heads, " + "If tensor parallel world size does not divide num_groups, " "then num_groups must equal 1.") assert (n_groups % self.tp_size == 0) or self.tp_size == 1 or \ @@ -417,11 +417,6 @@ def __init__(self, # and `set_weight_attrs` doesn't allow to override it self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1) - # fmt: off - print(f"AAAAAAAAAAAAAAAA############################# MambaMixer2 __init__: {type(self.conv1d)=}", flush=True) # noqa: E501 - print(f"AAAAAAAAAAAAAAAA############################# MambaMixer2 __init__: {type(self.in_proj)=}", flush=True) # noqa: E501 - # fmt: on - # - these are TPed by heads to reduce the size of the # temporal shape self.A = nn.Parameter( From 947c7456bc65e3bd2aa6ad8908f17c897df18d08 Mon Sep 17 00:00:00 2001 From: tomeras91 <57313761+tomeras91@users.noreply.github.com> Date: Wed, 10 Sep 2025 18:09:31 +0300 Subject: [PATCH 3/3] Fix if condition. Use `n_groups` instead of `self.n_groups` which is always divisible by tp_size Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Signed-off-by: tomeras91 <57313761+tomeras91@users.noreply.github.com> --- vllm/model_executor/layers/mamba/mamba_mixer2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index fb6f61b26ad8..4575a3e6ea11 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -291,7 +291,7 @@ def __init__(self, self.groups_ssm_state_size = self.n_groups * self.ssm_state_size self.conv_dim = intermediate_size + 2 * self.groups_ssm_state_size - if self.n_groups % self.tp_size == 0: + if n_groups % self.tp_size == 0: self.conv1d = MergedColumnParallelLinear( input_size=conv_kernel_size, output_sizes=[