-
-
Notifications
You must be signed in to change notification settings - Fork 11.7k
[Mamba] Support TP>1 with quantization for mamba2 mixer in case n_groups % tp_size == 0
#24593
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 2 commits
Commits
Show all changes
6 commits
Select commit
Hold shift + click to select a range
392e99c
Add initial implementation for using MergedColumnParallelLinear if n_…
tomeras91 eb4d81f
fix assertion comment + remove debug print
tomeras91 947c745
Fix if condition. Use `n_groups` instead of `self.n_groups` which is …
tomeras91 d7cdb7b
Merge branch 'main' into fix-mamba2-quant-tp
tomeras91 9956fb2
Merge branch 'main' into fix-mamba2-quant-tp
tomeras91 a6f1415
Merge branch 'main' into fix-mamba2-quant-tp
tomeras91 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -19,6 +19,7 @@ | |
| from vllm.forward_context import ForwardContext, get_forward_context | ||
| from vllm.model_executor.custom_op import CustomOp | ||
| from vllm.model_executor.layers.linear import (ColumnParallelLinear, | ||
| MergedColumnParallelLinear, | ||
| RowParallelLinear) | ||
| from vllm.model_executor.layers.mamba.abstract import MambaBase | ||
| from vllm.model_executor.layers.mamba.mamba2_metadata import (Mamba2Metadata, | ||
|
|
@@ -261,12 +262,14 @@ def __init__(self, | |
| ), "Tensor parallel world size must divide num heads." | ||
|
|
||
| assert (n_groups % self.tp_size) == 0 or n_groups == 1, ( | ||
| "If tensor parallel world size does not divide num_heads, " | ||
| "If tensor parallel world size does not divide num_groups, " | ||
| "then num_groups must equal 1.") | ||
|
|
||
| assert ( | ||
| self.tp_size == 1 or quant_config is None | ||
| ), "Tensor parallel currently not supported for quantized models." | ||
| assert (n_groups % self.tp_size == 0) or self.tp_size == 1 or \ | ||
| quant_config is None, ( | ||
| "Tensor parallel currently supported for quantized models only " | ||
| "if tensor parallel world size divides num groups." | ||
| ) | ||
|
|
||
| self.ssm_state_size = ssm_state_size | ||
| self.conv_kernel_size = conv_kernel_size | ||
|
|
@@ -285,99 +288,135 @@ def __init__(self, | |
| n_groups, self.tp_size) | ||
| self.n_groups = n_groups + groups | ||
|
|
||
| self.conv_dim = intermediate_size + 2 * self.n_groups * ssm_state_size | ||
| self.conv1d = ColumnParallelLinear( | ||
| input_size=conv_kernel_size, | ||
| output_size=self.conv_dim, | ||
| bias=use_conv_bias, | ||
| quant_config=None, | ||
| ) | ||
| # unsqueeze to fit conv1d weights shape into the linear weights shape. | ||
| # Can't do this in `weight_loader` since it already exists in | ||
| # `ColumnParallelLinear` and `set_weight_attrs` | ||
| # doesn't allow to override it | ||
| self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1) | ||
| self.groups_ssm_state_size = self.n_groups * self.ssm_state_size | ||
| self.conv_dim = intermediate_size + 2 * self.groups_ssm_state_size | ||
|
|
||
| if self.n_groups % self.tp_size == 0: | ||
| self.conv1d = MergedColumnParallelLinear( | ||
| input_size=conv_kernel_size, | ||
| output_sizes=[ | ||
| intermediate_size, | ||
| self.groups_ssm_state_size, | ||
| self.groups_ssm_state_size, | ||
| ], | ||
| bias=use_conv_bias, | ||
| quant_config=None, | ||
| prefix=f"{prefix}.conv1d", | ||
| ) | ||
|
|
||
| self.in_proj = ColumnParallelLinear( | ||
| input_size=hidden_size, | ||
| output_size=intermediate_size + self.conv_dim + self.num_heads, | ||
| bias=use_bias, | ||
| quant_config=quant_config, | ||
| ) | ||
| self.in_proj = MergedColumnParallelLinear( | ||
| input_size=hidden_size, | ||
| output_sizes=[ | ||
| intermediate_size, | ||
| intermediate_size, | ||
| self.groups_ssm_state_size, | ||
| self.groups_ssm_state_size, | ||
| self.num_heads, | ||
| ], | ||
| bias=use_bias, | ||
| quant_config=quant_config, | ||
| prefix=f"{prefix}.in_proj", | ||
| ) | ||
| else: | ||
| # This is the n_groups == 1 case, | ||
| # where we need to duplicate groups if TP>1. | ||
|
|
||
| self.conv1d = ColumnParallelLinear( | ||
| input_size=conv_kernel_size, | ||
| output_size=self.conv_dim, | ||
| bias=use_conv_bias, | ||
| quant_config=None, | ||
| prefix=f"{prefix}.conv1d", | ||
| ) | ||
|
|
||
| # - because in_proj is a concatenation of 3 weights, we | ||
| # need to interleave them before sharding | ||
| # - use the custom weight loader mamba_v2_sharded_weight_loader | ||
| # for conv1d.bias, covn1d.weight and in_proj.weight | ||
| # - need to set these settings, to assign the groups to the head shards | ||
| group_shard_settings = ( | ||
| self.n_groups * self.ssm_state_size, # expected model size | ||
| (self.n_groups - n_groups) * | ||
| self.ssm_state_size, # extra dims assigned | ||
| n_groups == 1, # if there was only one group | ||
| ) | ||
| intermediate_settings = (intermediate_size, 0, False) | ||
| head_settings = (self.num_heads, 0, False) | ||
|
|
||
| # - the weight already has a "weight_loader" attribute | ||
| # which set_weight_attrs will raise if we do not | ||
| # delete before trying to override it | ||
| # - ditto for the otther two weights below | ||
| delattr(self.conv1d.bias, "weight_loader") | ||
| set_weight_attrs( | ||
| self.conv1d.bias, | ||
| { | ||
| "weight_loader": | ||
| mamba_v2_sharded_weight_loader( | ||
| [ | ||
| intermediate_settings, | ||
| group_shard_settings, | ||
| group_shard_settings, | ||
| ], | ||
| self.tp_size, | ||
| tp_rank, | ||
| ) | ||
| }, | ||
| ) | ||
| self.in_proj = ColumnParallelLinear( | ||
| input_size=hidden_size, | ||
| output_size=intermediate_size + self.conv_dim + self.num_heads, | ||
| bias=use_bias, | ||
| quant_config=quant_config, | ||
| prefix=f"{prefix}.in_proj", | ||
| ) | ||
|
|
||
| delattr(self.conv1d.weight, "weight_loader") | ||
| set_weight_attrs( | ||
| self.conv1d.weight, | ||
| { | ||
| "weight_loader": | ||
| mamba_v2_sharded_weight_loader( | ||
| [ | ||
| intermediate_settings, | ||
| group_shard_settings, | ||
| group_shard_settings, | ||
| ], | ||
| self.tp_size, | ||
| tp_rank, | ||
| ) | ||
| }, | ||
| ) | ||
| # - because in_proj is a concatenation of 3 weights, we | ||
| # need to interleave them before sharding | ||
| # - use the custom weight loader mamba_v2_sharded_weight_loader | ||
| # for conv1d.bias, covn1d.weight and in_proj.weight | ||
| # - need to set these settings, to assign the groups | ||
| # to the head shards | ||
| group_shard_settings = ( | ||
| self.groups_ssm_state_size, # expected model size | ||
| (self.n_groups - n_groups) * | ||
| self.ssm_state_size, # extra dims assigned | ||
| n_groups == 1, # if there was only one group | ||
| ) | ||
| intermediate_settings = (intermediate_size, 0, False) | ||
| head_settings = (self.num_heads, 0, False) | ||
|
|
||
| # - the weight already has a "weight_loader" attribute | ||
| # which set_weight_attrs will raise if we do not | ||
| # delete before trying to override it | ||
| # - ditto for the otther two weights below | ||
| delattr(self.conv1d.bias, "weight_loader") | ||
| set_weight_attrs( | ||
| self.conv1d.bias, | ||
| { | ||
| "weight_loader": | ||
| mamba_v2_sharded_weight_loader( | ||
| [ | ||
| intermediate_settings, | ||
| group_shard_settings, | ||
| group_shard_settings, | ||
| ], | ||
| self.tp_size, | ||
| tp_rank, | ||
| ) | ||
| }, | ||
| ) | ||
|
|
||
| if quant_config is None: | ||
| # - quant layers do not have a weight loader | ||
| delattr(self.in_proj.weight, "weight_loader") | ||
| delattr(self.conv1d.weight, "weight_loader") | ||
| set_weight_attrs( | ||
| self.in_proj.weight, | ||
| self.conv1d.weight, | ||
| { | ||
| "weight_loader": | ||
| mamba_v2_sharded_weight_loader( | ||
| [ | ||
| intermediate_settings, # for gate | ||
| intermediate_settings, | ||
| group_shard_settings, | ||
| group_shard_settings, | ||
| head_settings, # for dt | ||
| ], | ||
| self.tp_size, | ||
| tp_rank, | ||
| ) | ||
| }, | ||
| ) | ||
|
|
||
| if quant_config is None: | ||
| # - quant layers do not have a weight loader | ||
| delattr(self.in_proj.weight, "weight_loader") | ||
| set_weight_attrs( | ||
| self.in_proj.weight, | ||
| { | ||
| "weight_loader": | ||
| mamba_v2_sharded_weight_loader( | ||
| [ | ||
| intermediate_settings, # for gate | ||
| intermediate_settings, | ||
| group_shard_settings, | ||
| group_shard_settings, | ||
| head_settings, # for dt | ||
| ], | ||
| self.tp_size, | ||
| tp_rank, | ||
| ) | ||
| }, | ||
| ) | ||
|
|
||
| # unsqueeze to fit conv1d weights shape into the linear weights shape. | ||
| # Can't do this in `weight_loader` since it already exists in | ||
| # `ColumnParallelLinear` and `MergedColumnParallelLinear`, | ||
| # and `set_weight_attrs` doesn't allow to override it | ||
| self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1) | ||
|
|
||
| # - these are TPed by heads to reduce the size of the | ||
| # temporal shape | ||
| self.A = nn.Parameter( | ||
|
|
@@ -495,8 +534,6 @@ def forward_cuda( | |
| chunk_indices_p = mamba2_metadata.chunk_indices | ||
| chunk_offsets_p = mamba2_metadata.chunk_offsets | ||
|
|
||
| groups_time_state_size = self.n_groups * self.ssm_state_size | ||
|
|
||
| # 1. Gated MLP's linear projection | ||
| projected_states, _ = self.in_proj(hidden_states) | ||
|
|
||
|
|
@@ -521,8 +558,8 @@ def forward_cuda( | |
| hidden_states_B_C, | ||
| [ | ||
| self.intermediate_size // self.tp_size, | ||
| groups_time_state_size // self.tp_size, | ||
| groups_time_state_size // self.tp_size, | ||
| self.groups_ssm_state_size // self.tp_size, | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. oic was this a bug?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. no.. this is just some renaming since I now create |
||
| self.groups_ssm_state_size // self.tp_size, | ||
| ], | ||
| dim=-1, | ||
| ) | ||
|
|
||
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.