-
Notifications
You must be signed in to change notification settings - Fork 204
Added support to export for BF16 weight and amax for vLLM fakequant QAT #579
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
10 commits
Select commit
Hold shift + click to select a range
a93e2b4
Added support to export for BF16 weight and amax
kinjalpatel27 05cd504
Updated docs
kinjalpatel27 b6efc6e
minor
kinjalpatel27 4daf5ce
minor
kinjalpatel27 2f6c0c0
minor
kinjalpatel27 bc85b5c
added seperate file for vLLM for export
kinjalpatel27 b0f78c8
added test for vllm fq export
kinjalpatel27 f46e41d
Added support for Qwen3-MoE
kinjalpatel27 d8652d1
minor
kinjalpatel27 1092699
minor
kinjalpatel27 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -15,7 +15,9 @@ | |
|
|
||
| import dataclasses | ||
| import os | ||
| import re | ||
| import warnings | ||
| from collections import defaultdict | ||
| from contextlib import contextmanager | ||
| from typing import Any | ||
|
|
||
|
|
@@ -30,6 +32,99 @@ | |
| from modelopt.torch.utils.dataset_utils import get_dataset_dataloader | ||
|
|
||
|
|
||
| def convert_amax_hf2vllm( | ||
| hf_state_dict: dict[str, torch.Tensor], fuse_experts: bool = False | ||
| ) -> dict[str, torch.Tensor]: | ||
| """ | ||
| Convert amax values from HuggingFace format to vLLM format. | ||
|
|
||
| This function merges: | ||
| - q_proj, k_proj, v_proj amax values into qkv_proj (taking max) | ||
| - gate_proj, up_proj amax values into gate_up_proj (taking max) | ||
|
|
||
| Args: | ||
| hf_state_dict: HuggingFace state dict containing amax values | ||
|
|
||
| Returns: | ||
| vLLM format state dict with merged amax values | ||
| """ | ||
| vllm_state_dict = {} | ||
|
|
||
| # Group keys by their base pattern (without the specific projection name) | ||
| merge_groups = defaultdict(list) | ||
|
|
||
| for key, value in hf_state_dict.items(): | ||
| if "_amax" not in key: | ||
| # Copy non-amax keys as-is | ||
| vllm_state_dict[key] = value | ||
| continue | ||
|
|
||
| # Check if this is a q/k/v projection that needs merging | ||
| qkv_match = re.search(r"(.*\.)([qkv])_proj(\..+_amax)$", key) | ||
| if qkv_match: | ||
| base_pattern = qkv_match.group(1) + "qkv_proj" + qkv_match.group(3) | ||
| merge_groups[base_pattern].append((key, value)) | ||
| continue | ||
|
|
||
| # Check if this is an expert gate/up projection | ||
| # Pattern: model.layers.0.mlp.experts.*.gate_proj.input_quantizer._amax and | ||
| # model.layers.0.mlp.experts.*.up_proj.input_quantizer._amax | ||
| # Maps to: model.layers.0.mlp.experts.w13_input_quantizer._amax | ||
| expert_gate_up_match = ( | ||
| "mixer" not in key | ||
| and fuse_experts | ||
| and re.search(r"(.*\.experts)\.\d+\.(gate|up)_proj\.([^.]+_quantizer\._amax)$", key) | ||
| ) | ||
| if expert_gate_up_match: | ||
| base_pattern = expert_gate_up_match.group(1) + ".w13_" + expert_gate_up_match.group(3) | ||
| merge_groups[base_pattern].append((key, value)) | ||
| continue | ||
|
|
||
| # Check if this is a non-expert gate/up projection that needs merging | ||
| gate_up_match = ( | ||
| "mixer" not in key | ||
| and "experts" not in key | ||
| and re.search(r"(.*\.)(gate|up)_proj(\..+_amax)$", key) | ||
| ) | ||
| if gate_up_match: | ||
| base_pattern = gate_up_match.group(1) + "gate_up_proj" + gate_up_match.group(3) | ||
| merge_groups[base_pattern].append((key, value)) | ||
| continue | ||
|
|
||
| # Check if this is an expert down_proj | ||
| # Pattern: model.layers.0.mlp.experts.*.down_proj.input_quantizer._amax | ||
| # Maps to: model.layers.0.mlp.experts.w2_input_quantizer._amax | ||
| expert_down_match = ( | ||
| "mixer" not in key | ||
| and fuse_experts | ||
| and re.search(r"(.*\.experts)\.\d+\.down_proj\.([^.]+_quantizer\._amax)$", key) | ||
| ) | ||
| if expert_down_match: | ||
| base_pattern = expert_down_match.group(1) + ".w2_" + expert_down_match.group(2) | ||
| merge_groups[base_pattern].append((key, value)) | ||
| continue | ||
|
|
||
| # Copy other amax keys as-is (like o_proj, down_proj) | ||
| vllm_state_dict[key] = value | ||
|
|
||
| # Merge grouped amax values by taking the maximum | ||
| for merged_key, key_value_pairs in merge_groups.items(): | ||
| if len(key_value_pairs) > 1: | ||
| # Take the maximum across all values for this merged key | ||
| values = [value for _, value in key_value_pairs] | ||
| merged_value = torch.stack(values).max(dim=0)[0] | ||
| vllm_state_dict[merged_key] = merged_value | ||
| print(f"Merged {len(key_value_pairs)} keys into {merged_key}") | ||
| for orig_key, _ in key_value_pairs: | ||
| print(f" - {orig_key}") | ||
| else: | ||
| # Single key, just rename it | ||
| _, value = key_value_pairs[0] | ||
| vllm_state_dict[merged_key] = value | ||
|
|
||
| return vllm_state_dict | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what if we miss a merging rule here? What will happen in VLLM? |
||
|
|
||
|
|
||
| @contextmanager | ||
| def disable_compilation(model): | ||
| do_not_compile = True | ||
|
|
@@ -154,8 +249,17 @@ def calibrate_loop(model: Any = None) -> None: | |
| if amax_file_path: | ||
| print(f"Loading amax values from {amax_file_path}") | ||
| saved_amax_dict = torch.load(amax_file_path) | ||
| current_state_dict = model.state_dict() | ||
| # convert amax keys to vLLM format | ||
| if hasattr(self.model_runner.model, "hf_to_vllm_mapper"): | ||
| saved_amax_dict = self.model_runner.model.hf_to_vllm_mapper.apply_dict(saved_amax_dict) | ||
| saved_amax_dict = { | ||
| key.replace("quantizer_amax", "quantizer._amax"): value | ||
| for key, value in saved_amax_dict.items() | ||
| if key.endswith("quantizer_amax") | ||
| } | ||
| saved_amax_dict = convert_amax_hf2vllm(saved_amax_dict, fuse_experts=True) | ||
|
|
||
| current_state_dict = model.state_dict() | ||
| # Count amax keys in checkpoint and model | ||
| checkpoint_amax_keys = [key for key in saved_amax_dict if key.endswith("_amax")] | ||
| model_amax_keys = [key for key in current_state_dict if key.endswith("_amax")] | ||
|
|
||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks Kinjal for documenting this. Create a jira ticket to address this - https://jirasw.nvidia.com/browse/OMNIML-3051
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for creating the ticket