From afcac24a39d1af4fdd94ac8ff1db6559e6cc902c Mon Sep 17 00:00:00 2001 From: andrewor14 Date: Mon, 13 Oct 2025 13:42:57 -0400 Subject: [PATCH] Remove config functions like `int4_weight_only` (#3145) **Summary:** As a follow-up to https://github.com/pytorch/ao/pull/2994, this commit removes all quantization functions that were used as configs. These functions were deprecated in 0.14.0 and will be removed in the next release, 0.15.0. **Test Plan:** CI --- README.md | 2 +- test/quantization/test_quant_api.py | 49 +++++++------------ torchao/quantization/__init__.py | 24 --------- torchao/quantization/quant_api.py | 75 +---------------------------- torchao/utils.py | 21 +------- 5 files changed, 21 insertions(+), 150 deletions(-) diff --git a/README.md b/README.md index 8f634400e1..527b2fdd3d 100644 --- a/README.md +++ b/README.md @@ -243,7 +243,7 @@ Our framework makes it straightforward to add tensor parallel support to your cu We've added support for authoring and releasing [custom ops](./torchao/csrc/) that do not graph break with `torch.compile()`. We have a few examples you can follow -1. [fp6](torchao/dtypes/floatx/README.md) for 2x faster inference over fp16 with an easy to use API `quantize_(model, fpx_weight_only(3, 2))` +1. [fp6](torchao/dtypes/floatx/README.md) for 2x faster inference over fp16 with an easy to use API `quantize_(model, FPXWeightOnlyConfig(3, 2))` 2. [2:4 Sparse Marlin GEMM](https://github.com/pytorch/ao/pull/733) 2x speedups for FP16xINT4 kernels even at batch sizes up to 256 3. [int4 tinygemm unpacker](https://github.com/pytorch/ao/pull/415) which makes it easier to switch quantized backends for inference diff --git a/test/quantization/test_quant_api.py b/test/quantization/test_quant_api.py index bd46a543fe..ab91268058 100644 --- a/test/quantization/test_quant_api.py +++ b/test/quantization/test_quant_api.py @@ -792,20 +792,15 @@ def test_int4wo_cuda_serialization(self): def test_config_deprecation(self): """ - Test that old config functions like `int4_weight_only` trigger deprecation warnings. + Test that old config functions like `Int8DynamicActivationInt4WeightConfig` trigger deprecation warnings. """ from torchao.quantization import ( - float8_dynamic_activation_float8_weight, - float8_static_activation_float8_weight, - float8_weight_only, - fpx_weight_only, - gemlite_uintx_weight_only, - int4_dynamic_activation_int4_weight, - int4_weight_only, - int8_dynamic_activation_int4_weight, - int8_dynamic_activation_int8_weight, - int8_weight_only, - uintx_weight_only, + Float8StaticActivationFloat8WeightConfig, + FPXWeightOnlyConfig, + GemliteUIntXWeightOnlyConfig, + Int4DynamicActivationInt4WeightConfig, + Int8DynamicActivationInt4WeightConfig, + UIntXWeightOnlyConfig, ) # Reset deprecation warning state, otherwise we won't log warnings here @@ -813,17 +808,12 @@ def test_config_deprecation(self): # Map from deprecated API to the args needed to instantiate it deprecated_apis_to_args = { - float8_dynamic_activation_float8_weight: (), - float8_static_activation_float8_weight: (torch.randn(3)), - float8_weight_only: (), - fpx_weight_only: (3, 2), - gemlite_uintx_weight_only: (), - int4_dynamic_activation_int4_weight: (), - int4_weight_only: (), - int8_dynamic_activation_int4_weight: (), - int8_dynamic_activation_int8_weight: (), - int8_weight_only: (), - uintx_weight_only: (torch.uint4,), + Float8StaticActivationFloat8WeightConfig: (torch.randn(3),), + FPXWeightOnlyConfig: (3, 2), + GemliteUIntXWeightOnlyConfig: (), + Int4DynamicActivationInt4WeightConfig: (), + Int8DynamicActivationInt4WeightConfig: (), + UIntXWeightOnlyConfig: (torch.uint4,), } # Call each deprecated API twice @@ -832,19 +822,16 @@ def test_config_deprecation(self): cls(*args) cls(*args) - # Each call should have at least one warning. - # Some of them can have two warnings - one for deprecation, - # one for moving to prototype - # 1 warning - just deprecation - # 2 warnings - deprecation and prototype warnings - self.assertTrue(len(_warnings) in (1, 2)) + self.assertTrue(len(_warnings) == 1) found_deprecated = False for w in _warnings: - if "is deprecated and will be removed in a future release" in str( + if "will be moving to prototype in a future release" in str( w.message ): found_deprecated = True - self.assertTrue(found_deprecated) + self.assertTrue( + found_deprecated, f"did not find deprecated warning for {cls}" + ) common_utils.instantiate_parametrized_tests(TestQuantFlow) diff --git a/torchao/quantization/__init__.py b/torchao/quantization/__init__.py index 8dd6410597..58c2e347e0 100644 --- a/torchao/quantization/__init__.py +++ b/torchao/quantization/__init__.py @@ -65,22 +65,10 @@ PlainLayout, TensorCoreTiledLayout, UIntXWeightOnlyConfig, - float8_dynamic_activation_float8_weight, - float8_static_activation_float8_weight, - float8_weight_only, - fpx_weight_only, fqn_matches_fqn_config, - gemlite_uintx_weight_only, - int4_dynamic_activation_int4_weight, - int4_weight_only, - int8_dynamic_activation_int4_weight, - int8_dynamic_activation_int8_semi_sparse_weight, - int8_dynamic_activation_int8_weight, - int8_weight_only, intx_quantization_aware_training, quantize_, swap_conv2d_1x1_to_linear, - uintx_weight_only, ) from .quant_primitives import ( MappingType, @@ -131,20 +119,8 @@ "ALL_AUTOQUANT_CLASS_LIST", # top level API - manual "quantize_", - "int4_dynamic_activation_int4_weight", - "int8_dynamic_activation_int4_weight", - "int8_dynamic_activation_int8_weight", - "int8_dynamic_activation_int8_semi_sparse_weight", - "int4_weight_only", - "int8_weight_only", "intx_quantization_aware_training", - "float8_weight_only", - "float8_dynamic_activation_float8_weight", - "float8_static_activation_float8_weight", - "uintx_weight_only", - "fpx_weight_only", "fqn_matches_fqn_config", - "gemlite_uintx_weight_only", "swap_conv2d_1x1_to_linear", "Int4DynamicActivationInt4WeightConfig", "Int8DynamicActivationInt4WeightConfig", diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 7054e00564..6321f1063c 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -97,7 +97,6 @@ to_weight_tensor_with_linear_activation_quantization_metadata, ) from torchao.utils import ( - _ConfigDeprecationWrapper, is_MI300, is_sm_at_least_89, is_sm_at_least_90, @@ -146,18 +145,7 @@ "autoquant", "_get_subclass_inserter", "quantize_", - "int8_dynamic_activation_int4_weight", - "int8_dynamic_activation_int8_weight", - "int8_dynamic_activation_int8_semi_sparse_weight", - "int4_weight_only", - "int8_weight_only", "intx_quantization_aware_training", - "float8_weight_only", - "uintx_weight_only", - "fpx_weight_only", - "gemlite_uintx_weight_only", - "float8_dynamic_activation_float8_weight", - "float8_static_activation_float8_weight", "Int8DynActInt4WeightQuantizer", "Float8DynamicActivationFloat8SemiSparseWeightConfig", "ModuleFqnToConfig", @@ -464,7 +452,7 @@ def quantize_( # Int8DynamicActivationInt8WeightConfig (optimized with int8 mm op and torch.compile) # Int4WeightOnlyConfig (optimized with int4 tinygemm kernel and torch.compile) # Int8WeightOnlyConfig (optimized with int8 mm op and torch.compile - from torchao.quantization.quant_api import int4_weight_only + from torchao.quantization.quant_api import Int4WeightOnlyConfig m = nn.Sequential(nn.Linear(32, 1024), nn.Linear(1024, 32)) quantize_(m, Int4WeightOnlyConfig(group_size=32, version=1)) @@ -599,12 +587,6 @@ def __post_init__(self): ) -# for BC -int8_dynamic_activation_int4_weight = _ConfigDeprecationWrapper( - "int8_dynamic_activation_int4_weight", Int8DynamicActivationInt4WeightConfig -) - - @register_quantize_module_handler(Int8DynamicActivationInt4WeightConfig) def _int8_dynamic_activation_int4_weight_transform( module: torch.nn.Module, @@ -973,12 +955,6 @@ def __post_init__(self): ) -# for bc -int4_dynamic_activation_int4_weight = _ConfigDeprecationWrapper( - "int4_dynamic_activation_int4_weight", Int4DynamicActivationInt4WeightConfig -) - - @register_quantize_module_handler(Int4DynamicActivationInt4WeightConfig) def _int4_dynamic_activation_int4_weight_transform( module: torch.nn.Module, config: Int4DynamicActivationInt4WeightConfig @@ -1039,12 +1015,6 @@ def __post_init__(self): ) -# for BC -gemlite_uintx_weight_only = _ConfigDeprecationWrapper( - "gemlite_uintx_weight_only", GemliteUIntXWeightOnlyConfig -) - - @register_quantize_module_handler(GemliteUIntXWeightOnlyConfig) def _gemlite_uintx_weight_only_transform( module: torch.nn.Module, config: GemliteUIntXWeightOnlyConfig @@ -1122,11 +1092,6 @@ def __post_init__(self): torch._C._log_api_usage_once("torchao.quantization.Int4WeightOnlyConfig") -# for BC -# TODO maybe change other callsites -int4_weight_only = _ConfigDeprecationWrapper("int4_weight_only", Int4WeightOnlyConfig) - - def _int4_weight_only_quantize_tensor(weight, config): # TODO(future PR): perhaps move this logic to a different file, to keep the API # file clean of implementation details @@ -1338,10 +1303,6 @@ def __post_init__(self): torch._C._log_api_usage_once("torchao.quantization.Int8WeightOnlyConfig") -# for BC -int8_weight_only = _ConfigDeprecationWrapper("int8_weight_only", Int8WeightOnlyConfig) - - def _int8_weight_only_quantize_tensor(weight, config): mapping_type = MappingType.SYMMETRIC target_dtype = torch.int8 @@ -1506,12 +1467,6 @@ def __post_init__(self): ) -# for BC -int8_dynamic_activation_int8_weight = _ConfigDeprecationWrapper( - "int8_dynamic_activation_int8_weight", Int8DynamicActivationInt8WeightConfig -) - - def _int8_dynamic_activation_int8_weight_quantize_tensor(weight, config): layout = config.layout act_mapping_type = config.act_mapping_type @@ -1617,12 +1572,6 @@ def __post_init__(self): torch._C._log_api_usage_once("torchao.quantization.Float8WeightOnlyConfig") -# for BC -float8_weight_only = _ConfigDeprecationWrapper( - "float8_weight_only", Float8WeightOnlyConfig -) - - def _float8_weight_only_quant_tensor(weight, config): if config.version == 1: warnings.warn( @@ -1800,12 +1749,6 @@ def __post_init__(self): self.mm_config = Float8MMConfig(use_fast_accum=default_use_fast_accum) -# for bc -float8_dynamic_activation_float8_weight = _ConfigDeprecationWrapper( - "float8_dynamic_activation_float8_weight", Float8DynamicActivationFloat8WeightConfig -) - - def _float8_dynamic_activation_float8_weight_quantize_tensor(weight, config): activation_dtype = config.activation_dtype weight_dtype = config.weight_dtype @@ -1995,12 +1938,6 @@ def __post_init__(self): ) -# for bc -float8_static_activation_float8_weight = _ConfigDeprecationWrapper( - "float8_static_activation_float8_weight", Float8StaticActivationFloat8WeightConfig -) - - @register_quantize_module_handler(Float8StaticActivationFloat8WeightConfig) def _float8_static_activation_float8_weight_transform( module: torch.nn.Module, config: Float8StaticActivationFloat8WeightConfig @@ -2086,12 +2023,6 @@ def __post_init__(self): ) -# for BC -uintx_weight_only = _ConfigDeprecationWrapper( - "uintx_weight_only", UIntXWeightOnlyConfig -) - - @register_quantize_module_handler(UIntXWeightOnlyConfig) def _uintx_weight_only_transform( module: torch.nn.Module, config: UIntXWeightOnlyConfig @@ -2373,10 +2304,6 @@ def __post_init__(self): ) -# for BC -fpx_weight_only = _ConfigDeprecationWrapper("fpx_weight_only", FPXWeightOnlyConfig) - - @register_quantize_module_handler(FPXWeightOnlyConfig) def _fpx_weight_only_transform( module: torch.nn.Module, config: FPXWeightOnlyConfig diff --git a/torchao/utils.py b/torchao/utils.py index 26191e2482..8d227dbf3d 100644 --- a/torchao/utils.py +++ b/torchao/utils.py @@ -12,7 +12,7 @@ from functools import reduce from importlib.metadata import version from math import gcd -from typing import Any, Callable, Optional, Type +from typing import Any, Callable, Optional import torch import torch.nn.utils.parametrize as parametrize @@ -434,25 +434,6 @@ def __eq__(self, other): TORCH_VERSION_AFTER_2_2 = _deprecated_torch_version_after("2.2.0.dev") -class _ConfigDeprecationWrapper: - """ - A deprecation wrapper that directs users from a deprecated "config function" - (e.g. `int4_weight_only`) to the replacement config class. - """ - - def __init__(self, deprecated_name: str, config_cls: Type): - self.deprecated_name = deprecated_name - self.config_cls = config_cls - - def __call__(self, *args, **kwargs): - warnings.warn( - f"`{self.deprecated_name}` is deprecated and will be removed in a future release. " - f"Please use `{self.config_cls.__name__}` instead. Example usage:\n" - f" quantize_(model, {self.config_cls.__name__}(...))" - ) - return self.config_cls(*args, **kwargs) - - """ Helper function for implementing aten op or torch function dispatch and dispatching to these implementations.