Skip to content

Commit afcac24

Browse files
committed
Remove config functions like int4_weight_only (#3145)
**Summary:** As a follow-up to #2994, this commit removes all quantization functions that were used as configs. These functions were deprecated in 0.14.0 and will be removed in the next release, 0.15.0. **Test Plan:** CI
1 parent 6259e98 commit afcac24

File tree

5 files changed

+21
-150
lines changed

5 files changed

+21
-150
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,7 @@ Our framework makes it straightforward to add tensor parallel support to your cu
243243
244244
We've added support for authoring and releasing [custom ops](./torchao/csrc/) that do not graph break with `torch.compile()`. We have a few examples you can follow
245245
246-
1. [fp6](torchao/dtypes/floatx/README.md) for 2x faster inference over fp16 with an easy to use API `quantize_(model, fpx_weight_only(3, 2))`
246+
1. [fp6](torchao/dtypes/floatx/README.md) for 2x faster inference over fp16 with an easy to use API `quantize_(model, FPXWeightOnlyConfig(3, 2))`
247247
2. [2:4 Sparse Marlin GEMM](https://github.com/pytorch/ao/pull/733) 2x speedups for FP16xINT4 kernels even at batch sizes up to 256
248248
3. [int4 tinygemm unpacker](https://github.com/pytorch/ao/pull/415) which makes it easier to switch quantized backends for inference
249249

test/quantization/test_quant_api.py

Lines changed: 18 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -792,38 +792,28 @@ def test_int4wo_cuda_serialization(self):
792792

793793
def test_config_deprecation(self):
794794
"""
795-
Test that old config functions like `int4_weight_only` trigger deprecation warnings.
795+
Test that old config functions like `Int8DynamicActivationInt4WeightConfig` trigger deprecation warnings.
796796
"""
797797
from torchao.quantization import (
798-
float8_dynamic_activation_float8_weight,
799-
float8_static_activation_float8_weight,
800-
float8_weight_only,
801-
fpx_weight_only,
802-
gemlite_uintx_weight_only,
803-
int4_dynamic_activation_int4_weight,
804-
int4_weight_only,
805-
int8_dynamic_activation_int4_weight,
806-
int8_dynamic_activation_int8_weight,
807-
int8_weight_only,
808-
uintx_weight_only,
798+
Float8StaticActivationFloat8WeightConfig,
799+
FPXWeightOnlyConfig,
800+
GemliteUIntXWeightOnlyConfig,
801+
Int4DynamicActivationInt4WeightConfig,
802+
Int8DynamicActivationInt4WeightConfig,
803+
UIntXWeightOnlyConfig,
809804
)
810805

811806
# Reset deprecation warning state, otherwise we won't log warnings here
812807
warnings.resetwarnings()
813808

814809
# Map from deprecated API to the args needed to instantiate it
815810
deprecated_apis_to_args = {
816-
float8_dynamic_activation_float8_weight: (),
817-
float8_static_activation_float8_weight: (torch.randn(3)),
818-
float8_weight_only: (),
819-
fpx_weight_only: (3, 2),
820-
gemlite_uintx_weight_only: (),
821-
int4_dynamic_activation_int4_weight: (),
822-
int4_weight_only: (),
823-
int8_dynamic_activation_int4_weight: (),
824-
int8_dynamic_activation_int8_weight: (),
825-
int8_weight_only: (),
826-
uintx_weight_only: (torch.uint4,),
811+
Float8StaticActivationFloat8WeightConfig: (torch.randn(3),),
812+
FPXWeightOnlyConfig: (3, 2),
813+
GemliteUIntXWeightOnlyConfig: (),
814+
Int4DynamicActivationInt4WeightConfig: (),
815+
Int8DynamicActivationInt4WeightConfig: (),
816+
UIntXWeightOnlyConfig: (torch.uint4,),
827817
}
828818

829819
# Call each deprecated API twice
@@ -832,19 +822,16 @@ def test_config_deprecation(self):
832822
cls(*args)
833823
cls(*args)
834824

835-
# Each call should have at least one warning.
836-
# Some of them can have two warnings - one for deprecation,
837-
# one for moving to prototype
838-
# 1 warning - just deprecation
839-
# 2 warnings - deprecation and prototype warnings
840-
self.assertTrue(len(_warnings) in (1, 2))
825+
self.assertTrue(len(_warnings) == 1)
841826
found_deprecated = False
842827
for w in _warnings:
843-
if "is deprecated and will be removed in a future release" in str(
828+
if "will be moving to prototype in a future release" in str(
844829
w.message
845830
):
846831
found_deprecated = True
847-
self.assertTrue(found_deprecated)
832+
self.assertTrue(
833+
found_deprecated, f"did not find deprecated warning for {cls}"
834+
)
848835

849836

850837
common_utils.instantiate_parametrized_tests(TestQuantFlow)

torchao/quantization/__init__.py

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -65,22 +65,10 @@
6565
PlainLayout,
6666
TensorCoreTiledLayout,
6767
UIntXWeightOnlyConfig,
68-
float8_dynamic_activation_float8_weight,
69-
float8_static_activation_float8_weight,
70-
float8_weight_only,
71-
fpx_weight_only,
7268
fqn_matches_fqn_config,
73-
gemlite_uintx_weight_only,
74-
int4_dynamic_activation_int4_weight,
75-
int4_weight_only,
76-
int8_dynamic_activation_int4_weight,
77-
int8_dynamic_activation_int8_semi_sparse_weight,
78-
int8_dynamic_activation_int8_weight,
79-
int8_weight_only,
8069
intx_quantization_aware_training,
8170
quantize_,
8271
swap_conv2d_1x1_to_linear,
83-
uintx_weight_only,
8472
)
8573
from .quant_primitives import (
8674
MappingType,
@@ -131,20 +119,8 @@
131119
"ALL_AUTOQUANT_CLASS_LIST",
132120
# top level API - manual
133121
"quantize_",
134-
"int4_dynamic_activation_int4_weight",
135-
"int8_dynamic_activation_int4_weight",
136-
"int8_dynamic_activation_int8_weight",
137-
"int8_dynamic_activation_int8_semi_sparse_weight",
138-
"int4_weight_only",
139-
"int8_weight_only",
140122
"intx_quantization_aware_training",
141-
"float8_weight_only",
142-
"float8_dynamic_activation_float8_weight",
143-
"float8_static_activation_float8_weight",
144-
"uintx_weight_only",
145-
"fpx_weight_only",
146123
"fqn_matches_fqn_config",
147-
"gemlite_uintx_weight_only",
148124
"swap_conv2d_1x1_to_linear",
149125
"Int4DynamicActivationInt4WeightConfig",
150126
"Int8DynamicActivationInt4WeightConfig",

torchao/quantization/quant_api.py

Lines changed: 1 addition & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,6 @@
9797
to_weight_tensor_with_linear_activation_quantization_metadata,
9898
)
9999
from torchao.utils import (
100-
_ConfigDeprecationWrapper,
101100
is_MI300,
102101
is_sm_at_least_89,
103102
is_sm_at_least_90,
@@ -146,18 +145,7 @@
146145
"autoquant",
147146
"_get_subclass_inserter",
148147
"quantize_",
149-
"int8_dynamic_activation_int4_weight",
150-
"int8_dynamic_activation_int8_weight",
151-
"int8_dynamic_activation_int8_semi_sparse_weight",
152-
"int4_weight_only",
153-
"int8_weight_only",
154148
"intx_quantization_aware_training",
155-
"float8_weight_only",
156-
"uintx_weight_only",
157-
"fpx_weight_only",
158-
"gemlite_uintx_weight_only",
159-
"float8_dynamic_activation_float8_weight",
160-
"float8_static_activation_float8_weight",
161149
"Int8DynActInt4WeightQuantizer",
162150
"Float8DynamicActivationFloat8SemiSparseWeightConfig",
163151
"ModuleFqnToConfig",
@@ -464,7 +452,7 @@ def quantize_(
464452
# Int8DynamicActivationInt8WeightConfig (optimized with int8 mm op and torch.compile)
465453
# Int4WeightOnlyConfig (optimized with int4 tinygemm kernel and torch.compile)
466454
# Int8WeightOnlyConfig (optimized with int8 mm op and torch.compile
467-
from torchao.quantization.quant_api import int4_weight_only
455+
from torchao.quantization.quant_api import Int4WeightOnlyConfig
468456
469457
m = nn.Sequential(nn.Linear(32, 1024), nn.Linear(1024, 32))
470458
quantize_(m, Int4WeightOnlyConfig(group_size=32, version=1))
@@ -599,12 +587,6 @@ def __post_init__(self):
599587
)
600588

601589

602-
# for BC
603-
int8_dynamic_activation_int4_weight = _ConfigDeprecationWrapper(
604-
"int8_dynamic_activation_int4_weight", Int8DynamicActivationInt4WeightConfig
605-
)
606-
607-
608590
@register_quantize_module_handler(Int8DynamicActivationInt4WeightConfig)
609591
def _int8_dynamic_activation_int4_weight_transform(
610592
module: torch.nn.Module,
@@ -973,12 +955,6 @@ def __post_init__(self):
973955
)
974956

975957

976-
# for bc
977-
int4_dynamic_activation_int4_weight = _ConfigDeprecationWrapper(
978-
"int4_dynamic_activation_int4_weight", Int4DynamicActivationInt4WeightConfig
979-
)
980-
981-
982958
@register_quantize_module_handler(Int4DynamicActivationInt4WeightConfig)
983959
def _int4_dynamic_activation_int4_weight_transform(
984960
module: torch.nn.Module, config: Int4DynamicActivationInt4WeightConfig
@@ -1039,12 +1015,6 @@ def __post_init__(self):
10391015
)
10401016

10411017

1042-
# for BC
1043-
gemlite_uintx_weight_only = _ConfigDeprecationWrapper(
1044-
"gemlite_uintx_weight_only", GemliteUIntXWeightOnlyConfig
1045-
)
1046-
1047-
10481018
@register_quantize_module_handler(GemliteUIntXWeightOnlyConfig)
10491019
def _gemlite_uintx_weight_only_transform(
10501020
module: torch.nn.Module, config: GemliteUIntXWeightOnlyConfig
@@ -1122,11 +1092,6 @@ def __post_init__(self):
11221092
torch._C._log_api_usage_once("torchao.quantization.Int4WeightOnlyConfig")
11231093

11241094

1125-
# for BC
1126-
# TODO maybe change other callsites
1127-
int4_weight_only = _ConfigDeprecationWrapper("int4_weight_only", Int4WeightOnlyConfig)
1128-
1129-
11301095
def _int4_weight_only_quantize_tensor(weight, config):
11311096
# TODO(future PR): perhaps move this logic to a different file, to keep the API
11321097
# file clean of implementation details
@@ -1338,10 +1303,6 @@ def __post_init__(self):
13381303
torch._C._log_api_usage_once("torchao.quantization.Int8WeightOnlyConfig")
13391304

13401305

1341-
# for BC
1342-
int8_weight_only = _ConfigDeprecationWrapper("int8_weight_only", Int8WeightOnlyConfig)
1343-
1344-
13451306
def _int8_weight_only_quantize_tensor(weight, config):
13461307
mapping_type = MappingType.SYMMETRIC
13471308
target_dtype = torch.int8
@@ -1506,12 +1467,6 @@ def __post_init__(self):
15061467
)
15071468

15081469

1509-
# for BC
1510-
int8_dynamic_activation_int8_weight = _ConfigDeprecationWrapper(
1511-
"int8_dynamic_activation_int8_weight", Int8DynamicActivationInt8WeightConfig
1512-
)
1513-
1514-
15151470
def _int8_dynamic_activation_int8_weight_quantize_tensor(weight, config):
15161471
layout = config.layout
15171472
act_mapping_type = config.act_mapping_type
@@ -1617,12 +1572,6 @@ def __post_init__(self):
16171572
torch._C._log_api_usage_once("torchao.quantization.Float8WeightOnlyConfig")
16181573

16191574

1620-
# for BC
1621-
float8_weight_only = _ConfigDeprecationWrapper(
1622-
"float8_weight_only", Float8WeightOnlyConfig
1623-
)
1624-
1625-
16261575
def _float8_weight_only_quant_tensor(weight, config):
16271576
if config.version == 1:
16281577
warnings.warn(
@@ -1800,12 +1749,6 @@ def __post_init__(self):
18001749
self.mm_config = Float8MMConfig(use_fast_accum=default_use_fast_accum)
18011750

18021751

1803-
# for bc
1804-
float8_dynamic_activation_float8_weight = _ConfigDeprecationWrapper(
1805-
"float8_dynamic_activation_float8_weight", Float8DynamicActivationFloat8WeightConfig
1806-
)
1807-
1808-
18091752
def _float8_dynamic_activation_float8_weight_quantize_tensor(weight, config):
18101753
activation_dtype = config.activation_dtype
18111754
weight_dtype = config.weight_dtype
@@ -1995,12 +1938,6 @@ def __post_init__(self):
19951938
)
19961939

19971940

1998-
# for bc
1999-
float8_static_activation_float8_weight = _ConfigDeprecationWrapper(
2000-
"float8_static_activation_float8_weight", Float8StaticActivationFloat8WeightConfig
2001-
)
2002-
2003-
20041941
@register_quantize_module_handler(Float8StaticActivationFloat8WeightConfig)
20051942
def _float8_static_activation_float8_weight_transform(
20061943
module: torch.nn.Module, config: Float8StaticActivationFloat8WeightConfig
@@ -2086,12 +2023,6 @@ def __post_init__(self):
20862023
)
20872024

20882025

2089-
# for BC
2090-
uintx_weight_only = _ConfigDeprecationWrapper(
2091-
"uintx_weight_only", UIntXWeightOnlyConfig
2092-
)
2093-
2094-
20952026
@register_quantize_module_handler(UIntXWeightOnlyConfig)
20962027
def _uintx_weight_only_transform(
20972028
module: torch.nn.Module, config: UIntXWeightOnlyConfig
@@ -2373,10 +2304,6 @@ def __post_init__(self):
23732304
)
23742305

23752306

2376-
# for BC
2377-
fpx_weight_only = _ConfigDeprecationWrapper("fpx_weight_only", FPXWeightOnlyConfig)
2378-
2379-
23802307
@register_quantize_module_handler(FPXWeightOnlyConfig)
23812308
def _fpx_weight_only_transform(
23822309
module: torch.nn.Module, config: FPXWeightOnlyConfig

torchao/utils.py

Lines changed: 1 addition & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from functools import reduce
1313
from importlib.metadata import version
1414
from math import gcd
15-
from typing import Any, Callable, Optional, Type
15+
from typing import Any, Callable, Optional
1616

1717
import torch
1818
import torch.nn.utils.parametrize as parametrize
@@ -434,25 +434,6 @@ def __eq__(self, other):
434434
TORCH_VERSION_AFTER_2_2 = _deprecated_torch_version_after("2.2.0.dev")
435435

436436

437-
class _ConfigDeprecationWrapper:
438-
"""
439-
A deprecation wrapper that directs users from a deprecated "config function"
440-
(e.g. `int4_weight_only`) to the replacement config class.
441-
"""
442-
443-
def __init__(self, deprecated_name: str, config_cls: Type):
444-
self.deprecated_name = deprecated_name
445-
self.config_cls = config_cls
446-
447-
def __call__(self, *args, **kwargs):
448-
warnings.warn(
449-
f"`{self.deprecated_name}` is deprecated and will be removed in a future release. "
450-
f"Please use `{self.config_cls.__name__}` instead. Example usage:\n"
451-
f" quantize_(model, {self.config_cls.__name__}(...))"
452-
)
453-
return self.config_cls(*args, **kwargs)
454-
455-
456437
"""
457438
Helper function for implementing aten op or torch function dispatch
458439
and dispatching to these implementations.

0 commit comments

Comments
 (0)