Skip to content

Commit 8e63472

Browse files
authored
Revert "Remove config functions like int4_weight_only (#3145)" (#3192)
This reverts commit 53a66f8.
1 parent b644211 commit 8e63472

File tree

5 files changed

+170
-3
lines changed

5 files changed

+170
-3
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,7 @@ Our framework makes it straightforward to add tensor parallel support to your cu
261261
262262
We've added support for authoring and releasing [custom ops](./torchao/csrc/) that do not graph break with `torch.compile()`. We have a few examples you can follow
263263
264-
1. [fp6](torchao/dtypes/floatx/README.md) for 2x faster inference over fp16 with an easy to use API `quantize_(model, FPXWeightOnlyConfig(3, 2))`
264+
1. [fp6](torchao/dtypes/floatx/README.md) for 2x faster inference over fp16 with an easy to use API `quantize_(model, fpx_weight_only(3, 2))`
265265
2. [2:4 Sparse Marlin GEMM](https://github.com/pytorch/ao/pull/733) 2x speedups for FP16xINT4 kernels even at batch sizes up to 256
266266
3. [int4 tinygemm unpacker](https://github.com/pytorch/ao/pull/415) which makes it easier to switch quantized backends for inference
267267

test/quantization/test_quant_api.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import gc
1111
import tempfile
1212
import unittest
13+
import warnings
1314
from pathlib import Path
1415

1516
import torch
@@ -786,6 +787,56 @@ def test_int4wo_cuda_serialization(self):
786787
# load state_dict in cuda
787788
model.load_state_dict(sd, assign=True)
788789

790+
def test_config_deprecation(self):
791+
"""
792+
Test that old config functions like `int4_weight_only` trigger deprecation warnings.
793+
"""
794+
from torchao.quantization import (
795+
float8_dynamic_activation_float8_weight,
796+
float8_static_activation_float8_weight,
797+
float8_weight_only,
798+
fpx_weight_only,
799+
gemlite_uintx_weight_only,
800+
int4_dynamic_activation_int4_weight,
801+
int4_weight_only,
802+
int8_dynamic_activation_int4_weight,
803+
int8_dynamic_activation_int8_weight,
804+
int8_weight_only,
805+
uintx_weight_only,
806+
)
807+
808+
# Reset deprecation warning state, otherwise we won't log warnings here
809+
warnings.resetwarnings()
810+
811+
# Map from deprecated API to the args needed to instantiate it
812+
deprecated_apis_to_args = {
813+
float8_dynamic_activation_float8_weight: (),
814+
float8_static_activation_float8_weight: (torch.randn(3)),
815+
float8_weight_only: (),
816+
fpx_weight_only: (3, 2),
817+
gemlite_uintx_weight_only: (),
818+
int4_dynamic_activation_int4_weight: (),
819+
int4_weight_only: (),
820+
int8_dynamic_activation_int4_weight: (),
821+
int8_dynamic_activation_int8_weight: (),
822+
int8_weight_only: (),
823+
uintx_weight_only: (torch.uint4,),
824+
}
825+
826+
with warnings.catch_warnings(record=True) as _warnings:
827+
# Call each deprecated API twice
828+
for cls, args in deprecated_apis_to_args.items():
829+
cls(*args)
830+
cls(*args)
831+
832+
# Each call should trigger the warning only once
833+
self.assertEqual(len(_warnings), len(deprecated_apis_to_args))
834+
for w in _warnings:
835+
self.assertIn(
836+
"is deprecated and will be removed in a future release",
837+
str(w.message),
838+
)
839+
789840

790841
common_utils.instantiate_parametrized_tests(TestQuantFlow)
791842

torchao/quantization/__init__.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,9 +64,21 @@
6464
PlainLayout,
6565
TensorCoreTiledLayout,
6666
UIntXWeightOnlyConfig,
67+
float8_dynamic_activation_float8_weight,
68+
float8_static_activation_float8_weight,
69+
float8_weight_only,
70+
fpx_weight_only,
71+
gemlite_uintx_weight_only,
72+
int4_dynamic_activation_int4_weight,
73+
int4_weight_only,
74+
int8_dynamic_activation_int4_weight,
75+
int8_dynamic_activation_int8_semi_sparse_weight,
76+
int8_dynamic_activation_int8_weight,
77+
int8_weight_only,
6778
intx_quantization_aware_training,
6879
quantize_,
6980
swap_conv2d_1x1_to_linear,
81+
uintx_weight_only,
7082
)
7183
from .quant_primitives import (
7284
MappingType,
@@ -117,7 +129,19 @@
117129
"ALL_AUTOQUANT_CLASS_LIST",
118130
# top level API - manual
119131
"quantize_",
132+
"int4_dynamic_activation_int4_weight",
133+
"int8_dynamic_activation_int4_weight",
134+
"int8_dynamic_activation_int8_weight",
135+
"int8_dynamic_activation_int8_semi_sparse_weight",
136+
"int4_weight_only",
137+
"int8_weight_only",
120138
"intx_quantization_aware_training",
139+
"float8_weight_only",
140+
"float8_dynamic_activation_float8_weight",
141+
"float8_static_activation_float8_weight",
142+
"uintx_weight_only",
143+
"fpx_weight_only",
144+
"gemlite_uintx_weight_only",
121145
"swap_conv2d_1x1_to_linear",
122146
"Int4DynamicActivationInt4WeightConfig",
123147
"Int8DynamicActivationInt4WeightConfig",

torchao/quantization/quant_api.py

Lines changed: 74 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@
9696
to_weight_tensor_with_linear_activation_quantization_metadata,
9797
)
9898
from torchao.utils import (
99+
_ConfigDeprecationWrapper,
99100
is_MI300,
100101
is_sm_at_least_89,
101102
is_sm_at_least_90,
@@ -144,7 +145,18 @@
144145
"autoquant",
145146
"_get_subclass_inserter",
146147
"quantize_",
148+
"int8_dynamic_activation_int4_weight",
149+
"int8_dynamic_activation_int8_weight",
150+
"int8_dynamic_activation_int8_semi_sparse_weight",
151+
"int4_weight_only",
152+
"int8_weight_only",
147153
"intx_quantization_aware_training",
154+
"float8_weight_only",
155+
"uintx_weight_only",
156+
"fpx_weight_only",
157+
"gemlite_uintx_weight_only",
158+
"float8_dynamic_activation_float8_weight",
159+
"float8_static_activation_float8_weight",
148160
"Int8DynActInt4WeightQuantizer",
149161
"Float8DynamicActivationFloat8SemiSparseWeightConfig",
150162
"ModuleFqnToConfig",
@@ -491,7 +503,7 @@ def quantize_(
491503
# Int8DynamicActivationInt8WeightConfig (optimized with int8 mm op and torch.compile)
492504
# Int4WeightOnlyConfig (optimized with int4 tinygemm kernel and torch.compile)
493505
# Int8WeightOnlyConfig (optimized with int8 mm op and torch.compile
494-
from torchao.quantization.quant_api import Int4WeightOnlyConfig
506+
from torchao.quantization.quant_api import int4_weight_only
495507
496508
m = nn.Sequential(nn.Linear(32, 1024), nn.Linear(1024, 32))
497509
quantize_(m, Int4WeightOnlyConfig(group_size=32, version=1))
@@ -613,6 +625,12 @@ def __post_init__(self):
613625
)
614626

615627

628+
# for BC
629+
int8_dynamic_activation_int4_weight = _ConfigDeprecationWrapper(
630+
"int8_dynamic_activation_int4_weight", Int8DynamicActivationInt4WeightConfig
631+
)
632+
633+
616634
@register_quantize_module_handler(Int8DynamicActivationInt4WeightConfig)
617635
def _int8_dynamic_activation_int4_weight_transform(
618636
module: torch.nn.Module,
@@ -978,6 +996,12 @@ def __post_init__(self):
978996
)
979997

980998

999+
# for bc
1000+
int4_dynamic_activation_int4_weight = _ConfigDeprecationWrapper(
1001+
"int4_dynamic_activation_int4_weight", Int4DynamicActivationInt4WeightConfig
1002+
)
1003+
1004+
9811005
@register_quantize_module_handler(Int4DynamicActivationInt4WeightConfig)
9821006
def _int4_dynamic_activation_int4_weight_transform(
9831007
module: torch.nn.Module, config: Int4DynamicActivationInt4WeightConfig
@@ -1035,6 +1059,12 @@ def __post_init__(self):
10351059
)
10361060

10371061

1062+
# for BC
1063+
gemlite_uintx_weight_only = _ConfigDeprecationWrapper(
1064+
"gemlite_uintx_weight_only", GemliteUIntXWeightOnlyConfig
1065+
)
1066+
1067+
10381068
@register_quantize_module_handler(GemliteUIntXWeightOnlyConfig)
10391069
def _gemlite_uintx_weight_only_transform(
10401070
module: torch.nn.Module, config: GemliteUIntXWeightOnlyConfig
@@ -1112,6 +1142,11 @@ def __post_init__(self):
11121142
torch._C._log_api_usage_once("torchao.quantization.Int4WeightOnlyConfig")
11131143

11141144

1145+
# for BC
1146+
# TODO maybe change other callsites
1147+
int4_weight_only = _ConfigDeprecationWrapper("int4_weight_only", Int4WeightOnlyConfig)
1148+
1149+
11151150
def _int4_weight_only_quantize_tensor(weight, config):
11161151
# TODO(future PR): perhaps move this logic to a different file, to keep the API
11171152
# file clean of implementation details
@@ -1323,6 +1358,10 @@ def __post_init__(self):
13231358
torch._C._log_api_usage_once("torchao.quantization.Int8WeightOnlyConfig")
13241359

13251360

1361+
# for BC
1362+
int8_weight_only = _ConfigDeprecationWrapper("int8_weight_only", Int8WeightOnlyConfig)
1363+
1364+
13261365
def _int8_weight_only_quantize_tensor(weight, config):
13271366
mapping_type = MappingType.SYMMETRIC
13281367
target_dtype = torch.int8
@@ -1480,6 +1519,12 @@ def __post_init__(self):
14801519
)
14811520

14821521

1522+
# for BC
1523+
int8_dynamic_activation_int8_weight = _ConfigDeprecationWrapper(
1524+
"int8_dynamic_activation_int8_weight", Int8DynamicActivationInt8WeightConfig
1525+
)
1526+
1527+
14831528
def _int8_dynamic_activation_int8_weight_quantize_tensor(weight, config):
14841529
layout = config.layout
14851530
act_mapping_type = config.act_mapping_type
@@ -1585,6 +1630,12 @@ def __post_init__(self):
15851630
torch._C._log_api_usage_once("torchao.quantization.Float8WeightOnlyConfig")
15861631

15871632

1633+
# for BC
1634+
float8_weight_only = _ConfigDeprecationWrapper(
1635+
"float8_weight_only", Float8WeightOnlyConfig
1636+
)
1637+
1638+
15881639
def _float8_weight_only_quant_tensor(weight, config):
15891640
if config.version == 1:
15901641
warnings.warn(
@@ -1743,6 +1794,12 @@ def __post_init__(self):
17431794
self.granularity = [activation_granularity, weight_granularity]
17441795

17451796

1797+
# for bc
1798+
float8_dynamic_activation_float8_weight = _ConfigDeprecationWrapper(
1799+
"float8_dynamic_activation_float8_weight", Float8DynamicActivationFloat8WeightConfig
1800+
)
1801+
1802+
17461803
def _float8_dynamic_activation_float8_weight_quantize_tensor(weight, config):
17471804
activation_dtype = config.activation_dtype
17481805
weight_dtype = config.weight_dtype
@@ -1918,6 +1975,12 @@ def __post_init__(self):
19181975
)
19191976

19201977

1978+
# for bc
1979+
float8_static_activation_float8_weight = _ConfigDeprecationWrapper(
1980+
"float8_static_activation_float8_weight", Float8StaticActivationFloat8WeightConfig
1981+
)
1982+
1983+
19211984
@register_quantize_module_handler(Float8StaticActivationFloat8WeightConfig)
19221985
def _float8_static_activation_float8_weight_transform(
19231986
module: torch.nn.Module, config: Float8StaticActivationFloat8WeightConfig
@@ -2000,6 +2063,12 @@ def __post_init__(self):
20002063
torch._C._log_api_usage_once("torchao.quantization.UIntXWeightOnlyConfig")
20012064

20022065

2066+
# for BC
2067+
uintx_weight_only = _ConfigDeprecationWrapper(
2068+
"uintx_weight_only", UIntXWeightOnlyConfig
2069+
)
2070+
2071+
20032072
@register_quantize_module_handler(UIntXWeightOnlyConfig)
20042073
def _uintx_weight_only_transform(
20052074
module: torch.nn.Module, config: UIntXWeightOnlyConfig
@@ -2278,6 +2347,10 @@ def __post_init__(self):
22782347
torch._C._log_api_usage_once("torchao.quantization.FPXWeightOnlyConfig")
22792348

22802349

2350+
# for BC
2351+
fpx_weight_only = _ConfigDeprecationWrapper("fpx_weight_only", FPXWeightOnlyConfig)
2352+
2353+
22812354
@register_quantize_module_handler(FPXWeightOnlyConfig)
22822355
def _fpx_weight_only_transform(
22832356
module: torch.nn.Module, config: FPXWeightOnlyConfig

torchao/utils.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from functools import reduce
1313
from importlib.metadata import version
1414
from math import gcd
15-
from typing import Any, Callable, Optional
15+
from typing import Any, Callable, Optional, Type
1616

1717
import torch
1818
import torch.nn.utils.parametrize as parametrize
@@ -433,6 +433,25 @@ def __eq__(self, other):
433433
TORCH_VERSION_AFTER_2_2 = _deprecated_torch_version_after("2.2.0.dev")
434434

435435

436+
class _ConfigDeprecationWrapper:
437+
"""
438+
A deprecation wrapper that directs users from a deprecated "config function"
439+
(e.g. `int4_weight_only`) to the replacement config class.
440+
"""
441+
442+
def __init__(self, deprecated_name: str, config_cls: Type):
443+
self.deprecated_name = deprecated_name
444+
self.config_cls = config_cls
445+
446+
def __call__(self, *args, **kwargs):
447+
warnings.warn(
448+
f"`{self.deprecated_name}` is deprecated and will be removed in a future release. "
449+
f"Please use `{self.config_cls.__name__}` instead. Example usage:\n"
450+
f" quantize_(model, {self.config_cls.__name__}(...))"
451+
)
452+
return self.config_cls(*args, **kwargs)
453+
454+
436455
"""
437456
Helper function for implementing aten op or torch function dispatch
438457
and dispatching to these implementations.

0 commit comments

Comments
 (0)