diff --git a/CHANGELOG.rst b/CHANGELOG.rst index ffe0acc53..899b14009 100755 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -17,6 +17,7 @@ Model Optimizer Changelog (Linux) - Add flag ``trt_plugins_precision`` in ONNX autocast to indicate custom ops precision. This is similar to the flag already existing in the quantization workflow. - Add support for PyTorch Geometric quantization. - Add per tensor and per channel MSE calibrator support. +- Added support for PTQ/QAT checkpoint export and loading for running fakequant evaluation in vLLM. See ``examples/vllm_serve/README.md#load-qatptq-model-and-serve-in-vllm-wip`` for more details. **Documentation** diff --git a/examples/vllm_serve/README.md b/examples/vllm_serve/README.md index 64a4147c2..90c053d5b 100644 --- a/examples/vllm_serve/README.md +++ b/examples/vllm_serve/README.md @@ -55,16 +55,19 @@ lm_eval --model local-completions --tasks gsm8k --model_args model=, ## Load QAT/PTQ model and serve in vLLM (WIP) -Overwrite the calibrated amax value with prepared values from either PTQ/QAT. This is only tested for Llama3.1 +Overwrite the calibrated amax value with prepared values from either QAT/PTQ. -Step 1: convert amax to merged amax, using llama3.1 as an example: +Step 1: export the model with bf16 weights and amax values. + +- For HF model set `export_bf16_weights_amax` to export the model with function `modelopt.torch.export.unified_export_hf.export_hf_checkpoint`. +- For MCore model use `export_bf16_weights_amax` to export the model with function `modelopt.torch.export.unified_export_megatron.export_mcore_gpt_to_hf`. + +Step 2: configure from exported model using AMAX_FILE_PATH environment variable in step 1. For example: ```bash -python convert_amax_hf2vllm.py -i -o +AMAX_FILE_PATH= QUANT_CFG= python vllm_serve_fakequant.py -tp 8 --host 0.0.0.0 --port 8000 ``` -Step 2: add `` to `quant_config` in `vllm_serve_fakequant.py` - ## Important Notes **Amax Synchronization across Tensor Parallel (TP):** @@ -85,3 +88,5 @@ torch.distributed.barrier() ## Known Problems 1. AWQ is not yet supported in vLLM. +2. PTQ/QAT checkpoint doesn't work with KV Cache quantization enabled. +3. Mixed precision checkpoint doesn't work currently. diff --git a/examples/vllm_serve/convert_amax_hf2vllm.py b/examples/vllm_serve/convert_amax_hf2vllm.py deleted file mode 100644 index 6f0321a91..000000000 --- a/examples/vllm_serve/convert_amax_hf2vllm.py +++ /dev/null @@ -1,213 +0,0 @@ -#!/usr/bin/env python3 - -# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import argparse -import os -import re -from collections import defaultdict - -import torch - - -def convert_amax_hf2vllm( - hf_state_dict: dict[str, torch.Tensor], -) -> dict[str, torch.Tensor]: - """ - Convert amax values from HuggingFace format to vLLM format. - - This function merges: - - q_proj, k_proj, v_proj amax values into qkv_proj (taking max) - - gate_proj, up_proj amax values into gate_up_proj (taking max) - - Args: - hf_state_dict: HuggingFace state dict containing amax values - - Returns: - vLLM format state dict with merged amax values - """ - vllm_state_dict = {} - - # Group keys by their base pattern (without the specific projection name) - merge_groups = defaultdict(list) - - for key, value in hf_state_dict.items(): - if "_amax" not in key: - # Copy non-amax keys as-is - vllm_state_dict[key] = value - continue - - # Check if this is a q/k/v projection that needs merging - qkv_match = re.search(r"(.*\.)([qkv])_proj(\..+_amax)$", key) - if qkv_match: - base_pattern = qkv_match.group(1) + "qkv_proj" + qkv_match.group(3) - merge_groups[base_pattern].append((key, value)) - continue - - # Check if this is a gate/up projection that needs merging - gate_up_match = re.search(r"(.*\.)(gate|up)_proj(\..+_amax)$", key) - if gate_up_match: - base_pattern = gate_up_match.group(1) + "gate_up_proj" + gate_up_match.group(3) - merge_groups[base_pattern].append((key, value)) - continue - - # Copy other amax keys as-is (like o_proj, down_proj) - vllm_state_dict[key] = value - - # Merge grouped amax values by taking the maximum - for merged_key, key_value_pairs in merge_groups.items(): - if len(key_value_pairs) > 1: - # Take the maximum across all values for this merged key - values = [value for _, value in key_value_pairs] - merged_value = torch.stack(values).max(dim=0)[0] - vllm_state_dict[merged_key] = merged_value - print(f"Merged {len(key_value_pairs)} keys into {merged_key}") - for orig_key, _ in key_value_pairs: - print(f" - {orig_key}") - else: - # Single key, just rename it - _, value = key_value_pairs[0] - vllm_state_dict[merged_key] = value - - return vllm_state_dict - - -def test_conversion(): - """Test the conversion logic with sample keys""" - import torch - - # Create sample HF state dict - sample_hf_keys = [ - "model.layers.0.self_attn.q_proj.input_quantizer._amax", - "model.layers.0.self_attn.k_proj.input_quantizer._amax", - "model.layers.0.self_attn.v_proj.input_quantizer._amax", - "model.layers.0.self_attn.q_proj.weight_quantizer._amax", - "model.layers.0.self_attn.k_proj.weight_quantizer._amax", - "model.layers.0.self_attn.v_proj.weight_quantizer._amax", - "model.layers.0.self_attn.o_proj.input_quantizer._amax", - "model.layers.0.self_attn.o_proj.weight_quantizer._amax", - "model.layers.0.mlp.gate_proj.input_quantizer._amax", - "model.layers.0.mlp.up_proj.input_quantizer._amax", - "model.layers.0.mlp.gate_proj.weight_quantizer._amax", - "model.layers.0.mlp.up_proj.weight_quantizer._amax", - "model.layers.0.mlp.down_proj.input_quantizer._amax", - "model.layers.0.mlp.down_proj.weight_quantizer._amax", - ] - - hf_state_dict = {} - for key in sample_hf_keys: - hf_state_dict[key] = torch.tensor([1.0, 2.0, 3.0]) # Sample values - - print("Testing conversion with sample keys...") - print(f"Input keys: {len(sample_hf_keys)}") - - vllm_state_dict = convert_amax_hf2vllm(hf_state_dict) - vllm_amax_keys = [k for k in vllm_state_dict if "_amax" in k] - - print(f"Output keys: {len(vllm_amax_keys)}") - print("\nExpected vLLM keys:") - expected_keys = [ - "model.layers.0.self_attn.qkv_proj.input_quantizer._amax", - "model.layers.0.self_attn.qkv_proj.weight_quantizer._amax", - "model.layers.0.self_attn.o_proj.input_quantizer._amax", - "model.layers.0.self_attn.o_proj.weight_quantizer._amax", - "model.layers.0.mlp.gate_up_proj.input_quantizer._amax", - "model.layers.0.mlp.gate_up_proj.weight_quantizer._amax", - "model.layers.0.mlp.down_proj.input_quantizer._amax", - "model.layers.0.mlp.down_proj.weight_quantizer._amax", - ] - - for key in expected_keys: - print(f" {key}") - - print("\nActual vLLM keys:") - for key in sorted(vllm_amax_keys): - print(f" {key}") - - # Check if all expected keys are present - missing_keys = set(expected_keys) - set(vllm_amax_keys) - extra_keys = set(vllm_amax_keys) - set(expected_keys) - - if missing_keys: - print(f"\nMissing keys: {missing_keys}") - if extra_keys: - print(f"\nExtra keys: {extra_keys}") - - if not missing_keys and not extra_keys: - print("\n✓ Test passed! All keys converted correctly.") - else: - print("\n✗ Test failed! Key mismatch detected.") - - -def main(): - parser = argparse.ArgumentParser( - description="Convert amax values from HuggingFace to vLLM format" - ) - parser.add_argument("--input", "-i", help="Input HuggingFace checkpoint path") - parser.add_argument("--output", "-o", help="Output vLLM checkpoint path") - parser.add_argument("--dry-run", action="store_true", help="Show conversion without saving") - parser.add_argument("--test", action="store_true", help="Run test with sample data") - - args = parser.parse_args() - - if args.test: - test_conversion() - return - - if not args.input or not args.output: - parser.error("--input and --output are required unless using --test") - - # Load HuggingFace checkpoint - print(f"Loading HuggingFace checkpoint from: {args.input}") - if os.path.isfile(args.input): - hf_state_dict = torch.load(args.input, map_location="cpu") - else: - raise Exception(f"File not found: {args.input}") - - print(f"Loaded {len(hf_state_dict)} keys from HuggingFace checkpoint") - - # Filter to only amax keys for analysis - amax_keys = [k for k in hf_state_dict if "_amax" in k] - print(f"Found {len(amax_keys)} amax keys") - - if args.dry_run: - print("\nAmax keys in HuggingFace format:") - for key in sorted(amax_keys): - print(f" {key}") - - # Convert to vLLM format - print("\nConverting to vLLM format...") - vllm_state_dict = convert_amax_hf2vllm(hf_state_dict) - - vllm_amax_keys = [k for k in vllm_state_dict if "_amax" in k] - print(f"Result: {len(vllm_amax_keys)} amax keys in vLLM format") - - if args.dry_run: - print("\nAmax keys in vLLM format:") - for key in sorted(vllm_amax_keys): - print(f" {key}") - print("\nDry run complete. No files saved.") - return - - # Save vLLM checkpoint - print(f"Saving vLLM checkpoint to: {args.output}") - os.makedirs(os.path.dirname(args.output), exist_ok=True) - torch.save(vllm_state_dict, args.output) - print("Conversion complete!") - - -if __name__ == "__main__": - main() diff --git a/examples/vllm_serve/fakequant_worker.py b/examples/vllm_serve/fakequant_worker.py index 8532c369f..d08e62340 100644 --- a/examples/vllm_serve/fakequant_worker.py +++ b/examples/vllm_serve/fakequant_worker.py @@ -15,7 +15,9 @@ import dataclasses import os +import re import warnings +from collections import defaultdict from contextlib import contextmanager from typing import Any @@ -30,6 +32,99 @@ from modelopt.torch.utils.dataset_utils import get_dataset_dataloader +def convert_amax_hf2vllm( + hf_state_dict: dict[str, torch.Tensor], fuse_experts: bool = False +) -> dict[str, torch.Tensor]: + """ + Convert amax values from HuggingFace format to vLLM format. + + This function merges: + - q_proj, k_proj, v_proj amax values into qkv_proj (taking max) + - gate_proj, up_proj amax values into gate_up_proj (taking max) + + Args: + hf_state_dict: HuggingFace state dict containing amax values + + Returns: + vLLM format state dict with merged amax values + """ + vllm_state_dict = {} + + # Group keys by their base pattern (without the specific projection name) + merge_groups = defaultdict(list) + + for key, value in hf_state_dict.items(): + if "_amax" not in key: + # Copy non-amax keys as-is + vllm_state_dict[key] = value + continue + + # Check if this is a q/k/v projection that needs merging + qkv_match = re.search(r"(.*\.)([qkv])_proj(\..+_amax)$", key) + if qkv_match: + base_pattern = qkv_match.group(1) + "qkv_proj" + qkv_match.group(3) + merge_groups[base_pattern].append((key, value)) + continue + + # Check if this is an expert gate/up projection + # Pattern: model.layers.0.mlp.experts.*.gate_proj.input_quantizer._amax and + # model.layers.0.mlp.experts.*.up_proj.input_quantizer._amax + # Maps to: model.layers.0.mlp.experts.w13_input_quantizer._amax + expert_gate_up_match = ( + "mixer" not in key + and fuse_experts + and re.search(r"(.*\.experts)\.\d+\.(gate|up)_proj\.([^.]+_quantizer\._amax)$", key) + ) + if expert_gate_up_match: + base_pattern = expert_gate_up_match.group(1) + ".w13_" + expert_gate_up_match.group(3) + merge_groups[base_pattern].append((key, value)) + continue + + # Check if this is a non-expert gate/up projection that needs merging + gate_up_match = ( + "mixer" not in key + and "experts" not in key + and re.search(r"(.*\.)(gate|up)_proj(\..+_amax)$", key) + ) + if gate_up_match: + base_pattern = gate_up_match.group(1) + "gate_up_proj" + gate_up_match.group(3) + merge_groups[base_pattern].append((key, value)) + continue + + # Check if this is an expert down_proj + # Pattern: model.layers.0.mlp.experts.*.down_proj.input_quantizer._amax + # Maps to: model.layers.0.mlp.experts.w2_input_quantizer._amax + expert_down_match = ( + "mixer" not in key + and fuse_experts + and re.search(r"(.*\.experts)\.\d+\.down_proj\.([^.]+_quantizer\._amax)$", key) + ) + if expert_down_match: + base_pattern = expert_down_match.group(1) + ".w2_" + expert_down_match.group(2) + merge_groups[base_pattern].append((key, value)) + continue + + # Copy other amax keys as-is (like o_proj, down_proj) + vllm_state_dict[key] = value + + # Merge grouped amax values by taking the maximum + for merged_key, key_value_pairs in merge_groups.items(): + if len(key_value_pairs) > 1: + # Take the maximum across all values for this merged key + values = [value for _, value in key_value_pairs] + merged_value = torch.stack(values).max(dim=0)[0] + vllm_state_dict[merged_key] = merged_value + print(f"Merged {len(key_value_pairs)} keys into {merged_key}") + for orig_key, _ in key_value_pairs: + print(f" - {orig_key}") + else: + # Single key, just rename it + _, value = key_value_pairs[0] + vllm_state_dict[merged_key] = value + + return vllm_state_dict + + @contextmanager def disable_compilation(model): do_not_compile = True @@ -154,8 +249,17 @@ def calibrate_loop(model: Any = None) -> None: if amax_file_path: print(f"Loading amax values from {amax_file_path}") saved_amax_dict = torch.load(amax_file_path) - current_state_dict = model.state_dict() + # convert amax keys to vLLM format + if hasattr(self.model_runner.model, "hf_to_vllm_mapper"): + saved_amax_dict = self.model_runner.model.hf_to_vllm_mapper.apply_dict(saved_amax_dict) + saved_amax_dict = { + key.replace("quantizer_amax", "quantizer._amax"): value + for key, value in saved_amax_dict.items() + if key.endswith("quantizer_amax") + } + saved_amax_dict = convert_amax_hf2vllm(saved_amax_dict, fuse_experts=True) + current_state_dict = model.state_dict() # Count amax keys in checkpoint and model checkpoint_amax_keys = [key for key in saved_amax_dict if key.endswith("_amax")] model_amax_keys = [key for key in current_state_dict if key.endswith("_amax")] diff --git a/modelopt/torch/export/plugins/vllm_fakequant.py b/modelopt/torch/export/plugins/vllm_fakequant.py new file mode 100644 index 000000000..370d886b7 --- /dev/null +++ b/modelopt/torch/export/plugins/vllm_fakequant.py @@ -0,0 +1,125 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Export functions for vLLM fakequant.""" + +import os +from pathlib import Path + +import torch +import torch.nn as nn + +from modelopt.torch.export.layer_utils import is_quantlinear +from modelopt.torch.export.model_config import QUANTIZATION_NONE +from modelopt.torch.quantization.utils import get_quantizer_state_dict + + +def export_hf_vllm_fq_checkpoint( + model: nn.Module, + export_dir: Path | str, +) -> dict[str, torch.Tensor]: + """Exports the torch model weights and amax values separately. + + This function: + 1. Extracts amax values for calibration + 2. Deletes all quantizer parameters from state dict to store only weights in original dtype + + Args: + model: The quantized model to export + export_dir: Directory to save the amax values + + Returns: + post_state_dict: Dict containing quantized weights + """ + amax_dict = { + name + "._amax": param["_amax"].detach().clone().cpu() + for name, param in get_quantizer_state_dict(model).items() + if "_amax" in param + } + + # remove quantizer from model + for _, module in model.named_modules(): + if is_quantlinear(module): + delattr(module, "weight_quantizer") + delattr(module, "input_quantizer") + delattr(module, "output_quantizer") + module.export() + torch.save(amax_dict, f"{export_dir}/quant_amax.pth") + return model.state_dict() + + +def get_mcore_vllm_fq_quantized_state( + module: torch.nn.Module, name_to_value: dict, dtype: torch.dtype = torch.bfloat16 +): + """Return a state_dict, quantization format, and block_size of the quantized module. + + Args: + module: The target module to perform real quantization. + name_to_value: The dictionary to store the quantized state. + dtype: The default data type. + + Returns: + Tuple: state dict, quantization format, and block_size of the quantized module. + + """ + qformat: str = QUANTIZATION_NONE + block_size = 0 + + for name, param in get_quantizer_state_dict(module).items(): + if "_amax" in param: + name_to_value[name + "._amax"] = param["_amax"].to(dtype).cpu() + return name_to_value, qformat, block_size + + +def gather_mcore_vllm_fq_quantized_state_dict( + state_dict: dict[str, torch.Tensor], save_directory: str | os.PathLike +): + """Gather all quantized state dict from all ranks and save them to a file. + + Args: + state_dict: The state dictionary of the module. + save_directory: The directory to save the quantized state dict. + + Returns: + The state dictionary of the module without quantized state. + """ + amax_state_dict = { + k: v.detach().clone().cpu() for k, v in state_dict.items() if k.endswith("_amax") + } + + # Gather all amax dicts to rank 0 + world_size = torch.distributed.get_world_size() + rank = torch.distributed.get_rank() + + if rank == 0: + # Rank 0 will collect all amax values + all_amax_dicts = [None] * world_size + torch.distributed.gather_object(amax_state_dict, all_amax_dicts, dst=0) + + # Merge all amax dicts into one + merged_amax_dict = {} + for amax_dict in all_amax_dicts: + if amax_dict is not None: + merged_amax_dict.update(amax_dict) + + print(f"Total amax entries from all ranks: {len(merged_amax_dict.keys())}") + torch.save(merged_amax_dict, save_directory + "/quant_amax.pth") + else: + # Other ranks just send their amax values + torch.distributed.gather_object(amax_state_dict, None, dst=0) + + torch.distributed.barrier() + + # remove amax values from state_dict + return {k: v for k, v in state_dict.items() if not k.endswith("_amax")} diff --git a/modelopt/torch/export/unified_export_hf.py b/modelopt/torch/export/unified_export_hf.py index 447338690..6deb479f1 100644 --- a/modelopt/torch/export/unified_export_hf.py +++ b/modelopt/torch/export/unified_export_hf.py @@ -59,6 +59,7 @@ ) from .model_utils import get_language_model_from_vl, is_multimodal_model from .plugins import export_spec_ckpt_config, export_spec_ckpt_state_dict, spec_opt_only +from .plugins.vllm_fakequant import export_hf_vllm_fq_checkpoint from .quant_utils import ( fuse_prequant_layernorm, fuse_prequant_to_linear, @@ -558,6 +559,7 @@ def export_hf_checkpoint( dtype: torch.dtype | None = None, export_dir: Path | str = tempfile.gettempdir(), save_modelopt_state: bool = False, + export_vllm_fq_weights_qstate: bool = False, ): """Exports the torch model to unified checkpoint and saves to export_dir. @@ -566,6 +568,8 @@ def export_hf_checkpoint( dtype: the weights data type to export the unquantized layers or the default model data type if None. export_dir: the target export path. save_modelopt_state: whether to save the modelopt state_dict. + export_vllm_fq_weights_qstate: whether to export the weights and quantization state separately for vLLM + fakequant serving. """ export_dir = Path(export_dir) export_dir.mkdir(parents=True, exist_ok=True) @@ -579,13 +583,18 @@ def export_hf_checkpoint( return try: - post_state_dict, hf_quant_config = _export_hf_checkpoint(model, dtype) + if export_vllm_fq_weights_qstate: + post_state_dict = export_hf_vllm_fq_checkpoint(model, export_dir) + hf_quant_config = None + else: + post_state_dict, hf_quant_config = _export_hf_checkpoint(model, dtype) - # Save hf_quant_config.json for backward compatibility - with open(f"{export_dir}/hf_quant_config.json", "w") as file: - json.dump(hf_quant_config, file, indent=4) + if hf_quant_config is not None: + # Save hf_quant_config.json for\ backward compatibility + with open(f"{export_dir}/hf_quant_config.json", "w") as file: + json.dump(hf_quant_config, file, indent=4) - hf_quant_config = convert_hf_quant_config_format(hf_quant_config) + hf_quant_config = convert_hf_quant_config_format(hf_quant_config) # Save model model.save_pretrained( @@ -598,7 +607,8 @@ def export_hf_checkpoint( with open(original_config) as file: config_data = json.load(file) - config_data["quantization_config"] = hf_quant_config + if hf_quant_config is not None: + config_data["quantization_config"] = hf_quant_config with open(original_config, "w") as file: json.dump(config_data, file, indent=4) diff --git a/modelopt/torch/export/unified_export_megatron.py b/modelopt/torch/export/unified_export_megatron.py index e31530109..ba0d76d4f 100644 --- a/modelopt/torch/export/unified_export_megatron.py +++ b/modelopt/torch/export/unified_export_megatron.py @@ -20,6 +20,7 @@ import json import os +import shutil import tempfile from collections import OrderedDict from pathlib import Path @@ -29,7 +30,7 @@ import torch import torch.distributed import torch.nn as nn -from huggingface_hub import snapshot_download +from huggingface_hub import hf_hub_download, snapshot_download from safetensors.torch import safe_open, save_file from tqdm import tqdm @@ -41,11 +42,16 @@ QUANTIZATION_FP8, QUANTIZATION_FP8_PB_REAL, QUANTIZATION_FP8_PB_WO, + QUANTIZATION_NONE, QUANTIZATION_NVFP4, ) from .plugins.mcore_common import all_mcore_hf_export_mapping from .plugins.mcore_custom import CustomModuleMapping, save_safetensors from .plugins.megatron_importer import GPTModelImporter +from .plugins.vllm_fakequant import ( + gather_mcore_vllm_fq_quantized_state_dict, + get_mcore_vllm_fq_quantized_state, +) from .quant_utils import ( get_activation_scaling_factor, get_kv_cache_dtype, @@ -77,7 +83,10 @@ has_mcore = True -__all__ = ["export_mcore_gpt_to_hf", "import_mcore_gpt_from_hf"] +__all__ = [ + "export_mcore_gpt_to_hf", + "import_mcore_gpt_from_hf", +] # This path uses output_quantizer for KV cache quantization. @@ -109,13 +118,15 @@ def get_kv_cache_scaling_factor(kv_module: nn.Module) -> torch.Tensor: def get_quantized_state( module: torch.nn.Module, - dtype: torch.dtype = torch.bfloat16, + dtype: torch.dtype = torch.float16, + export_vllm_fq_weights_qstate: bool = False, ) -> tuple[dict[str, torch.Tensor], str, int]: """Return a state_dict, quantization format, and block_size of the module. Args: module: The target module to perform real quantization. dtype: The default data type. + export_vllm_fq_weights_qstate: Whether to export the weights in bf16 and amax values. Returns: Tuple: state_dict, quantization format, and block_size of the module. @@ -136,6 +147,9 @@ def get_quantized_state( if hasattr(module, "expert_bias") and module.expert_bias is not None: name_to_value["expert_bias"] = module.expert_bias.to(dtype).cpu() + if export_vllm_fq_weights_qstate: + return get_mcore_vllm_fq_quantized_state(module, name_to_value, dtype) + # Getting the weight scales weight_scale = get_weight_scaling_factor(module) weight_scale_2 = get_weight_scaling_factor_2(module) @@ -187,6 +201,7 @@ def __init__( dtype=torch.bfloat16, trust_remote_code: bool = True, moe_router_dtype: torch.dtype | None = None, + export_vllm_fq_weights_qstate: bool = False, ): """Create a GPTModel exporter instance.""" if not isinstance(model, (GPTModel, MambaModel, LLaVAModel)): @@ -222,6 +237,7 @@ def __init__( self.model = model.language_model if self.is_multimodal else model self.dtype = dtype self.trust_remote_code = trust_remote_code + self.export_vllm_fq_weights_qstate = export_vllm_fq_weights_qstate self.arch = self._hf_config.architectures[0] # TODO: May modify this later according to what quantization exported ckpt is, currently only support BF16. if self.arch == "GptOssForCausalLM": @@ -331,7 +347,11 @@ def save_pretrained( # Main export process state_dict = self.extra_state_dict if self.export_extra_modules else self.state_dict - quantization_format = get_quantization_format(self.model) + quantization_format = ( + get_quantization_format(self.model) + if not self.export_vllm_fq_weights_qstate + else QUANTIZATION_NONE + ) quantization = None kv_cache_quantization = None @@ -378,7 +398,7 @@ def save_pretrained( except (OSError, ValueError, ImportError): pass - if is_last_stage_main_rank: + if is_last_stage_main_rank and not self.export_vllm_fq_weights_qstate: hf_quant_config = { "producer": { "name": "modelopt", @@ -398,6 +418,9 @@ def save_pretrained( and self.is_multimodal and pretrained_model_name_or_path is not None ): + assert not self.export_vllm_fq_weights_qstate, ( + "Exporting weights in bf16 and amax values is not supported for multimodal models" + ) hf_checkpoint_path = Path(pretrained_model_name_or_path) if not hf_checkpoint_path.is_dir(): hf_checkpoint_path = tempfile.gettempdir() + "/" + pretrained_model_name_or_path @@ -466,6 +489,9 @@ def save_pretrained( torch.distributed.barrier() if self.export_extra_modules: + assert not self.export_vllm_fq_weights_qstate, ( + "Exporting weights in bf16 and amax values is not supported for extra modules" + ) if is_last_stage_main_rank: save_file( state_dict, save_directory + "/model.safetensors", metadata={"format": "pt"} @@ -473,6 +499,45 @@ def save_pretrained( torch.distributed.barrier() return + if self.export_vllm_fq_weights_qstate: + state_dict = gather_mcore_vllm_fq_quantized_state_dict(state_dict, save_directory) + + if ( + is_last_stage_main_rank + and self._hf_config is not None + and pretrained_model_name_or_path is not None + ): + # For models that keep configuration and modeling files as part of the checkpoint, + # we need to copy them to the export directory for seamless integration with inference + # frameworks. + hf_checkpoint_path = Path(pretrained_model_name_or_path) + model_type = getattr(self._hf_config, "model_type", None) + + if hf_checkpoint_path.is_dir(): + # Local directory - files should be there + config_file = hf_checkpoint_path / f"configuration_{model_type}.py" + modeling_file = hf_checkpoint_path / f"modeling_{model_type}.py" + else: + # Remote model ID - download from HuggingFace Hub (cached automatically) + try: + config_file = hf_hub_download( + repo_id=pretrained_model_name_or_path, + filename=f"configuration_{model_type}.py", + ) + except Exception: + config_file = "" + try: + modeling_file = hf_hub_download( + repo_id=pretrained_model_name_or_path, filename=f"modeling_{model_type}.py" + ) + except Exception: + modeling_file = "" + + if config_file and os.path.exists(config_file): + shutil.copy(config_file, f"{save_directory}/configuration_{model_type}.py") + if modeling_file and os.path.exists(modeling_file): + shutil.copy(modeling_file, f"{save_directory}/modeling_{model_type}.py") + save_safetensors(state_dict, save_directory) @property @@ -544,7 +609,9 @@ def _name_remapping( self._state_dict[prefix] = module return - name_to_value, qformat, block_size = get_quantized_state(module, dtype) + name_to_value, qformat, block_size = get_quantized_state( + module, dtype, self.export_vllm_fq_weights_qstate + ) weight = name_to_value.pop("weight") weight_scale, weight_scale_2 = self._get_weight_scales(name_to_value, qformat) @@ -576,7 +643,9 @@ def _name_remapping( def _gated_mlp_slicing( self, module, prefix, gate_proj_name="gate_proj", up_proj_name="up_proj" ): - name_to_value, qformat, block_size = get_quantized_state(module, self.dtype) + name_to_value, qformat, block_size = get_quantized_state( + module, self.dtype, self.export_vllm_fq_weights_qstate + ) weight = name_to_value.pop("weight") weight_scale, weight_scale_2 = self._get_weight_scales(name_to_value, qformat) @@ -641,7 +710,9 @@ def _qkv_slicing( k_scale_name="k_scale", v_scale_name="v_scale", ): - name_to_value, qformat, block_size = get_quantized_state(module, self.dtype) + name_to_value, qformat, block_size = get_quantized_state( + module, self.dtype, self.export_vllm_fq_weights_qstate + ) q_proj_prefix = prefix + q_proj_name + "." k_proj_prefix = prefix + k_proj_name + "." @@ -764,7 +835,7 @@ def _pack_name_remapping(self, module, prefix, layer_type=None): for expert in module: assert layer_type is not None, "layer_type is required for pack_name_remapping" name_to_value, qformat, block_size = get_quantized_state( - getattr(expert, layer_type), self.dtype + getattr(expert, layer_type), self.dtype, self.export_vllm_fq_weights_qstate ) weight = name_to_value.pop("weight") weight_scale, weight_scale_2 = self._get_weight_scales(name_to_value, qformat) @@ -830,7 +901,7 @@ def _pack_name_remapping_gpt_oss(self, module, prefix, layer_type=None): for expert in module: assert layer_type is not None, "layer_type is required for pack_name_remapping" name_to_value, qformat, block_size = get_quantized_state( - getattr(expert, layer_type), self.dtype + getattr(expert, layer_type), self.dtype, self.export_vllm_fq_weights_qstate ) weight = name_to_value.pop("weight") bias = name_to_value.pop("bias", None) @@ -1170,6 +1241,7 @@ def export_mcore_gpt_to_hf( dtype: torch.dtype = torch.bfloat16, export_dir: Path | str = tempfile.gettempdir(), moe_router_dtype: torch.dtype | None = None, + export_vllm_fq_weights_qstate: bool = False, ): """Export Megatron Core GPTModel to unified checkpoint and save to export_dir. @@ -1183,6 +1255,7 @@ def export_mcore_gpt_to_hf( eagle_module. Otherwise, only export the base model. dtype: The weights data type to export the unquantized layers. export_dir: The target export path. + export_vllm_fq_weights_qstate: If True, export the weights in bf16 and amax values. """ exporter = GPTModelExporter( model, @@ -1190,6 +1263,7 @@ def export_mcore_gpt_to_hf( export_extra_modules=export_extra_modules, dtype=dtype, moe_router_dtype=moe_router_dtype, + export_vllm_fq_weights_qstate=export_vllm_fq_weights_qstate, ) exporter.save_pretrained(export_dir, pretrained_model_name_or_path) diff --git a/modelopt/torch/quantization/plugins/vllm.py b/modelopt/torch/quantization/plugins/vllm.py index c35f7760b..9676d2c89 100644 --- a/modelopt/torch/quantization/plugins/vllm.py +++ b/modelopt/torch/quantization/plugins/vllm.py @@ -21,14 +21,21 @@ import vllm.model_executor.layers.fused_moe.layer as vllm_fused_moe_layer import vllm.model_executor.layers.linear as vllm_linear -try: - import vllm.model_executor.layers.fused_moe.shared_fused_moe as vllm_shared_fused_moe_layer -except ImportError: - vllm_shared_fused_moe_layer = None - from ...utils.distributed import ParallelState from ..nn import QuantLinearConvBase, QuantModule, QuantModuleRegistry, TensorQuantizer +# Try multiple import paths for vLLM compatibility across versions +vllm_shared_fused_moe_layer = None +for module_path in [ + "vllm.model_executor.layers.fused_moe.shared_fused_moe", # 0.11.0+ + "vllm.model_executor.layers.shared_fused_moe.shared_fused_moe", # 0.10.2 +]: + try: + vllm_shared_fused_moe_layer = importlib.import_module(module_path) + break + except ImportError: + continue + vllm_fused_moe_package = importlib.import_module("vllm.model_executor.layers.fused_moe.fused_moe") diff --git a/tests/gpu/torch/export/test_vllm_fakequant_export.py b/tests/gpu/torch/export/test_vllm_fakequant_export.py new file mode 100644 index 000000000..127e0f57e --- /dev/null +++ b/tests/gpu/torch/export/test_vllm_fakequant_export.py @@ -0,0 +1,194 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +from copy import deepcopy +from functools import partial + +import pytest +import torch +from _test_utils.import_helper import skip_if_no_megatron +from _test_utils.torch.distributed.utils import spawn_multiprocess_job +from _test_utils.torch.megatron.models import get_mcore_gpt_model +from _test_utils.torch.transformers_models import create_tiny_llama_dir +from transformers import AutoModelForCausalLM + +import modelopt.torch.quantization as mtq +from modelopt.torch.export.unified_export_hf import export_hf_checkpoint +from modelopt.torch.export.unified_export_megatron import export_mcore_gpt_to_hf + +skip_if_no_megatron(apex_or_te_required=True) + + +@pytest.mark.parametrize("quant_cfg", [mtq.FP8_DEFAULT_CFG]) +def test_hf_vllm_export(tmp_path, quant_cfg): + """Test HuggingFace model export for vLLM with fake quantization. + + This test verifies: + 1. Model weights match before and after export + 2. quant_amax.pth file is created, huggingface config file does not exist + 3. Amax values are correctly extracted and saved in quant_amax.pth file + """ + + # Create a tiny LLaMA model for testing + tiny_model_dir = create_tiny_llama_dir(tmp_path, with_tokenizer=True, num_hidden_layers=2) + + # Load the model + model = AutoModelForCausalLM.from_pretrained(tiny_model_dir) + model = model.cuda() + model.eval() + + # Quantize the model + def forward_loop(model): + input_ids = torch.randint(0, model.config.vocab_size, (1, 128)).cuda() + with torch.no_grad(): + model(input_ids) + + model = mtq.quantize(model, quant_cfg, forward_loop) + + model_state_dict = deepcopy(model.state_dict()) + + # Export directory + export_dir = tmp_path / "vllm_export" + export_dir.mkdir(exist_ok=True) + + # Export for vLLM + export_hf_checkpoint(model, export_dir=export_dir, export_vllm_fq_weights_qstate=True) + + # check if quant_amax.pth file exists + quant_amax_file = export_dir / "quant_amax.pth" + assert quant_amax_file.exists(), f"quant_amax.pth file should be created in {export_dir}" + + # make sure hf_quant_config.json file does not exist + hf_quant_config_file = export_dir / "hf_quant_config.json" + assert not hf_quant_config_file.exists(), ( + f"hf_quant_config.json file should not be created in {export_dir}" + ) + + # check weights match before and after export + model_after = AutoModelForCausalLM.from_pretrained(export_dir) + model_after = model_after.cuda() + model_after.eval() + model_after_state_dict = model_after.state_dict() + amax_state_dict = {} + for key, param in model_state_dict.items(): + if key.endswith("_amax"): + amax_state_dict[key] = param + continue + + assert torch.allclose(param, model_after_state_dict[key], atol=1e-6), ( + f"Weight mismatch for {key}: " + f"before shape={param.shape}, after shape={model_after_state_dict[key].shape}, " + f"max diff={torch.abs(param - model_after_state_dict[key]).max()}" + ) + + # Verify amax values are correct + amax_dict = torch.load(quant_amax_file) + assert len(amax_dict) > 0, "amax_dict should not be empty" + assert amax_dict.keys() == amax_state_dict.keys(), ( + "amax keys mismatch between before and after export" + ) + + +def _test_mcore_vllm_export(tmp_path, quant_cfg, rank, size): + """Test megatron-core model export for vLLM with fake quantization.""" + # Create a tiny mcore GPT model + num_layers = 2 + hidden_size = 64 + num_attention_heads = 8 + num_query_groups = size + ffn_hidden_size = 128 + max_sequence_length = 32 + vocab_size = 64 + + model = get_mcore_gpt_model( + tensor_model_parallel_size=size, + pipeline_model_parallel_size=1, + initialize_megatron=True, + num_layers=num_layers, + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + num_query_groups=num_query_groups, + ffn_hidden_size=ffn_hidden_size, + max_sequence_length=max_sequence_length, + vocab_size=vocab_size, + activation_func="swiglu", + normalization="RMSNorm", + transformer_impl="modelopt", + ).cuda() + model.eval() + + # Quantize the model + def forward_loop(model): + batch_size = 1 + seq_len = 32 + input_ids = torch.randint(0, vocab_size, (batch_size, seq_len)).cuda() + position_ids = torch.arange(seq_len).unsqueeze(0).cuda() + # Create causal attention mask + attention_mask = torch.tril(torch.ones((1, 1, seq_len, seq_len))).cuda() + attention_mask = attention_mask < 0.5 # Convert to boolean mask + with torch.no_grad(): + model(input_ids, position_ids, attention_mask) + + model = mtq.quantize(model, quant_cfg, forward_loop) + # Create HF config for export + pretrained_config = { + "architectures": ["LlamaForCausalLM"], + "attention_bias": False, + "hidden_size": hidden_size, + "intermediate_size": ffn_hidden_size, + "max_position_embeddings": max_sequence_length, + "model_type": "llama", + "num_attention_heads": num_attention_heads, + "num_hidden_layers": num_layers, + "num_key_value_heads": num_query_groups, + "torch_dtype": "bfloat16", + } + + with open(tmp_path / "config.json", "w") as f: + json.dump(pretrained_config, f) + + # Export directory + export_dir = tmp_path / "vllm_export" + export_dir.mkdir(exist_ok=True) + + # Export for vLLM + export_mcore_gpt_to_hf( + model, + pretrained_model_name_or_path=tmp_path, + dtype=torch.bfloat16, + export_dir=str(export_dir), + export_vllm_fq_weights_qstate=True, + ) + + # check if quant_amax.pth file exists + quant_amax_file = export_dir / "quant_amax.pth" + assert quant_amax_file.exists(), f"quant_amax.pth file should be created in {export_dir}" + + # make sure hf_quant_config.json file does not exist + hf_quant_config_file = export_dir / "hf_quant_config.json" + assert not hf_quant_config_file.exists(), ( + f"hf_quant_config.json file should not be created in {export_dir}" + ) + + +@pytest.mark.parametrize("quant_cfg", [mtq.FP8_DEFAULT_CFG]) +def test_mcore_vllm_export(tmp_path, quant_cfg): + """Wrapper test function for mcore vLLM export.""" + spawn_multiprocess_job( + size=1, + job=partial(_test_mcore_vllm_export, tmp_path, quant_cfg), + backend="nccl", + )