From 18e7e92939bf0f5f631bd729a498ad1666799fe2 Mon Sep 17 00:00:00 2001 From: weimingc <17592131+meenchen@users.noreply.github.com> Date: Thu, 6 Nov 2025 22:08:39 +0000 Subject: [PATCH 1/4] fix Signed-off-by: weimingc <17592131+meenchen@users.noreply.github.com> --- modelopt/torch/quantization/nn/modules/tensor_quantizer.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py index bf801646b..4c8440f3d 100644 --- a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py +++ b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py @@ -558,7 +558,9 @@ def _real_quantize(self, inputs): inputs, axis=self._axis, block_sizes=self._block_sizes, - scales=self.amax / 448.0 if self.amax is not None else None, + scales=self.amax / 448.0 + if (self.amax is not None and not self._block_sizes) + else None, ) buffer_to_register["_scale"] = _scale elif self._num_bits == 8: From 17dcf35ac9569c21adeca739e9b894add07085d0 Mon Sep 17 00:00:00 2001 From: weimingc <17592131+meenchen@users.noreply.github.com> Date: Thu, 6 Nov 2025 22:21:20 +0000 Subject: [PATCH 2/4] add test Signed-off-by: weimingc <17592131+meenchen@users.noreply.github.com> --- .../torch/quantization/test_qtensor_cuda.py | 33 +++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/tests/gpu/torch/quantization/test_qtensor_cuda.py b/tests/gpu/torch/quantization/test_qtensor_cuda.py index f1e511c21..a3710ea0c 100644 --- a/tests/gpu/torch/quantization/test_qtensor_cuda.py +++ b/tests/gpu/torch/quantization/test_qtensor_cuda.py @@ -569,3 +569,36 @@ def test_nvfp4_dequantize_fast(self, shape, input_dtype): f"Fast and standard dequantization differ: " f"max diff = {(dequant_fast - dequant_standard).abs().max()}" ) + + @pytest.mark.parametrize("device", ["cuda"]) + @pytest.mark.parametrize("input_dtype", [torch.float32, torch.float16, torch.bfloat16]) + @pytest.mark.parametrize( + ("input_shape", "block_sizes"), + [ + ((128, 1152), {-1: 128}), + ((256, 256), {-1: 64, -2: 64}), # 2D block sizes + ], + ) + def test_fp8_with_amax_and_block_sizes(self, device, input_dtype, input_shape, block_sizes): + """Test FP8 quantization with both amax and block_sizes specified.""" + quant_cfg = QuantizerAttributeConfig( + num_bits=(4, 3), + block_sizes=block_sizes, + fake_quant=False, + ) + quantizer = TensorQuantizer(quant_cfg).to(device) + + # Set a mock amax (scalar) - this was causing the bug + mock_amax = torch.tensor(1.5, device=device) + quantizer.amax = mock_amax + + # Create input tensor + x = torch.randn(input_shape, dtype=input_dtype, device=device) + + # QDQ + q_x = quantizer(x) + deq_x = quantizer(q_x) + + assert torch.allclose(deq_x, x, rtol=1e-1, atol=1e-1) + assert hasattr(quantizer, "_scale") + assert quantizer._scale.numel() > 1 From b0564cdf8373013c773f2cb08b829544e30c767d Mon Sep 17 00:00:00 2001 From: weimingc <17592131+meenchen@users.noreply.github.com> Date: Tue, 2 Dec 2025 21:23:34 +0000 Subject: [PATCH 3/4] address comment Signed-off-by: weimingc <17592131+meenchen@users.noreply.github.com> --- modelopt/torch/quantization/nn/modules/tensor_quantizer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py index 4c8440f3d..9ceeac0c2 100644 --- a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py +++ b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py @@ -559,6 +559,7 @@ def _real_quantize(self, inputs): axis=self._axis, block_sizes=self._block_sizes, scales=self.amax / 448.0 + # for blockwise quantization, amax is a scalar and will be recomputed in the kernel if (self.amax is not None and not self._block_sizes) else None, ) From 86546cda76d2b0f0444dc13d76883adcbd7f7c59 Mon Sep 17 00:00:00 2001 From: weimingc <17592131+meenchen@users.noreply.github.com> Date: Tue, 2 Dec 2025 22:23:15 +0000 Subject: [PATCH 4/4] address comment Signed-off-by: weimingc <17592131+meenchen@users.noreply.github.com> --- modelopt/torch/quantization/nn/modules/tensor_quantizer.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py index 9ceeac0c2..c1e62fc3b 100644 --- a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py +++ b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py @@ -554,14 +554,13 @@ def _real_quantize(self, inputs): if self._num_bits == (4, 3): # FP8 quantization # For per-tensor/per-channel quantization, we might need amax which is synced across all ranks + # For blockwise quantization, amax will be recomputed in the kernel + use_amax = self.amax is not None and not (self._block_sizes and self.amax.numel() == 1) outputs, _scale = FP8QTensor.quantize( inputs, axis=self._axis, block_sizes=self._block_sizes, - scales=self.amax / 448.0 - # for blockwise quantization, amax is a scalar and will be recomputed in the kernel - if (self.amax is not None and not self._block_sizes) - else None, + scales=self.amax / 448.0 if use_amax else None, ) buffer_to_register["_scale"] = _scale elif self._num_bits == 8: