Skip to content

Commit 69d38c0

Browse files
committed
Remove config functions like int4_weight_only
**Summary:** As a follow-up to #2994, this commit removes all quantization functions that were used as configs. These functions were deprecated in 0.14.0 and will be removed in the next release, 0.15.0. **Test Plan:** CI ghstack-source-id: 31fc27d Pull Request resolved: #3145
1 parent 6c24a7a commit 69d38c0

File tree

5 files changed

+3
-170
lines changed

5 files changed

+3
-170
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -258,7 +258,7 @@ Our framework makes it straightforward to add tensor parallel support to your cu
258258
259259
We've added support for authoring and releasing [custom ops](./torchao/csrc/) that do not graph break with `torch.compile()`. We have a few examples you can follow
260260
261-
1. [fp6](torchao/dtypes/floatx/README.md) for 2x faster inference over fp16 with an easy to use API `quantize_(model, fpx_weight_only(3, 2))`
261+
1. [fp6](torchao/dtypes/floatx/README.md) for 2x faster inference over fp16 with an easy to use API `quantize_(model, FPXWeightOnlyConfig(3, 2))`
262262
2. [2:4 Sparse Marlin GEMM](https://github.com/pytorch/ao/pull/733) 2x speedups for FP16xINT4 kernels even at batch sizes up to 256
263263
3. [int4 tinygemm unpacker](https://github.com/pytorch/ao/pull/415) which makes it easier to switch quantized backends for inference
264264

test/quantization/test_quant_api.py

Lines changed: 0 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
import gc
1111
import tempfile
1212
import unittest
13-
import warnings
1413
from pathlib import Path
1514

1615
import torch
@@ -847,56 +846,6 @@ def test_int4wo_cuda_serialization(self):
847846
# load state_dict in cuda
848847
model.load_state_dict(sd, assign=True)
849848

850-
def test_config_deprecation(self):
851-
"""
852-
Test that old config functions like `int4_weight_only` trigger deprecation warnings.
853-
"""
854-
from torchao.quantization import (
855-
float8_dynamic_activation_float8_weight,
856-
float8_static_activation_float8_weight,
857-
float8_weight_only,
858-
fpx_weight_only,
859-
gemlite_uintx_weight_only,
860-
int4_dynamic_activation_int4_weight,
861-
int4_weight_only,
862-
int8_dynamic_activation_int4_weight,
863-
int8_dynamic_activation_int8_weight,
864-
int8_weight_only,
865-
uintx_weight_only,
866-
)
867-
868-
# Reset deprecation warning state, otherwise we won't log warnings here
869-
warnings.resetwarnings()
870-
871-
# Map from deprecated API to the args needed to instantiate it
872-
deprecated_apis_to_args = {
873-
float8_dynamic_activation_float8_weight: (),
874-
float8_static_activation_float8_weight: (torch.randn(3)),
875-
float8_weight_only: (),
876-
fpx_weight_only: (3, 2),
877-
gemlite_uintx_weight_only: (),
878-
int4_dynamic_activation_int4_weight: (),
879-
int4_weight_only: (),
880-
int8_dynamic_activation_int4_weight: (),
881-
int8_dynamic_activation_int8_weight: (),
882-
int8_weight_only: (),
883-
uintx_weight_only: (torch.uint4,),
884-
}
885-
886-
with warnings.catch_warnings(record=True) as _warnings:
887-
# Call each deprecated API twice
888-
for cls, args in deprecated_apis_to_args.items():
889-
cls(*args)
890-
cls(*args)
891-
892-
# Each call should trigger the warning only once
893-
self.assertEqual(len(_warnings), len(deprecated_apis_to_args))
894-
for w in _warnings:
895-
self.assertIn(
896-
"is deprecated and will be removed in a future release",
897-
str(w.message),
898-
)
899-
900849

901850
common_utils.instantiate_parametrized_tests(TestQuantFlow)
902851

torchao/quantization/__init__.py

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -64,21 +64,9 @@
6464
PlainLayout,
6565
TensorCoreTiledLayout,
6666
UIntXWeightOnlyConfig,
67-
float8_dynamic_activation_float8_weight,
68-
float8_static_activation_float8_weight,
69-
float8_weight_only,
70-
fpx_weight_only,
71-
gemlite_uintx_weight_only,
72-
int4_dynamic_activation_int4_weight,
73-
int4_weight_only,
74-
int8_dynamic_activation_int4_weight,
75-
int8_dynamic_activation_int8_semi_sparse_weight,
76-
int8_dynamic_activation_int8_weight,
77-
int8_weight_only,
7867
intx_quantization_aware_training,
7968
quantize_,
8069
swap_conv2d_1x1_to_linear,
81-
uintx_weight_only,
8270
)
8371
from .quant_primitives import (
8472
MappingType,
@@ -131,19 +119,7 @@
131119
"ALL_AUTOQUANT_CLASS_LIST",
132120
# top level API - manual
133121
"quantize_",
134-
"int4_dynamic_activation_int4_weight",
135-
"int8_dynamic_activation_int4_weight",
136-
"int8_dynamic_activation_int8_weight",
137-
"int8_dynamic_activation_int8_semi_sparse_weight",
138-
"int4_weight_only",
139-
"int8_weight_only",
140122
"intx_quantization_aware_training",
141-
"float8_weight_only",
142-
"float8_dynamic_activation_float8_weight",
143-
"float8_static_activation_float8_weight",
144-
"uintx_weight_only",
145-
"fpx_weight_only",
146-
"gemlite_uintx_weight_only",
147123
"swap_conv2d_1x1_to_linear",
148124
"Int4DynamicActivationInt4WeightConfig",
149125
"Int8DynamicActivationInt4WeightConfig",

torchao/quantization/quant_api.py

Lines changed: 1 addition & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,6 @@
9696
to_weight_tensor_with_linear_activation_quantization_metadata,
9797
)
9898
from torchao.utils import (
99-
_ConfigDeprecationWrapper,
10099
is_MI300,
101100
is_sm_at_least_89,
102101
is_sm_at_least_90,
@@ -148,18 +147,7 @@
148147
"autoquant",
149148
"_get_subclass_inserter",
150149
"quantize_",
151-
"int8_dynamic_activation_int4_weight",
152-
"int8_dynamic_activation_int8_weight",
153-
"int8_dynamic_activation_int8_semi_sparse_weight",
154-
"int4_weight_only",
155-
"int8_weight_only",
156150
"intx_quantization_aware_training",
157-
"float8_weight_only",
158-
"uintx_weight_only",
159-
"fpx_weight_only",
160-
"gemlite_uintx_weight_only",
161-
"float8_dynamic_activation_float8_weight",
162-
"float8_static_activation_float8_weight",
163151
"Int8DynActInt4WeightQuantizer",
164152
"Float8DynamicActivationFloat8SemiSparseWeightConfig",
165153
"ModuleFqnToConfig",
@@ -507,7 +495,7 @@ def quantize_(
507495
# Int8DynamicActivationInt8WeightConfig (optimized with int8 mm op and torch.compile)
508496
# Int4WeightOnlyConfig (optimized with int4 tinygemm kernel and torch.compile)
509497
# Int8WeightOnlyConfig (optimized with int8 mm op and torch.compile
510-
from torchao.quantization.quant_api import int4_weight_only
498+
from torchao.quantization.quant_api import Int4WeightOnlyConfig
511499
512500
m = nn.Sequential(nn.Linear(32, 1024), nn.Linear(1024, 32))
513501
quantize_(m, Int4WeightOnlyConfig(group_size=32, version=1))
@@ -629,12 +617,6 @@ def __post_init__(self):
629617
)
630618

631619

632-
# for BC
633-
int8_dynamic_activation_int4_weight = _ConfigDeprecationWrapper(
634-
"int8_dynamic_activation_int4_weight", Int8DynamicActivationInt4WeightConfig
635-
)
636-
637-
638620
@register_quantize_module_handler(Int8DynamicActivationInt4WeightConfig)
639621
def _int8_dynamic_activation_int4_weight_transform(
640622
module: torch.nn.Module,
@@ -1000,12 +982,6 @@ def __post_init__(self):
1000982
)
1001983

1002984

1003-
# for bc
1004-
int4_dynamic_activation_int4_weight = _ConfigDeprecationWrapper(
1005-
"int4_dynamic_activation_int4_weight", Int4DynamicActivationInt4WeightConfig
1006-
)
1007-
1008-
1009985
@register_quantize_module_handler(Int4DynamicActivationInt4WeightConfig)
1010986
def _int4_dynamic_activation_int4_weight_transform(
1011987
module: torch.nn.Module, config: Int4DynamicActivationInt4WeightConfig
@@ -1063,12 +1039,6 @@ def __post_init__(self):
10631039
)
10641040

10651041

1066-
# for BC
1067-
gemlite_uintx_weight_only = _ConfigDeprecationWrapper(
1068-
"gemlite_uintx_weight_only", GemliteUIntXWeightOnlyConfig
1069-
)
1070-
1071-
10721042
@register_quantize_module_handler(GemliteUIntXWeightOnlyConfig)
10731043
def _gemlite_uintx_weight_only_transform(
10741044
module: torch.nn.Module, config: GemliteUIntXWeightOnlyConfig
@@ -1146,11 +1116,6 @@ def __post_init__(self):
11461116
torch._C._log_api_usage_once("torchao.quantization.Int4WeightOnlyConfig")
11471117

11481118

1149-
# for BC
1150-
# TODO maybe change other callsites
1151-
int4_weight_only = _ConfigDeprecationWrapper("int4_weight_only", Int4WeightOnlyConfig)
1152-
1153-
11541119
def _int4_weight_only_quantize_tensor(weight, config):
11551120
# TODO(future PR): perhaps move this logic to a different file, to keep the API
11561121
# file clean of implementation details
@@ -1362,10 +1327,6 @@ def __post_init__(self):
13621327
torch._C._log_api_usage_once("torchao.quantization.Int8WeightOnlyConfig")
13631328

13641329

1365-
# for BC
1366-
int8_weight_only = _ConfigDeprecationWrapper("int8_weight_only", Int8WeightOnlyConfig)
1367-
1368-
13691330
def _int8_weight_only_quantize_tensor(weight, config):
13701331
mapping_type = MappingType.SYMMETRIC
13711332
target_dtype = torch.int8
@@ -1523,12 +1484,6 @@ def __post_init__(self):
15231484
)
15241485

15251486

1526-
# for BC
1527-
int8_dynamic_activation_int8_weight = _ConfigDeprecationWrapper(
1528-
"int8_dynamic_activation_int8_weight", Int8DynamicActivationInt8WeightConfig
1529-
)
1530-
1531-
15321487
def _int8_dynamic_activation_int8_weight_quantize_tensor(weight, config):
15331488
layout = config.layout
15341489
act_mapping_type = config.act_mapping_type
@@ -1634,12 +1589,6 @@ def __post_init__(self):
16341589
torch._C._log_api_usage_once("torchao.quantization.Float8WeightOnlyConfig")
16351590

16361591

1637-
# for BC
1638-
float8_weight_only = _ConfigDeprecationWrapper(
1639-
"float8_weight_only", Float8WeightOnlyConfig
1640-
)
1641-
1642-
16431592
def _float8_weight_only_quant_tensor(weight, config):
16441593
if config.version == 1:
16451594
warnings.warn(
@@ -1798,12 +1747,6 @@ def __post_init__(self):
17981747
self.granularity = [activation_granularity, weight_granularity]
17991748

18001749

1801-
# for bc
1802-
float8_dynamic_activation_float8_weight = _ConfigDeprecationWrapper(
1803-
"float8_dynamic_activation_float8_weight", Float8DynamicActivationFloat8WeightConfig
1804-
)
1805-
1806-
18071750
def _float8_dynamic_activation_float8_weight_quantize_tensor(weight, config):
18081751
activation_dtype = config.activation_dtype
18091752
weight_dtype = config.weight_dtype
@@ -1979,12 +1922,6 @@ def __post_init__(self):
19791922
)
19801923

19811924

1982-
# for bc
1983-
float8_static_activation_float8_weight = _ConfigDeprecationWrapper(
1984-
"float8_static_activation_float8_weight", Float8StaticActivationFloat8WeightConfig
1985-
)
1986-
1987-
19881925
@register_quantize_module_handler(Float8StaticActivationFloat8WeightConfig)
19891926
def _float8_static_activation_float8_weight_transform(
19901927
module: torch.nn.Module, config: Float8StaticActivationFloat8WeightConfig
@@ -2067,12 +2004,6 @@ def __post_init__(self):
20672004
torch._C._log_api_usage_once("torchao.quantization.UIntXWeightOnlyConfig")
20682005

20692006

2070-
# for BC
2071-
uintx_weight_only = _ConfigDeprecationWrapper(
2072-
"uintx_weight_only", UIntXWeightOnlyConfig
2073-
)
2074-
2075-
20762007
@register_quantize_module_handler(UIntXWeightOnlyConfig)
20772008
def _uintx_weight_only_transform(
20782009
module: torch.nn.Module, config: UIntXWeightOnlyConfig
@@ -2351,10 +2282,6 @@ def __post_init__(self):
23512282
torch._C._log_api_usage_once("torchao.quantization.FPXWeightOnlyConfig")
23522283

23532284

2354-
# for BC
2355-
fpx_weight_only = _ConfigDeprecationWrapper("fpx_weight_only", FPXWeightOnlyConfig)
2356-
2357-
23582285
@register_quantize_module_handler(FPXWeightOnlyConfig)
23592286
def _fpx_weight_only_transform(
23602287
module: torch.nn.Module, config: FPXWeightOnlyConfig

torchao/utils.py

Lines changed: 1 addition & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from functools import reduce
1313
from importlib.metadata import version
1414
from math import gcd
15-
from typing import Any, Callable, Optional, Type
15+
from typing import Any, Callable, Optional
1616

1717
import torch
1818
import torch.nn.utils.parametrize as parametrize
@@ -433,25 +433,6 @@ def __eq__(self, other):
433433
TORCH_VERSION_AFTER_2_2 = _deprecated_torch_version_after("2.2.0.dev")
434434

435435

436-
class _ConfigDeprecationWrapper:
437-
"""
438-
A deprecation wrapper that directs users from a deprecated "config function"
439-
(e.g. `int4_weight_only`) to the replacement config class.
440-
"""
441-
442-
def __init__(self, deprecated_name: str, config_cls: Type):
443-
self.deprecated_name = deprecated_name
444-
self.config_cls = config_cls
445-
446-
def __call__(self, *args, **kwargs):
447-
warnings.warn(
448-
f"`{self.deprecated_name}` is deprecated and will be removed in a future release. "
449-
f"Please use `{self.config_cls.__name__}` instead. Example usage:\n"
450-
f" quantize_(model, {self.config_cls.__name__}(...))"
451-
)
452-
return self.config_cls(*args, **kwargs)
453-
454-
455436
"""
456437
Helper function for implementing aten op or torch function dispatch
457438
and dispatching to these implementations.

0 commit comments

Comments
 (0)