Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ Model Optimizer Changelog (Linux)
- Add flag ``trt_plugins_precision`` in ONNX autocast to indicate custom ops precision. This is similar to the flag already existing in the quantization workflow.
- Add support for PyTorch Geometric quantization.
- Add per tensor and per channel MSE calibrator support.
- Added support for PTQ/QAT checkpoint export and loading for running fakequant evaluation in vLLM. See ``examples/vllm_serve/README.md#load-qatptq-model-and-serve-in-vllm-wip`` for more details.

**Documentation**

Expand Down
15 changes: 10 additions & 5 deletions examples/vllm_serve/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,16 +55,19 @@ lm_eval --model local-completions --tasks gsm8k --model_args model=<model_name>,

## Load QAT/PTQ model and serve in vLLM (WIP)

Overwrite the calibrated amax value with prepared values from either PTQ/QAT. This is only tested for Llama3.1
Overwrite the calibrated amax value with prepared values from either QAT/PTQ.

Step 1: convert amax to merged amax, using llama3.1 as an example:
Step 1: export the model with bf16 weights and amax values.

- For HF model set `export_bf16_weights_amax` to export the model with function `modelopt.torch.export.unified_export_hf.export_hf_checkpoint`.
- For MCore model use `export_bf16_weights_amax` to export the model with function `modelopt.torch.export.unified_export_megatron.export_mcore_gpt_to_hf`.

Step 2: configure <quant_amax.pth> from exported model using AMAX_FILE_PATH environment variable in step 1. For example:

```bash
python convert_amax_hf2vllm.py -i <amax.pth> -o <vllm_amax.pth>
AMAX_FILE_PATH=<vllm_amax.pth> QUANT_CFG=<quant_config> python vllm_serve_fakequant.py <model_path> -tp 8 --host 0.0.0.0 --port 8000
```

Step 2: add `<vllm_amax.pth>` to `quant_config` in `vllm_serve_fakequant.py`

## Important Notes

**Amax Synchronization across Tensor Parallel (TP):**
Expand All @@ -85,3 +88,5 @@ torch.distributed.barrier()
## Known Problems

1. AWQ is not yet supported in vLLM.
2. PTQ/QAT checkpoint doesn't work with KV Cache quantization enabled.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Kinjal for documenting this. Create a jira ticket to address this - https://jirasw.nvidia.com/browse/OMNIML-3051

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for creating the ticket

3. Mixed precision checkpoint doesn't work currently.
213 changes: 0 additions & 213 deletions examples/vllm_serve/convert_amax_hf2vllm.py

This file was deleted.

106 changes: 105 additions & 1 deletion examples/vllm_serve/fakequant_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@

import dataclasses
import os
import re
import warnings
from collections import defaultdict
from contextlib import contextmanager
from typing import Any

Expand All @@ -30,6 +32,99 @@
from modelopt.torch.utils.dataset_utils import get_dataset_dataloader


def convert_amax_hf2vllm(
hf_state_dict: dict[str, torch.Tensor], fuse_experts: bool = False
) -> dict[str, torch.Tensor]:
"""
Convert amax values from HuggingFace format to vLLM format.

This function merges:
- q_proj, k_proj, v_proj amax values into qkv_proj (taking max)
- gate_proj, up_proj amax values into gate_up_proj (taking max)

Args:
hf_state_dict: HuggingFace state dict containing amax values

Returns:
vLLM format state dict with merged amax values
"""
vllm_state_dict = {}

# Group keys by their base pattern (without the specific projection name)
merge_groups = defaultdict(list)

for key, value in hf_state_dict.items():
if "_amax" not in key:
# Copy non-amax keys as-is
vllm_state_dict[key] = value
continue

# Check if this is a q/k/v projection that needs merging
qkv_match = re.search(r"(.*\.)([qkv])_proj(\..+_amax)$", key)
if qkv_match:
base_pattern = qkv_match.group(1) + "qkv_proj" + qkv_match.group(3)
merge_groups[base_pattern].append((key, value))
continue

# Check if this is an expert gate/up projection
# Pattern: model.layers.0.mlp.experts.*.gate_proj.input_quantizer._amax and
# model.layers.0.mlp.experts.*.up_proj.input_quantizer._amax
# Maps to: model.layers.0.mlp.experts.w13_input_quantizer._amax
expert_gate_up_match = (
"mixer" not in key
and fuse_experts
and re.search(r"(.*\.experts)\.\d+\.(gate|up)_proj\.([^.]+_quantizer\._amax)$", key)
)
if expert_gate_up_match:
base_pattern = expert_gate_up_match.group(1) + ".w13_" + expert_gate_up_match.group(3)
merge_groups[base_pattern].append((key, value))
continue

# Check if this is a non-expert gate/up projection that needs merging
gate_up_match = (
"mixer" not in key
and "experts" not in key
and re.search(r"(.*\.)(gate|up)_proj(\..+_amax)$", key)
)
if gate_up_match:
base_pattern = gate_up_match.group(1) + "gate_up_proj" + gate_up_match.group(3)
merge_groups[base_pattern].append((key, value))
continue

# Check if this is an expert down_proj
# Pattern: model.layers.0.mlp.experts.*.down_proj.input_quantizer._amax
# Maps to: model.layers.0.mlp.experts.w2_input_quantizer._amax
expert_down_match = (
"mixer" not in key
and fuse_experts
and re.search(r"(.*\.experts)\.\d+\.down_proj\.([^.]+_quantizer\._amax)$", key)
)
if expert_down_match:
base_pattern = expert_down_match.group(1) + ".w2_" + expert_down_match.group(2)
merge_groups[base_pattern].append((key, value))
continue

# Copy other amax keys as-is (like o_proj, down_proj)
vllm_state_dict[key] = value

# Merge grouped amax values by taking the maximum
for merged_key, key_value_pairs in merge_groups.items():
if len(key_value_pairs) > 1:
# Take the maximum across all values for this merged key
values = [value for _, value in key_value_pairs]
merged_value = torch.stack(values).max(dim=0)[0]
vllm_state_dict[merged_key] = merged_value
print(f"Merged {len(key_value_pairs)} keys into {merged_key}")
for orig_key, _ in key_value_pairs:
print(f" - {orig_key}")
else:
# Single key, just rename it
_, value = key_value_pairs[0]
vllm_state_dict[merged_key] = value

return vllm_state_dict
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what if we miss a merging rule here? What will happen in VLLM?



@contextmanager
def disable_compilation(model):
do_not_compile = True
Expand Down Expand Up @@ -154,8 +249,17 @@ def calibrate_loop(model: Any = None) -> None:
if amax_file_path:
print(f"Loading amax values from {amax_file_path}")
saved_amax_dict = torch.load(amax_file_path)
current_state_dict = model.state_dict()
# convert amax keys to vLLM format
if hasattr(self.model_runner.model, "hf_to_vllm_mapper"):
saved_amax_dict = self.model_runner.model.hf_to_vllm_mapper.apply_dict(saved_amax_dict)
saved_amax_dict = {
key.replace("quantizer_amax", "quantizer._amax"): value
for key, value in saved_amax_dict.items()
if key.endswith("quantizer_amax")
}
saved_amax_dict = convert_amax_hf2vllm(saved_amax_dict, fuse_experts=True)

current_state_dict = model.state_dict()
# Count amax keys in checkpoint and model
checkpoint_amax_keys = [key for key in saved_amax_dict if key.endswith("_amax")]
model_amax_keys = [key for key in current_state_dict if key.endswith("_amax")]
Expand Down
Loading