Skip to content

Commit 1169efa

Browse files
committed
Remove config functions like int4_weight_only
**Summary:** As a follow-up to #2994, this commit removes all quantization functions that were used as configs. These functions were deprecated in 0.14.0 and will be removed in the next release, 0.15.0. **Test Plan:** CI ghstack-source-id: 31fc27d Pull Request resolved: #3145
1 parent bb65dbc commit 1169efa

File tree

5 files changed

+3
-170
lines changed

5 files changed

+3
-170
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -258,7 +258,7 @@ Our framework makes it straightforward to add tensor parallel support to your cu
258258
259259
We've added support for authoring and releasing [custom ops](./torchao/csrc/) that do not graph break with `torch.compile()`. We have a few examples you can follow
260260
261-
1. [fp6](torchao/dtypes/floatx/README.md) for 2x faster inference over fp16 with an easy to use API `quantize_(model, fpx_weight_only(3, 2))`
261+
1. [fp6](torchao/dtypes/floatx/README.md) for 2x faster inference over fp16 with an easy to use API `quantize_(model, FPXWeightOnlyConfig(3, 2))`
262262
2. [2:4 Sparse Marlin GEMM](https://github.com/pytorch/ao/pull/733) 2x speedups for FP16xINT4 kernels even at batch sizes up to 256
263263
3. [int4 tinygemm unpacker](https://github.com/pytorch/ao/pull/415) which makes it easier to switch quantized backends for inference
264264

test/quantization/test_quant_api.py

Lines changed: 0 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
import gc
1111
import tempfile
1212
import unittest
13-
import warnings
1413
from pathlib import Path
1514

1615
import torch
@@ -847,56 +846,6 @@ def test_int4wo_cuda_serialization(self):
847846
# load state_dict in cuda
848847
model.load_state_dict(sd, assign=True)
849848

850-
def test_config_deprecation(self):
851-
"""
852-
Test that old config functions like `int4_weight_only` trigger deprecation warnings.
853-
"""
854-
from torchao.quantization import (
855-
float8_dynamic_activation_float8_weight,
856-
float8_static_activation_float8_weight,
857-
float8_weight_only,
858-
fpx_weight_only,
859-
gemlite_uintx_weight_only,
860-
int4_dynamic_activation_int4_weight,
861-
int4_weight_only,
862-
int8_dynamic_activation_int4_weight,
863-
int8_dynamic_activation_int8_weight,
864-
int8_weight_only,
865-
uintx_weight_only,
866-
)
867-
868-
# Reset deprecation warning state, otherwise we won't log warnings here
869-
warnings.resetwarnings()
870-
871-
# Map from deprecated API to the args needed to instantiate it
872-
deprecated_apis_to_args = {
873-
float8_dynamic_activation_float8_weight: (),
874-
float8_static_activation_float8_weight: (torch.randn(3)),
875-
float8_weight_only: (),
876-
fpx_weight_only: (3, 2),
877-
gemlite_uintx_weight_only: (),
878-
int4_dynamic_activation_int4_weight: (),
879-
int4_weight_only: (),
880-
int8_dynamic_activation_int4_weight: (),
881-
int8_dynamic_activation_int8_weight: (),
882-
int8_weight_only: (),
883-
uintx_weight_only: (torch.uint4,),
884-
}
885-
886-
with warnings.catch_warnings(record=True) as _warnings:
887-
# Call each deprecated API twice
888-
for cls, args in deprecated_apis_to_args.items():
889-
cls(*args)
890-
cls(*args)
891-
892-
# Each call should trigger the warning only once
893-
self.assertEqual(len(_warnings), len(deprecated_apis_to_args))
894-
for w in _warnings:
895-
self.assertIn(
896-
"is deprecated and will be removed in a future release",
897-
str(w.message),
898-
)
899-
900849

901850
common_utils.instantiate_parametrized_tests(TestQuantFlow)
902851

torchao/quantization/__init__.py

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -64,21 +64,9 @@
6464
PlainLayout,
6565
TensorCoreTiledLayout,
6666
UIntXWeightOnlyConfig,
67-
float8_dynamic_activation_float8_weight,
68-
float8_static_activation_float8_weight,
69-
float8_weight_only,
70-
fpx_weight_only,
71-
gemlite_uintx_weight_only,
72-
int4_dynamic_activation_int4_weight,
73-
int4_weight_only,
74-
int8_dynamic_activation_int4_weight,
75-
int8_dynamic_activation_int8_semi_sparse_weight,
76-
int8_dynamic_activation_int8_weight,
77-
int8_weight_only,
7867
intx_quantization_aware_training,
7968
quantize_,
8069
swap_conv2d_1x1_to_linear,
81-
uintx_weight_only,
8270
)
8371
from .quant_primitives import (
8472
MappingType,
@@ -131,19 +119,7 @@
131119
"ALL_AUTOQUANT_CLASS_LIST",
132120
# top level API - manual
133121
"quantize_",
134-
"int4_dynamic_activation_int4_weight",
135-
"int8_dynamic_activation_int4_weight",
136-
"int8_dynamic_activation_int8_weight",
137-
"int8_dynamic_activation_int8_semi_sparse_weight",
138-
"int4_weight_only",
139-
"int8_weight_only",
140122
"intx_quantization_aware_training",
141-
"float8_weight_only",
142-
"float8_dynamic_activation_float8_weight",
143-
"float8_static_activation_float8_weight",
144-
"uintx_weight_only",
145-
"fpx_weight_only",
146-
"gemlite_uintx_weight_only",
147123
"swap_conv2d_1x1_to_linear",
148124
"Int4DynamicActivationInt4WeightConfig",
149125
"Int8DynamicActivationInt4WeightConfig",

torchao/quantization/quant_api.py

Lines changed: 1 addition & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,6 @@
9696
to_weight_tensor_with_linear_activation_quantization_metadata,
9797
)
9898
from torchao.utils import (
99-
_ConfigDeprecationWrapper,
10099
is_MI300,
101100
is_sm_at_least_89,
102101
is_sm_at_least_90,
@@ -148,18 +147,7 @@
148147
"autoquant",
149148
"_get_subclass_inserter",
150149
"quantize_",
151-
"int8_dynamic_activation_int4_weight",
152-
"int8_dynamic_activation_int8_weight",
153-
"int8_dynamic_activation_int8_semi_sparse_weight",
154-
"int4_weight_only",
155-
"int8_weight_only",
156150
"intx_quantization_aware_training",
157-
"float8_weight_only",
158-
"uintx_weight_only",
159-
"fpx_weight_only",
160-
"gemlite_uintx_weight_only",
161-
"float8_dynamic_activation_float8_weight",
162-
"float8_static_activation_float8_weight",
163151
"Int8DynActInt4WeightQuantizer",
164152
"Float8DynamicActivationFloat8SemiSparseWeightConfig",
165153
"ModuleFqnToConfig",
@@ -519,7 +507,7 @@ def quantize_(
519507
# Int8DynamicActivationInt8WeightConfig (optimized with int8 mm op and torch.compile)
520508
# Int4WeightOnlyConfig (optimized with int4 tinygemm kernel and torch.compile)
521509
# Int8WeightOnlyConfig (optimized with int8 mm op and torch.compile
522-
from torchao.quantization.quant_api import int4_weight_only
510+
from torchao.quantization.quant_api import Int4WeightOnlyConfig
523511
524512
m = nn.Sequential(nn.Linear(32, 1024), nn.Linear(1024, 32))
525513
quantize_(m, Int4WeightOnlyConfig(group_size=32, version=1))
@@ -641,12 +629,6 @@ def __post_init__(self):
641629
)
642630

643631

644-
# for BC
645-
int8_dynamic_activation_int4_weight = _ConfigDeprecationWrapper(
646-
"int8_dynamic_activation_int4_weight", Int8DynamicActivationInt4WeightConfig
647-
)
648-
649-
650632
@register_quantize_module_handler(Int8DynamicActivationInt4WeightConfig)
651633
def _int8_dynamic_activation_int4_weight_transform(
652634
module: torch.nn.Module,
@@ -1012,12 +994,6 @@ def __post_init__(self):
1012994
)
1013995

1014996

1015-
# for bc
1016-
int4_dynamic_activation_int4_weight = _ConfigDeprecationWrapper(
1017-
"int4_dynamic_activation_int4_weight", Int4DynamicActivationInt4WeightConfig
1018-
)
1019-
1020-
1021997
@register_quantize_module_handler(Int4DynamicActivationInt4WeightConfig)
1022998
def _int4_dynamic_activation_int4_weight_transform(
1023999
module: torch.nn.Module, config: Int4DynamicActivationInt4WeightConfig
@@ -1075,12 +1051,6 @@ def __post_init__(self):
10751051
)
10761052

10771053

1078-
# for BC
1079-
gemlite_uintx_weight_only = _ConfigDeprecationWrapper(
1080-
"gemlite_uintx_weight_only", GemliteUIntXWeightOnlyConfig
1081-
)
1082-
1083-
10841054
@register_quantize_module_handler(GemliteUIntXWeightOnlyConfig)
10851055
def _gemlite_uintx_weight_only_transform(
10861056
module: torch.nn.Module, config: GemliteUIntXWeightOnlyConfig
@@ -1158,11 +1128,6 @@ def __post_init__(self):
11581128
torch._C._log_api_usage_once("torchao.quantization.Int4WeightOnlyConfig")
11591129

11601130

1161-
# for BC
1162-
# TODO maybe change other callsites
1163-
int4_weight_only = _ConfigDeprecationWrapper("int4_weight_only", Int4WeightOnlyConfig)
1164-
1165-
11661131
def _int4_weight_only_quantize_tensor(weight, config):
11671132
# TODO(future PR): perhaps move this logic to a different file, to keep the API
11681133
# file clean of implementation details
@@ -1374,10 +1339,6 @@ def __post_init__(self):
13741339
torch._C._log_api_usage_once("torchao.quantization.Int8WeightOnlyConfig")
13751340

13761341

1377-
# for BC
1378-
int8_weight_only = _ConfigDeprecationWrapper("int8_weight_only", Int8WeightOnlyConfig)
1379-
1380-
13811342
def _int8_weight_only_quantize_tensor(weight, config):
13821343
mapping_type = MappingType.SYMMETRIC
13831344
target_dtype = torch.int8
@@ -1535,12 +1496,6 @@ def __post_init__(self):
15351496
)
15361497

15371498

1538-
# for BC
1539-
int8_dynamic_activation_int8_weight = _ConfigDeprecationWrapper(
1540-
"int8_dynamic_activation_int8_weight", Int8DynamicActivationInt8WeightConfig
1541-
)
1542-
1543-
15441499
def _int8_dynamic_activation_int8_weight_quantize_tensor(weight, config):
15451500
layout = config.layout
15461501
act_mapping_type = config.act_mapping_type
@@ -1646,12 +1601,6 @@ def __post_init__(self):
16461601
torch._C._log_api_usage_once("torchao.quantization.Float8WeightOnlyConfig")
16471602

16481603

1649-
# for BC
1650-
float8_weight_only = _ConfigDeprecationWrapper(
1651-
"float8_weight_only", Float8WeightOnlyConfig
1652-
)
1653-
1654-
16551604
def _float8_weight_only_quant_tensor(weight, config):
16561605
if config.version == 1:
16571606
warnings.warn(
@@ -1806,12 +1755,6 @@ def __post_init__(self):
18061755
self.granularity = [activation_granularity, weight_granularity]
18071756

18081757

1809-
# for bc
1810-
float8_dynamic_activation_float8_weight = _ConfigDeprecationWrapper(
1811-
"float8_dynamic_activation_float8_weight", Float8DynamicActivationFloat8WeightConfig
1812-
)
1813-
1814-
18151758
def _float8_dynamic_activation_float8_weight_quantize_tensor(weight, config):
18161759
activation_dtype = config.activation_dtype
18171760
weight_dtype = config.weight_dtype
@@ -1981,12 +1924,6 @@ def __post_init__(self):
19811924
)
19821925

19831926

1984-
# for bc
1985-
float8_static_activation_float8_weight = _ConfigDeprecationWrapper(
1986-
"float8_static_activation_float8_weight", Float8StaticActivationFloat8WeightConfig
1987-
)
1988-
1989-
19901927
@register_quantize_module_handler(Float8StaticActivationFloat8WeightConfig)
19911928
def _float8_static_activation_float8_weight_transform(
19921929
module: torch.nn.Module, config: Float8StaticActivationFloat8WeightConfig
@@ -2066,12 +2003,6 @@ def __post_init__(self):
20662003
torch._C._log_api_usage_once("torchao.quantization.UIntXWeightOnlyConfig")
20672004

20682005

2069-
# for BC
2070-
uintx_weight_only = _ConfigDeprecationWrapper(
2071-
"uintx_weight_only", UIntXWeightOnlyConfig
2072-
)
2073-
2074-
20752006
@register_quantize_module_handler(UIntXWeightOnlyConfig)
20762007
def _uintx_weight_only_transform(
20772008
module: torch.nn.Module, config: UIntXWeightOnlyConfig
@@ -2350,10 +2281,6 @@ def __post_init__(self):
23502281
torch._C._log_api_usage_once("torchao.quantization.FPXWeightOnlyConfig")
23512282

23522283

2353-
# for BC
2354-
fpx_weight_only = _ConfigDeprecationWrapper("fpx_weight_only", FPXWeightOnlyConfig)
2355-
2356-
23572284
@register_quantize_module_handler(FPXWeightOnlyConfig)
23582285
def _fpx_weight_only_transform(
23592286
module: torch.nn.Module, config: FPXWeightOnlyConfig

torchao/utils.py

Lines changed: 1 addition & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from functools import reduce
1313
from importlib.metadata import version
1414
from math import gcd
15-
from typing import Any, Callable, Optional, Type
15+
from typing import Any, Callable, Optional
1616

1717
import torch
1818
import torch.nn.utils.parametrize as parametrize
@@ -433,25 +433,6 @@ def __eq__(self, other):
433433
TORCH_VERSION_AFTER_2_2 = _deprecated_torch_version_after("2.2.0.dev")
434434

435435

436-
class _ConfigDeprecationWrapper:
437-
"""
438-
A deprecation wrapper that directs users from a deprecated "config function"
439-
(e.g. `int4_weight_only`) to the replacement config class.
440-
"""
441-
442-
def __init__(self, deprecated_name: str, config_cls: Type):
443-
self.deprecated_name = deprecated_name
444-
self.config_cls = config_cls
445-
446-
def __call__(self, *args, **kwargs):
447-
warnings.warn(
448-
f"`{self.deprecated_name}` is deprecated and will be removed in a future release. "
449-
f"Please use `{self.config_cls.__name__}` instead. Example usage:\n"
450-
f" quantize_(model, {self.config_cls.__name__}(...))"
451-
)
452-
return self.config_cls(*args, **kwargs)
453-
454-
455436
"""
456437
Helper function for implementing aten op or torch function dispatch
457438
and dispatching to these implementations.

0 commit comments

Comments
 (0)