|
| 1 | +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. |
| 2 | +# SPDX-License-Identifier: Apache-2.0 |
| 3 | +# |
| 4 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | +# you may not use this file except in compliance with the License. |
| 6 | +# You may obtain a copy of the License at |
| 7 | +# |
| 8 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | +# |
| 10 | +# Unless required by applicable law or agreed to in writing, software |
| 11 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | +# See the License for the specific language governing permissions and |
| 14 | +# limitations under the License. |
| 15 | + |
| 16 | +"""Calibrator that returns the MSE amax of all collected tensors.""" |
| 17 | + |
| 18 | +from collections.abc import Callable |
| 19 | + |
| 20 | +import torch |
| 21 | +import torch.nn.functional as F |
| 22 | + |
| 23 | +from .. import utils as quant_utils |
| 24 | +from .calibrator import _Calibrator |
| 25 | + |
| 26 | +__all__ = ["MseCalibrator"] |
| 27 | + |
| 28 | + |
| 29 | +class MseCalibrator(_Calibrator): |
| 30 | + """Per-tensor and per-channel MSE amax search that minimizes error between x and quantized x.""" |
| 31 | + |
| 32 | + def __init__( |
| 33 | + self, |
| 34 | + amax: torch.Tensor, |
| 35 | + axis: int | tuple | list | None = None, |
| 36 | + num_steps: int = 10, |
| 37 | + start_multiplier: float = 0.25, |
| 38 | + stop_multiplier: float = 4.0, |
| 39 | + quant_func: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None, |
| 40 | + error_func: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None, |
| 41 | + ): |
| 42 | + """Initialize MSE calibrator. |
| 43 | +
|
| 44 | + Args: |
| 45 | + amax: Initial amax value (required). |
| 46 | + axis: Quantization axis. None means per-tensor quantization. |
| 47 | + num_steps: Number of amax candidates to try. |
| 48 | + start_multiplier: Starting multiplier for amax search. |
| 49 | + stop_multiplier: Ending multiplier for amax search. |
| 50 | + quant_func: Function that quantizes input tensor given an amax value. |
| 51 | + Should have signature: quant_func(x, amax) -> quantized_x. |
| 52 | + error_func: Function to compute error between x and xq. |
| 53 | + Default is F.mse_loss(x, xq, reduction='none'). |
| 54 | + """ |
| 55 | + super().__init__(num_bits=None, axis=axis, unsigned=None) |
| 56 | + self._initial_amax = amax |
| 57 | + self._num_steps = num_steps |
| 58 | + self._start_multiplier = start_multiplier |
| 59 | + self._stop_multiplier = stop_multiplier |
| 60 | + self._quant_func = quant_func |
| 61 | + self._error_func = error_func |
| 62 | + self._losses_sum = [None] * num_steps |
| 63 | + self._candidate_amaxs = [None] * num_steps |
| 64 | + |
| 65 | + self._amax = None |
| 66 | + |
| 67 | + @torch.no_grad() |
| 68 | + def collect(self, x: torch.Tensor): |
| 69 | + """Collect input tensor statistics and compute losses for MSE calibration. |
| 70 | +
|
| 71 | + Args: |
| 72 | + x: Input tensor. |
| 73 | + """ |
| 74 | + if self._quant_func is None: |
| 75 | + raise RuntimeError( |
| 76 | + "Quantization function not set. Msecalibrator requires a quant_func to be provided." |
| 77 | + ) |
| 78 | + |
| 79 | + x = x.detach().to(dtype=torch.float32) |
| 80 | + |
| 81 | + device = x.device |
| 82 | + multipliers = torch.linspace( |
| 83 | + self._start_multiplier, self._stop_multiplier, steps=self._num_steps, device=device |
| 84 | + ) |
| 85 | + |
| 86 | + # Get reduce axis for per-channel quantization |
| 87 | + reduce_axis = quant_utils.convert_quantization_axis_to_reduce_axis(x, self._axis) |
| 88 | + |
| 89 | + for step, multiplier in enumerate(multipliers): |
| 90 | + candidate_amax = self._initial_amax * multiplier |
| 91 | + xq = self._quant_func(x, candidate_amax) |
| 92 | + |
| 93 | + if self._error_func is not None: |
| 94 | + error = self._error_func(x, xq) |
| 95 | + else: |
| 96 | + error = F.mse_loss(x, xq, reduction="none") |
| 97 | + |
| 98 | + loss = quant_utils.reduce_sum(error, axis=reduce_axis, keepdims=False) |
| 99 | + |
| 100 | + if self._candidate_amaxs[step] is None: |
| 101 | + self._candidate_amaxs[step] = candidate_amax |
| 102 | + |
| 103 | + if self._losses_sum[step] is None: |
| 104 | + self._losses_sum[step] = loss.clone() |
| 105 | + else: |
| 106 | + self._losses_sum[step] += loss |
| 107 | + |
| 108 | + def reset(self): |
| 109 | + """Reset the stored losses and amax value.""" |
| 110 | + self._losses_sum = [None] * self._num_steps |
| 111 | + self._candidate_amaxs = [None] * self._num_steps |
| 112 | + self._amax = None |
| 113 | + |
| 114 | + @torch.no_grad() |
| 115 | + def compute_amax(self, verbose: bool = False): |
| 116 | + """Return the amax value that minimizes quantization error. |
| 117 | +
|
| 118 | + Args: |
| 119 | + verbose: If True, print the ratio of best_amax to initial_amax. |
| 120 | + """ |
| 121 | + if not any(loss_sum is not None for loss_sum in self._losses_sum): |
| 122 | + return None |
| 123 | + |
| 124 | + # Check if this is per-tensor or per-channel based on the first loss |
| 125 | + first_loss_sum = None |
| 126 | + for loss_sum in self._losses_sum: |
| 127 | + if loss_sum is not None: |
| 128 | + first_loss_sum = loss_sum |
| 129 | + break |
| 130 | + |
| 131 | + if first_loss_sum is None: |
| 132 | + return None |
| 133 | + |
| 134 | + # Collect losses for all steps |
| 135 | + losses_per_step = [] |
| 136 | + for step in range(self._num_steps): |
| 137 | + if self._losses_sum[step] is not None: |
| 138 | + losses_per_step.append(self._losses_sum[step]) |
| 139 | + # No data for this step, use inf |
| 140 | + elif first_loss_sum.ndim == 0: |
| 141 | + losses_per_step.append(torch.tensor(float("inf"), device=first_loss_sum.device)) |
| 142 | + else: |
| 143 | + losses_per_step.append(torch.full_like(first_loss_sum, float("inf"))) |
| 144 | + |
| 145 | + # Stack to get [num_steps] for per-tensor or [num_steps, num_channels] for per-channel |
| 146 | + losses_per_step = torch.stack(losses_per_step) |
| 147 | + |
| 148 | + # Find best step(s): scalar for per-tensor, [num_channels] for per-channel |
| 149 | + best_steps = torch.argmin(losses_per_step, dim=0) |
| 150 | + |
| 151 | + # Stack candidate amaxs and select based on best_steps |
| 152 | + candidate_amaxs = torch.stack(self._candidate_amaxs) |
| 153 | + |
| 154 | + if first_loss_sum.ndim == 0: |
| 155 | + # Per-tensor case: best_steps is a scalar |
| 156 | + self._amax = self._candidate_amaxs[best_steps.item()] |
| 157 | + else: |
| 158 | + # Per-channel case: best_steps is a tensor |
| 159 | + num_channels = best_steps.shape[0] |
| 160 | + self._amax = candidate_amaxs[ |
| 161 | + best_steps, torch.arange(num_channels, device=best_steps.device) |
| 162 | + ] |
| 163 | + self._amax = self._amax.reshape(self._initial_amax.shape) |
| 164 | + |
| 165 | + if verbose: |
| 166 | + ratio = self._amax / self._initial_amax |
| 167 | + if ratio.ndim == 0: |
| 168 | + print(f"MSE Calibrator: best_amax/initial_amax ratio = {ratio.item():.4f}") |
| 169 | + else: |
| 170 | + print( |
| 171 | + f"MSE Calibrator: best_amax/initial_amax ratio - " |
| 172 | + f"mean: {ratio.mean().item():.4f}, " |
| 173 | + f"min: {ratio.min().item():.4f}, " |
| 174 | + f"max: {ratio.max().item():.4f}" |
| 175 | + ) |
| 176 | + |
| 177 | + return self._amax |
0 commit comments