From 06b211ce0225ad41148809cb78b385fc8c47c244 Mon Sep 17 00:00:00 2001 From: Jesse Cai Date: Thu, 30 Oct 2025 05:56:32 -0700 Subject: [PATCH 1/5] Assign module to device after quantization Summary: Before, we were moving the module to the device and then quantizing it, now we quantize first and move to the device after. This was causing some of the huggingface integration tests to fail. Test Plan: Reviewers: Subscribers: Tasks: Tags: --- torchao/quantization/quant_api.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 417150229c..83000bfe39 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -487,7 +487,9 @@ def quantize_( module_fqn.rsplit(".", 1) if "." in module_fqn else module_fqn ) # this replaces inplace, so no need to reassign - _fqn_to_config_handler(module, module_name, config, device) + _fqn_to_config_handler(module, module_name, config) + if device is not None: + module.to(device=device) return if isinstance(config, AOBaseConfig): filter_fn = _is_linear if filter_fn is None else filter_fn @@ -2451,7 +2453,6 @@ def _fqn_to_config_handler( module: torch.nn.Module, fqn: str, config: FqnToConfig, - device: Optional[torch.device] = None, ): """This function expects a module that either is specified in FqnToConfig or has a parameter that is specified in FqnToConfig. @@ -2460,7 +2461,6 @@ def _fqn_to_config_handler( fqn (str): The fully qualified name of the module containing the parameters. config (FqnToConfig): Configuration object containing regex patterns / fqn mapped to quantization configurations. - device (Optional[torch.device]): The device to move the module to as part of quantization Returns: torch.nn.Module: The modified module with quantized parameters. @@ -2468,9 +2468,6 @@ def _fqn_to_config_handler( Raises: NotImplementedError: If the quantization configuration is not yet supported for parameter quantization. """ - if device is not None: - module = module.to(device) - parameter_config_found = False top_level_params = [] for i, (parameter_name, param) in enumerate(list(module.named_parameters())): From 90502df975971ac4cda954af2e2ba3994121d650 Mon Sep 17 00:00:00 2001 From: Jesse Cai Date: Sun, 2 Nov 2025 20:43:51 -0800 Subject: [PATCH 2/5] update --- torchao/quantization/__init__.py | 3 +++ torchao/quantization/quant_api.py | 33 ++++++++++++++++++++----------- 2 files changed, 24 insertions(+), 12 deletions(-) diff --git a/torchao/quantization/__init__.py b/torchao/quantization/__init__.py index 7459b2504c..ba7f38facd 100644 --- a/torchao/quantization/__init__.py +++ b/torchao/quantization/__init__.py @@ -69,6 +69,7 @@ float8_static_activation_float8_weight, float8_weight_only, fpx_weight_only, + fqn_matches_fqn_config, gemlite_uintx_weight_only, int4_dynamic_activation_int4_weight, int4_weight_only, @@ -221,4 +222,6 @@ "Int4WeightOnlyGPTQQuantizer", "MultiTensor", "MultiTensorInputRecorder", + # helper functions + "fqn_matches_fqn_config", ] diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 83000bfe39..e3096c79f8 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -160,6 +160,7 @@ "Int8DynActInt4WeightQuantizer", "Float8DynamicActivationFloat8SemiSparseWeightConfig", "ModuleFqnToConfig", + "FqnToConfig", ] LAYOUT_TO_ZERO_POINT_DOMAIN = { @@ -479,7 +480,7 @@ def quantize_( for module_fqn, module in model.named_modules(): if ( - _fqn_matches_fqn_config(module_fqn, config) + fqn_matches_fqn_config(module_fqn, config) or _module_param_matches_fqn_config(module, module_fqn, config) or ("_default" in config.fqn_to_config and _is_linear(module)) ): @@ -1254,17 +1255,22 @@ def _int4_weight_only_quantize_tensor(weight, config): @register_quantize_module_handler(Int4WeightOnlyConfig) def _int4_weight_only_transform( - module: torch.nn.Module, config: Int4WeightOnlyConfig + module: torch.nn.Module, + config: Int4WeightOnlyConfig, + *, + parameter_name: str = "weight", ) -> torch.nn.Module: if config.set_inductor_config: torchao.quantization.utils.recommended_inductor_config_setter() - assert hasattr(module, "weight"), ( - "applying int8 weight only quant requires module to have weight attribute" + assert hasattr(module, parameter_name), ( + "applying int8 weight only quant requires module to have {parameter_name} attribute" + " but {module} does not have one" ) - new_weight = _int4_weight_only_quantize_tensor(module.weight, config) - module.weight = torch.nn.Parameter(new_weight, requires_grad=False) + new_weight = _int4_weight_only_quantize_tensor( + getattr(module, parameter_name), config + ) + setattr(module, parameter_name, torch.nn.Parameter(new_weight, requires_grad=False)) module.extra_repr = types.MethodType(_linear_extra_repr, module) return module @@ -2298,18 +2304,19 @@ def _intx_weight_only_transform( *, custom_scale: Optional[torch.Tensor] = None, custom_zero_point: Optional[torch.Tensor] = None, + parameter_name="weight", ) -> torch.nn.Module: - assert hasattr(module, "weight"), ( - "applying intx weight only quant requires module to have weight attribute" + assert hasattr(module, parameter_name), ( + "applying intx weight only quant requires module to have {parameter_name} attribute" + " but {module} does not have one" ) new_weight = _intx_weight_only_quantize_tensor( - module.weight, + getattr(module, parameter_name), config, custom_scale=custom_scale, custom_zero_point=custom_zero_point, ) - module.weight = torch.nn.Parameter(new_weight, requires_grad=False) + setattr(module, parameter_name, torch.nn.Parameter(new_weight, requires_grad=False)) if isinstance(module, nn.Linear): module.extra_repr = types.MethodType(_linear_extra_repr, module) @@ -2446,6 +2453,8 @@ def __post_init__(self): Float8DynamicActivationFloat8WeightConfig, Float8WeightOnlyConfig, Int8WeightOnlyConfig, + Int4WeightOnlyConfig, + IntxWeightOnlyConfig, } @@ -2541,7 +2550,7 @@ def _fqn_to_config_handler( return module -def _fqn_matches_fqn_config( +def fqn_matches_fqn_config( fqn: str, config: FqnToConfig, ): @@ -2586,7 +2595,7 @@ def _module_param_matches_fqn_config( for name, param in module.named_parameters(): if name in dir(module): parameter_fqn = f"{fqn}.{name}" if len(fqn) > 0 else name - if _fqn_matches_fqn_config(parameter_fqn, config): + if fqn_matches_fqn_config(parameter_fqn, config): return True return False From 3537a222731cf9f8a453d2a58e4b681b66bbd0af Mon Sep 17 00:00:00 2001 From: Jesse Cai Date: Sun, 2 Nov 2025 21:25:02 -0800 Subject: [PATCH 3/5] add test --- test/quantization/test_quant_api.py | 35 +++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/test/quantization/test_quant_api.py b/test/quantization/test_quant_api.py index e1c6471b17..909ddd1842 100644 --- a/test/quantization/test_quant_api.py +++ b/test/quantization/test_quant_api.py @@ -1122,6 +1122,41 @@ def reset_memory(): assert param.is_cuda self.assertLess(memory_streaming, memory_baseline) + from torchao.quantization.quant_api import ( + CUSTOM_PARAM_QUANTIZATION_SUPPORTED_CONFIGS, + ) + + @common_utils.parametrize("config", CUSTOM_PARAM_QUANTIZATION_SUPPORTED_CONFIGS) + def test_fqn_to_config_supported_param_configs(self, config): + """Test that all supported parameter configs are in FqnToConfig.""" + + from torchao.utils import ( + TorchAOBaseTensor, + ) + + torchao_tensor_types = (TorchAOBaseTensor, AffineQuantizedTensor) + m = ToyLinearModel(m=128, k=128, n=128) + m.linear1.register_parameter( + "custom_param_name", torch.nn.Parameter(torch.randn(m.linear1.weight.shape)) + ) + m = m.cuda().bfloat16() + + fqn_config = FqnToConfig( + { + "linear1.custom_param_name": config(), + "linear1.weight": config(), + "linear2.weight": config(), + } + ) + + quantize_(m, fqn_config, filter_fn=None) + + assert isinstance(m.linear1.custom_param_name.data, torchao_tensor_types) + assert isinstance(m.linear1.weight.data, torchao_tensor_types) + assert isinstance(m.linear2.weight.data, torchao_tensor_types) + + +common_utils.instantiate_parametrized_tests(TestFqnToConfig) if __name__ == "__main__": unittest.main() From ff68bb37d91a1147b1e6607110026d276dc8c83d Mon Sep 17 00:00:00 2001 From: Jesse Cai Date: Mon, 3 Nov 2025 10:08:14 -0800 Subject: [PATCH 4/5] add some bug fixes --- torchao/dtypes/affine_quantized_tensor_ops.py | 3 +++ torchao/quantization/quant_api.py | 2 +- .../quantize_/workflows/float8/float8_tensor.py | 15 ++++++++------- torchao/utils.py | 3 +++ 4 files changed, 15 insertions(+), 8 deletions(-) diff --git a/torchao/dtypes/affine_quantized_tensor_ops.py b/torchao/dtypes/affine_quantized_tensor_ops.py index ffadece729..e033e9d8b3 100644 --- a/torchao/dtypes/affine_quantized_tensor_ops.py +++ b/torchao/dtypes/affine_quantized_tensor_ops.py @@ -456,6 +456,9 @@ def _(func, types, args, kwargs): def _(func, types, args, kwargs): self = args[0] src = args[1] + if type(self) is torch.Tensor and isinstance(src, AffineQuantizedTensor): + func(self, src.dequantize()) + return if _same_metadata(self, src): self_tensors = self.__tensor_flatten__()[0] for tensor_name in self_tensors: diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index e3096c79f8..a14f2a3956 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -485,7 +485,7 @@ def quantize_( or ("_default" in config.fqn_to_config and _is_linear(module)) ): module_name = ( - module_fqn.rsplit(".", 1) if "." in module_fqn else module_fqn + module_fqn.rsplit(".", 1)[0] if "." in module_fqn else module_fqn ) # this replaces inplace, so no need to reassign _fqn_to_config_handler(module, module_name, config) diff --git a/torchao/quantization/quantize_/workflows/float8/float8_tensor.py b/torchao/quantization/quantize_/workflows/float8/float8_tensor.py index 47395a15af..d6204f1931 100644 --- a/torchao/quantization/quantize_/workflows/float8/float8_tensor.py +++ b/torchao/quantization/quantize_/workflows/float8/float8_tensor.py @@ -202,13 +202,14 @@ def from_hp( else: maybe_hp_value_ub_tensor = None if isinstance(granularity, PerRow): - data, scale = torch.ops.triton.quantize_fp8_row( - hp_tensor, scale_ub=maybe_hp_value_ub_tensor - ) - scale_shape = [] - for i in range(hp_tensor.ndim): - scale_shape.append(hp_tensor.shape[i] // block_size[i]) - scale = scale.reshape(*scale_shape) + with torch.cuda.device(hp_tensor.device): + data, scale = torch.ops.triton.quantize_fp8_row( + hp_tensor, scale_ub=maybe_hp_value_ub_tensor + ) + scale_shape = [] + for i in range(hp_tensor.ndim): + scale_shape.append(hp_tensor.shape[i] // block_size[i]) + scale = scale.reshape(*scale_shape) else: assert isinstance(granularity, PerTensor), ( f"Expected per tensor, got {granularity}" diff --git a/torchao/utils.py b/torchao/utils.py index 5af3e00cfa..187c6c41eb 100644 --- a/torchao/utils.py +++ b/torchao/utils.py @@ -571,6 +571,9 @@ def _same_metadata(self: TorchAOBaseTensor, src: TorchAOBaseTensor) -> bool: def _(func, types, args, kwargs): self = args[0] src = args[1] + if type(self) is torch.Tensor and isinstance(src, TorchAOBaseTensor): + func(self, src.dequantize()) + return if _same_metadata(self, src): self_tensors = self.__tensor_flatten__()[0] for tensor_name in self_tensors: From 31e738135207eb7d8d7a2de69a6d3e815053fbd2 Mon Sep 17 00:00:00 2001 From: Jesse Cai Date: Mon, 3 Nov 2025 11:14:13 -0800 Subject: [PATCH 5/5] update --- torchao/core/config.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/torchao/core/config.py b/torchao/core/config.py index 330e6a42af..421dee52b8 100644 --- a/torchao/core/config.py +++ b/torchao/core/config.py @@ -144,10 +144,11 @@ def default(self, o): return [self.encode_value(item) for item in o] elif isinstance(o, tuple): - raise NotImplementedError( - "Tuples will be serialized as List in JSON, so we recommend to use " - f"Lists instead to avoid surprises. got: {o}" - ) + return [self.encode_value(item) for item in o] + # raise NotImplementedError( + # "Tuples will be serialized as List in JSON, so we recommend to use " + # f"Lists instead to avoid surprises. got: {o}" + # ) if isinstance(o, dict): return {k: self.encode_value(v) for k, v in o.items()}