Skip to content

Commit 404194a

Browse files
authored
Merge branch 'main' into shengliangx/fix-extra-args
2 parents 03a9148 + a5025a2 commit 404194a

File tree

12 files changed

+1005
-30
lines changed

12 files changed

+1005
-30
lines changed

CHANGELOG.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ Model Optimizer Changelog (Linux)
1616
- Add FP8/NVFP4 KV cache quantization support for Megatron Core models.
1717
- Add flag ``trt_plugins_precision`` in ONNX autocast to indicate custom ops precision. This is similar to the flag already existing in the quantization workflow.
1818
- Add support for PyTorch Geometric quantization.
19+
- Add per tensor and per channel MSE calibrator support.
1920

2021
**Documentation**
2122

modelopt/torch/export/quant_utils.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -779,7 +779,36 @@ def to_quantized_weight(
779779
)[0]._quantized_data
780780

781781
if quantization == QUANTIZATION_FP8_PC_PT:
782-
return (weight / weights_scaling_factor.unsqueeze(-1)).to(torch.float8_e4m3fn)
782+
if weight.dim() == 3:
783+
# Handle different scale tensor shapes
784+
if weights_scaling_factor.dim() == 1:
785+
# Per-expert scaling only: (num_experts,) -> (num_experts, 1, 1)
786+
return (weight / weights_scaling_factor[:, None, None]).to(torch.float8_e4m3fn)
787+
elif weights_scaling_factor.dim() == 2:
788+
# Per-channel scaling: check which dimension matches
789+
if weights_scaling_factor.shape[0] != weight.shape[0]:
790+
raise ValueError(
791+
f"First dimension (num_experts) mismatch for FP8_PC_PT quantization. "
792+
f"weight shape: {weight.shape}, scale shape: {weights_scaling_factor.shape}"
793+
)
794+
if weight.shape[-1] == weight.shape[-2]:
795+
raise ValueError(
796+
f"Ambiguous scaling dimension for FP8_PC_PT quantization with square weight matrix. "
797+
f"weight shape: {weight.shape}, scale shape: {weights_scaling_factor.shape}. "
798+
f"Cannot determine if scaling should be applied to input_dim or output_dim."
799+
)
800+
if weights_scaling_factor.shape[-1] == weight.shape[-1]:
801+
# (num_experts, input_dim) -> (num_experts, 1, input_dim), BMM-style
802+
return (weight / weights_scaling_factor.unsqueeze(-2)).to(torch.float8_e4m3fn)
803+
elif weights_scaling_factor.shape[-1] == weight.shape[-2]:
804+
# (num_experts, output_dim) -> (num_experts, output_dim, 1), Standard MoE case
805+
return (weight / weights_scaling_factor.unsqueeze(-1)).to(torch.float8_e4m3fn)
806+
else:
807+
raise ValueError(
808+
f"Cannot determine correct unsqueeze dimension for FP8_PC_PT quantization. "
809+
f"weight shape: {weight.shape}, scale shape: {weights_scaling_factor.shape}"
810+
)
811+
return (weight / weights_scaling_factor[:, None]).to(torch.float8_e4m3fn)
783812

784813
if quantization in [QUANTIZATION_INT4_AWQ, QUANTIZATION_W4A8_AWQ]:
785814
return pack_int4_in_uint8(weight, weights_scaling_factor)

modelopt/torch/export/unified_export_hf.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
KV_CACHE_NVFP4_AFFINE,
5151
QUANTIZATION_FP8,
5252
QUANTIZATION_FP8_PB_REAL,
53+
QUANTIZATION_FP8_PC_PT,
5354
QUANTIZATION_NONE,
5455
QUANTIZATION_NVFP4,
5556
QUANTIZATION_NVFP4_AWQ,
@@ -327,13 +328,15 @@ def _export_quantized_weight(
327328
weight_scale_2: torch.Tensor | None = getattr(sub_module, quantizer_attrs.weight_scale_2, None)
328329

329330
# Transpose weight for bmm-style expert quantization (llama4, gpt-oss)
331+
# Check if this is a BMM-style expert weight that needs transposition
332+
is_bmm_expert_weight = weight.dim() == 3 and any(
333+
expert_type in type(sub_module).__name__
334+
for expert_type in ["Llama4TextExperts", "GptOssExperts"]
335+
)
336+
330337
if quantization_format in [QUANTIZATION_NVFP4, QUANTIZATION_NVFP4_AWQ]:
331338
# Transpose weight from (num_experts, input_dim, output_dim) to (num_experts, output_dim, input_dim)
332339
# for NVFP4 quantization functions that expect input_dim as the last dimension for block quantization
333-
is_bmm_expert_weight = weight.dim() == 3 and any(
334-
expert_type in type(sub_module).__name__
335-
for expert_type in ["Llama4TextExperts", "GptOssExperts"]
336-
)
337340
weight, _ = maybe_transpose_expert_weight_dimensions(
338341
weight, is_bmm_expert_weight=is_bmm_expert_weight
339342
)
@@ -354,6 +357,24 @@ def _export_quantized_weight(
354357
quantized_weight, weight_scale = maybe_transpose_expert_weight_dimensions(
355358
quantized_weight, weight_scale, is_bmm_expert_weight=is_bmm_expert_weight
356359
)
360+
elif quantization_format == QUANTIZATION_FP8_PC_PT and is_bmm_expert_weight:
361+
# For FP8_PC_PT with BMM-style experts, transpose only the weight (not weight_scale)
362+
weight, _ = maybe_transpose_expert_weight_dimensions(
363+
weight, is_bmm_expert_weight=is_bmm_expert_weight
364+
)
365+
366+
quantized_weight = to_quantized_weight(
367+
weight.to(dtype),
368+
weight_scale,
369+
quantization_format,
370+
weight_scale_2,
371+
block_size,
372+
)
373+
374+
# Transpose back to original BMM format
375+
quantized_weight, _ = maybe_transpose_expert_weight_dimensions(
376+
quantized_weight, is_bmm_expert_weight=is_bmm_expert_weight
377+
)
357378
else:
358379
quantized_weight = to_quantized_weight(
359380
weight.to(dtype),

modelopt/torch/quantization/calib/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,3 +23,4 @@
2323
from .calibrator import *
2424
from .histogram import *
2525
from .max import *
26+
from .mse import *
Lines changed: 177 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,177 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Calibrator that returns the MSE amax of all collected tensors."""
17+
18+
from collections.abc import Callable
19+
20+
import torch
21+
import torch.nn.functional as F
22+
23+
from .. import utils as quant_utils
24+
from .calibrator import _Calibrator
25+
26+
__all__ = ["MseCalibrator"]
27+
28+
29+
class MseCalibrator(_Calibrator):
30+
"""Per-tensor and per-channel MSE amax search that minimizes error between x and quantized x."""
31+
32+
def __init__(
33+
self,
34+
amax: torch.Tensor,
35+
axis: int | tuple | list | None = None,
36+
num_steps: int = 10,
37+
start_multiplier: float = 0.25,
38+
stop_multiplier: float = 4.0,
39+
quant_func: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None,
40+
error_func: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None,
41+
):
42+
"""Initialize MSE calibrator.
43+
44+
Args:
45+
amax: Initial amax value (required).
46+
axis: Quantization axis. None means per-tensor quantization.
47+
num_steps: Number of amax candidates to try.
48+
start_multiplier: Starting multiplier for amax search.
49+
stop_multiplier: Ending multiplier for amax search.
50+
quant_func: Function that quantizes input tensor given an amax value.
51+
Should have signature: quant_func(x, amax) -> quantized_x.
52+
error_func: Function to compute error between x and xq.
53+
Default is F.mse_loss(x, xq, reduction='none').
54+
"""
55+
super().__init__(num_bits=None, axis=axis, unsigned=None)
56+
self._initial_amax = amax
57+
self._num_steps = num_steps
58+
self._start_multiplier = start_multiplier
59+
self._stop_multiplier = stop_multiplier
60+
self._quant_func = quant_func
61+
self._error_func = error_func
62+
self._losses_sum = [None] * num_steps
63+
self._candidate_amaxs = [None] * num_steps
64+
65+
self._amax = None
66+
67+
@torch.no_grad()
68+
def collect(self, x: torch.Tensor):
69+
"""Collect input tensor statistics and compute losses for MSE calibration.
70+
71+
Args:
72+
x: Input tensor.
73+
"""
74+
if self._quant_func is None:
75+
raise RuntimeError(
76+
"Quantization function not set. Msecalibrator requires a quant_func to be provided."
77+
)
78+
79+
x = x.detach().to(dtype=torch.float32)
80+
81+
device = x.device
82+
multipliers = torch.linspace(
83+
self._start_multiplier, self._stop_multiplier, steps=self._num_steps, device=device
84+
)
85+
86+
# Get reduce axis for per-channel quantization
87+
reduce_axis = quant_utils.convert_quantization_axis_to_reduce_axis(x, self._axis)
88+
89+
for step, multiplier in enumerate(multipliers):
90+
candidate_amax = self._initial_amax * multiplier
91+
xq = self._quant_func(x, candidate_amax)
92+
93+
if self._error_func is not None:
94+
error = self._error_func(x, xq)
95+
else:
96+
error = F.mse_loss(x, xq, reduction="none")
97+
98+
loss = quant_utils.reduce_sum(error, axis=reduce_axis, keepdims=False)
99+
100+
if self._candidate_amaxs[step] is None:
101+
self._candidate_amaxs[step] = candidate_amax
102+
103+
if self._losses_sum[step] is None:
104+
self._losses_sum[step] = loss.clone()
105+
else:
106+
self._losses_sum[step] += loss
107+
108+
def reset(self):
109+
"""Reset the stored losses and amax value."""
110+
self._losses_sum = [None] * self._num_steps
111+
self._candidate_amaxs = [None] * self._num_steps
112+
self._amax = None
113+
114+
@torch.no_grad()
115+
def compute_amax(self, verbose: bool = False):
116+
"""Return the amax value that minimizes quantization error.
117+
118+
Args:
119+
verbose: If True, print the ratio of best_amax to initial_amax.
120+
"""
121+
if not any(loss_sum is not None for loss_sum in self._losses_sum):
122+
return None
123+
124+
# Check if this is per-tensor or per-channel based on the first loss
125+
first_loss_sum = None
126+
for loss_sum in self._losses_sum:
127+
if loss_sum is not None:
128+
first_loss_sum = loss_sum
129+
break
130+
131+
if first_loss_sum is None:
132+
return None
133+
134+
# Collect losses for all steps
135+
losses_per_step = []
136+
for step in range(self._num_steps):
137+
if self._losses_sum[step] is not None:
138+
losses_per_step.append(self._losses_sum[step])
139+
# No data for this step, use inf
140+
elif first_loss_sum.ndim == 0:
141+
losses_per_step.append(torch.tensor(float("inf"), device=first_loss_sum.device))
142+
else:
143+
losses_per_step.append(torch.full_like(first_loss_sum, float("inf")))
144+
145+
# Stack to get [num_steps] for per-tensor or [num_steps, num_channels] for per-channel
146+
losses_per_step = torch.stack(losses_per_step)
147+
148+
# Find best step(s): scalar for per-tensor, [num_channels] for per-channel
149+
best_steps = torch.argmin(losses_per_step, dim=0)
150+
151+
# Stack candidate amaxs and select based on best_steps
152+
candidate_amaxs = torch.stack(self._candidate_amaxs)
153+
154+
if first_loss_sum.ndim == 0:
155+
# Per-tensor case: best_steps is a scalar
156+
self._amax = self._candidate_amaxs[best_steps.item()]
157+
else:
158+
# Per-channel case: best_steps is a tensor
159+
num_channels = best_steps.shape[0]
160+
self._amax = candidate_amaxs[
161+
best_steps, torch.arange(num_channels, device=best_steps.device)
162+
]
163+
self._amax = self._amax.reshape(self._initial_amax.shape)
164+
165+
if verbose:
166+
ratio = self._amax / self._initial_amax
167+
if ratio.ndim == 0:
168+
print(f"MSE Calibrator: best_amax/initial_amax ratio = {ratio.item():.4f}")
169+
else:
170+
print(
171+
f"MSE Calibrator: best_amax/initial_amax ratio - "
172+
f"mean: {ratio.mean().item():.4f}, "
173+
f"min: {ratio.min().item():.4f}, "
174+
f"max: {ratio.max().item():.4f}"
175+
)
176+
177+
return self._amax

modelopt/torch/quantization/config.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -981,6 +981,45 @@ class MaxCalibConfig(QuantizeAlgorithmConfig):
981981
)
982982

983983

984+
class MseCalibConfig(QuantizeAlgorithmConfig):
985+
"""Configuration for per-tensor MSE calibration.
986+
987+
Finds a scale s (via amax a, with s = a / q_max) that minimizes the
988+
reconstruction error of a tensor after uniform Q→DQ:
989+
990+
s* = argmin_s E[(X - DQ(Q(X; s)))^2], X ∈ {weights | activations}
991+
"""
992+
993+
method: Literal["mse"] = ModeloptField("mse")
994+
995+
num_steps: int | None = ModeloptField(
996+
default=10,
997+
ge=1,
998+
title="Number of amax candidates to try.",
999+
description="Number of amax candidates to search over for MSE minimization.",
1000+
)
1001+
1002+
start_multiplier: float | None = ModeloptField(
1003+
default=0.25,
1004+
gt=0.0,
1005+
title="Starting multiplier for amax search.",
1006+
description="Starting multiplier for amax search range (multiplies initial amax).",
1007+
)
1008+
1009+
stop_multiplier: float | None = ModeloptField(
1010+
default=4.0,
1011+
gt=0.0,
1012+
title="Ending multiplier for amax search.",
1013+
description="Ending multiplier for amax search range (multiplies initial amax).",
1014+
)
1015+
1016+
distributed_sync: bool | None = ModeloptField(
1017+
default=True,
1018+
title="Whether to sync the amax across the distributed processes.",
1019+
description="If True, the amax will be synced across the distributed processes.",
1020+
)
1021+
1022+
9841023
class SmoothQuantCalibConfig(QuantizeAlgorithmConfig):
9851024
"""The config for ``smoothquant`` algorithm (SmoothQuant).
9861025

modelopt/torch/quantization/mode.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
AWQLiteCalibConfig,
3939
CompressConfig,
4040
MaxCalibConfig,
41+
MseCalibConfig,
4142
QuantizeAlgoCfgType,
4243
QuantizeAlgorithmConfig,
4344
QuantizeConfig,
@@ -54,7 +55,7 @@
5455
restore_svdquant_model,
5556
update_quantize_metadata,
5657
)
57-
from .model_calib import awq, max_calibrate, smoothquant, svdquant
58+
from .model_calib import awq, max_calibrate, mse_calibrate, smoothquant, svdquant
5859

5960
__all__ = ["BaseCalibrateModeDescriptor"]
6061

@@ -363,6 +364,18 @@ def config_class(self) -> type[QuantizeAlgorithmConfig]:
363364
_calib_func = max_calibrate
364365

365366

367+
@CalibrateModeRegistry.register_mode
368+
class MseCalibrateModeDescriptor(BaseCalibrateModeDescriptor):
369+
"""Mode for mse calibration algorithm."""
370+
371+
@property
372+
def config_class(self) -> type[QuantizeAlgorithmConfig]:
373+
"""Specifies the config class for the mode."""
374+
return MseCalibConfig
375+
376+
_calib_func = mse_calibrate
377+
378+
366379
@CalibrateModeRegistry.register_mode
367380
class SmoothQuantModeDescriptor(BaseCalibrateModeDescriptor):
368381
"""Mode for smoothquant calibration algorithm."""

0 commit comments

Comments
 (0)