Skip to content

Commit 775d87b

Browse files
committed
Remove old QAT APIs
**Summary:** As a follow-up to #2641, which deprecated the old QAT APIs in 0.13.0, we remove them now in the next release 0.15.0. Fixes #2630. **Test Plan:** CI ghstack-source-id: 8ddff9e Pull Request resolved: #3147
1 parent fda8820 commit 775d87b

File tree

10 files changed

+15
-270
lines changed

10 files changed

+15
-270
lines changed

docs/source/api_ref_qat.rst

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,6 @@ Legacy QAT APIs
4242
:toctree: generated/
4343
:nosignatures:
4444

45-
IntXQuantizationAwareTrainingConfig
46-
FromIntXQuantizationAwareTrainingConfig
4745
Int4WeightOnlyQATQuantizer
4846
linear.Int4WeightOnlyQATLinear
4947
Int8DynActInt4WeightQATQuantizer

test/prototype/test_embedding.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,9 @@
1818
)
1919
from torchao.quantization.granularity import PerAxis, PerGroup
2020
from torchao.quantization.qat import (
21-
FromIntXQuantizationAwareTrainingConfig,
2221
Int4WeightOnlyEmbeddingQATQuantizer,
2322
IntxFakeQuantizeConfig,
24-
IntXQuantizationAwareTrainingConfig,
23+
QATConfig,
2524
)
2625
from torchao.quantization.quant_api import (
2726
Int8DynamicActivationIntxWeightConfig,
@@ -257,7 +256,7 @@ def test_identical_to_IntxWeightOnlyConfig(
257256
],
258257
name_func=lambda f, _, params: f.__name__ + f"_{params.kwargs}",
259258
)
260-
def test_identical_to_IntXQuantizationAwareTrainingConfig(
259+
def test_identical_to_QATConfig(
261260
self, weight_dtype, granularity, mapping_type, scale_dtype, model_dtype
262261
):
263262
# ASYMMETRIC in QAT is very different that PTQ configs
@@ -288,12 +287,12 @@ def test_identical_to_IntXQuantizationAwareTrainingConfig(
288287
)
289288
quantize_(
290289
model,
291-
IntXQuantizationAwareTrainingConfig(weight_config=weight_config),
290+
QATConfig(weight_config=weight_config, step="prepare"),
292291
embedding_filter,
293292
)
294293
prepared_out = model(indices)
295294

296-
quantize_(model, FromIntXQuantizationAwareTrainingConfig(), embedding_filter)
295+
quantize_(model, QATConfig(step="convert"), embedding_filter)
297296
quantize_(
298297
model,
299298
IntxWeightOnlyConfig(
@@ -355,7 +354,7 @@ def test_identical_to_Int4WeightOnlyEmbeddingQATQuantizer(
355354
prepared_out = model(indices)
356355

357356
# Convert model method 1
358-
quantize_(model, FromIntXQuantizationAwareTrainingConfig(), embedding_filter)
357+
quantize_(model, QATConfig(step="convert"), embedding_filter)
359358
quantize_(
360359
model,
361360
IntxWeightOnlyConfig(

test/quantization/test_int8_dynamic_activation_intx_weight_config_v1.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,9 @@
2020
from torchao.dtypes import PackedLinearInt8DynamicActivationIntxWeightLayout, QDQLayout
2121
from torchao.quantization.granularity import PerAxis, PerGroup
2222
from torchao.quantization.qat import (
23-
FromIntXQuantizationAwareTrainingConfig,
2423
Int8DynActInt4WeightQATQuantizer,
2524
IntxFakeQuantizeConfig,
26-
IntXQuantizationAwareTrainingConfig,
25+
QATConfig,
2726
)
2827
from torchao.quantization.quant_api import (
2928
Int8DynamicActivationInt4WeightConfig,
@@ -499,7 +498,7 @@ def test_identical_to_Int8DynamicActivationInt4WeightConfig(
499498
for model_dtype in [torch.float32, torch.bfloat16, torch.float16]
500499
],
501500
)
502-
def test_identical_to_IntXQuantizationAwareTrainingConfig(
501+
def test_identical_to_QATConfig(
503502
self,
504503
weight_dtype,
505504
group_size,
@@ -545,7 +544,11 @@ def test_identical_to_IntXQuantizationAwareTrainingConfig(
545544

546545
quantize_(
547546
model,
548-
IntXQuantizationAwareTrainingConfig(activation_config, weight_config),
547+
QATConfig(
548+
activation_config=activation_config,
549+
weight_config=weight_config,
550+
step="prepare",
551+
),
549552
)
550553
try:
551554
prepared_out = model(activations)
@@ -555,7 +558,7 @@ def test_identical_to_IntXQuantizationAwareTrainingConfig(
555558
return
556559
raise e
557560

558-
quantize_(model, FromIntXQuantizationAwareTrainingConfig())
561+
quantize_(model, QATConfig(step="convert"))
559562
quantize_(
560563
model,
561564
Int8DynamicActivationIntxWeightConfig(
@@ -606,7 +609,7 @@ def test_identical_to_Int8DynActInt4WeightQATQuantizer(
606609
prepared_out = model(activations)
607610

608611
# Convert model method 1
609-
quantize_(model, FromIntXQuantizationAwareTrainingConfig())
612+
quantize_(model, QATConfig(step="convert"))
610613
quantize_(
611614
model,
612615
Int8DynamicActivationIntxWeightConfig(

test/quantization/test_qat.py

Lines changed: 0 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99

1010
import copy
1111
import unittest
12-
import warnings
1312
from typing import List, Type
1413

1514
import torch
@@ -39,8 +38,6 @@
3938
)
4039
from torchao.quantization.qat.api import (
4140
ComposableQATQuantizer,
42-
FromIntXQuantizationAwareTrainingConfig,
43-
IntXQuantizationAwareTrainingConfig,
4441
QATConfig,
4542
QATStep,
4643
initialize_fake_quantizers,
@@ -1718,95 +1715,6 @@ def test_qat_fp8a4w_quantizer(self):
17181715
self.assertNotEqual(torch.count_nonzero(new_weight.grad), 0)
17191716
self.assertFalse(torch.equal(new_weight, prev_weight))
17201717

1721-
def test_legacy_quantize_api_e2e(self):
1722-
"""
1723-
Test that the following two APIs are numerically equivalent:
1724-
1725-
New API:
1726-
quantize_(model, QATConfig(Int8DynamicActivationInt4WeightConfig(), step="prepare"))
1727-
quantize_(model, QATConfig(Int8DynamicActivationInt4WeightConfig(), step="convert"))
1728-
1729-
Old API:
1730-
quantize_(model, IntXQuantizationAwareTrainingConfig(...))
1731-
quantize_(model, FromIntXQuantizationAwareTrainingConfig())
1732-
quantize_(model, Int8DynamicActivationInt4WeightConfig())
1733-
"""
1734-
group_size = 16
1735-
torch.manual_seed(self.SEED)
1736-
m = M()
1737-
baseline_model = copy.deepcopy(m)
1738-
1739-
# Baseline prepare
1740-
act_config = IntxFakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False)
1741-
weight_config = IntxFakeQuantizeConfig(TorchAODType.INT4, group_size=group_size)
1742-
old_qat_config = IntXQuantizationAwareTrainingConfig(act_config, weight_config)
1743-
quantize_(baseline_model, old_qat_config)
1744-
1745-
# QATConfig prepare
1746-
base_config = Int8DynamicActivationInt4WeightConfig(group_size=group_size)
1747-
quantize_(m, QATConfig(base_config, step="prepare"))
1748-
1749-
# Compare prepared values
1750-
torch.manual_seed(self.SEED)
1751-
x = m.example_inputs()
1752-
x2 = copy.deepcopy(x)
1753-
out = m(*x)
1754-
baseline_out = baseline_model(*x2)
1755-
torch.testing.assert_close(out, baseline_out, atol=0, rtol=0)
1756-
1757-
# Baseline convert
1758-
quantize_(baseline_model, FromIntXQuantizationAwareTrainingConfig())
1759-
quantize_(baseline_model, base_config)
1760-
1761-
# quantize_ convert
1762-
quantize_(m, QATConfig(base_config, step="convert"))
1763-
1764-
# Compare converted values
1765-
torch.manual_seed(self.SEED)
1766-
x = m.example_inputs()
1767-
x2 = copy.deepcopy(x)
1768-
out = m(*x)
1769-
baseline_out = baseline_model(*x2)
1770-
torch.testing.assert_close(out, baseline_out, atol=0, rtol=0)
1771-
1772-
def test_qat_api_deprecation(self):
1773-
"""
1774-
Test that the appropriate deprecation warning is logged exactly once per class.
1775-
"""
1776-
from torchao.quantization.qat import (
1777-
FakeQuantizeConfig,
1778-
FakeQuantizer,
1779-
from_intx_quantization_aware_training,
1780-
intx_quantization_aware_training,
1781-
)
1782-
1783-
# Reset deprecation warning state, otherwise we won't log warnings here
1784-
warnings.resetwarnings()
1785-
1786-
# Map from deprecated API to the args needed to instantiate it
1787-
deprecated_apis_to_args = {
1788-
IntXQuantizationAwareTrainingConfig: (),
1789-
FromIntXQuantizationAwareTrainingConfig: (),
1790-
intx_quantization_aware_training: (),
1791-
from_intx_quantization_aware_training: (),
1792-
FakeQuantizeConfig: (torch.int8, "per_channel"),
1793-
FakeQuantizer: (IntxFakeQuantizeConfig(torch.int8, "per_channel"),),
1794-
}
1795-
1796-
with warnings.catch_warnings(record=True) as _warnings:
1797-
# Call each deprecated API twice
1798-
for cls, args in deprecated_apis_to_args.items():
1799-
cls(*args)
1800-
cls(*args)
1801-
1802-
# Each call should trigger the warning only once
1803-
self.assertEqual(len(_warnings), len(deprecated_apis_to_args))
1804-
for w in _warnings:
1805-
self.assertIn(
1806-
"is deprecated and will be removed in a future release",
1807-
str(w.message),
1808-
)
1809-
18101718
def test_qat_api_convert_no_quantization(self):
18111719
"""
18121720
Test that `QATConfig(step="convert")` swaps back to nn modules without quantization.

torchao/quantization/prototype/qat/api.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from torchao.quantization.qat.api import (
22
ComposableQATQuantizer,
3-
FakeQuantizeConfig,
3+
IntxFakeQuantizeConfig as FakeQuantizeConfig,
44
)
55

66
__all__ = [

torchao/quantization/qat/__init__.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,19 @@
11
from .api import (
22
ComposableQATQuantizer,
3-
FromIntXQuantizationAwareTrainingConfig,
4-
IntXQuantizationAwareTrainingConfig,
53
QATConfig,
64
QATStep,
7-
from_intx_quantization_aware_training,
85
initialize_fake_quantizers,
9-
intx_quantization_aware_training,
106
)
117
from .embedding import (
128
FakeQuantizedEmbedding,
139
Int4WeightOnlyEmbeddingQATQuantizer,
1410
)
1511
from .fake_quantize_config import (
16-
FakeQuantizeConfig,
1712
FakeQuantizeConfigBase,
1813
Float8FakeQuantizeConfig,
1914
IntxFakeQuantizeConfig,
2015
)
2116
from .fake_quantizer import (
22-
FakeQuantizer,
2317
FakeQuantizerBase,
2418
Float8FakeQuantizer,
2519
IntxFakeQuantizer,
@@ -50,11 +44,4 @@
5044
"Int4WeightOnlyEmbeddingQATQuantizer",
5145
"Int4WeightOnlyQATQuantizer",
5246
"Int8DynActInt4WeightQATQuantizer",
53-
# for BC
54-
"FakeQuantizer",
55-
"FakeQuantizeConfig",
56-
"from_intx_quantization_aware_training",
57-
"FromIntXQuantizationAwareTrainingConfig",
58-
"intx_quantization_aware_training",
59-
"IntXQuantizationAwareTrainingConfig",
6047
]

torchao/quantization/qat/api.py

Lines changed: 0 additions & 115 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,11 @@
2121

2222
from .embedding import FakeQuantizedEmbedding
2323
from .fake_quantize_config import (
24-
FakeQuantizeConfig, # noqa: F401, for BC
2524
FakeQuantizeConfigBase,
2625
IntxFakeQuantizeConfig,
2726
_infer_fake_quantize_configs,
2827
)
2928
from .linear import FakeQuantizedLinear
30-
from .utils import _log_deprecation_warning
3129

3230

3331
class QATStep(str, Enum):
@@ -288,119 +286,6 @@ def _qat_config_transform(
288286
return module
289287

290288

291-
@dataclass
292-
class IntXQuantizationAwareTrainingConfig(AOBaseConfig):
293-
"""
294-
(Deprecated) Please use :class:`~torchao.quantization.qat.QATConfig` instead.
295-
296-
Config for applying fake quantization to a `torch.nn.Module`.
297-
to be used with :func:`~torchao.quantization.quant_api.quantize_`.
298-
299-
Example usage::
300-
301-
from torchao.quantization import quantize_
302-
from torchao.quantization.qat import IntxFakeQuantizeConfig
303-
activation_config = IntxFakeQuantizeConfig(
304-
torch.int8, "per_token", is_symmetric=False,
305-
)
306-
weight_config = IntxFakeQuantizeConfig(
307-
torch.int4, group_size=32, is_symmetric=True,
308-
)
309-
quantize_(
310-
model,
311-
IntXQuantizationAwareTrainingConfig(activation_config, weight_config),
312-
)
313-
314-
Note: If the config is applied on a module that is not
315-
`torch.nn.Linear` or `torch.nn.Embedding`, or it is applied on
316-
`torch.nn.Embedding` with an activation config, then we will raise
317-
ValueError as these are not supported.
318-
"""
319-
320-
activation_config: Optional[FakeQuantizeConfigBase] = None
321-
weight_config: Optional[FakeQuantizeConfigBase] = None
322-
323-
def __post_init__(self):
324-
_log_deprecation_warning(self)
325-
326-
327-
# for BC
328-
class intx_quantization_aware_training(IntXQuantizationAwareTrainingConfig):
329-
pass
330-
331-
332-
@register_quantize_module_handler(IntXQuantizationAwareTrainingConfig)
333-
def _intx_quantization_aware_training_transform(
334-
module: torch.nn.Module,
335-
config: IntXQuantizationAwareTrainingConfig,
336-
) -> torch.nn.Module:
337-
mod = module
338-
activation_config = config.activation_config
339-
weight_config = config.weight_config
340-
341-
if isinstance(mod, torch.nn.Linear):
342-
return FakeQuantizedLinear.from_linear(
343-
mod,
344-
activation_config,
345-
weight_config,
346-
)
347-
elif isinstance(mod, torch.nn.Embedding):
348-
if activation_config is not None:
349-
raise ValueError(
350-
"Activation fake quantization is not supported for embedding"
351-
)
352-
return FakeQuantizedEmbedding.from_embedding(mod, weight_config)
353-
else:
354-
raise ValueError("Module of type '%s' does not have QAT support" % type(mod))
355-
356-
357-
@dataclass
358-
class FromIntXQuantizationAwareTrainingConfig(AOBaseConfig):
359-
"""
360-
(Deprecated) Please use :class:`~torchao.quantization.qat.QATConfig` instead.
361-
362-
Config for converting a model with fake quantized modules,
363-
such as :func:`~torchao.quantization.qat.linear.FakeQuantizedLinear`
364-
and :func:`~torchao.quantization.qat.linear.FakeQuantizedEmbedding`,
365-
back to model with the original, corresponding modules without
366-
fake quantization. This should be used with
367-
:func:`~torchao.quantization.quant_api.quantize_`.
368-
369-
Example usage::
370-
371-
from torchao.quantization import quantize_
372-
quantize_(
373-
model_with_fake_quantized_linears,
374-
FromIntXQuantizationAwareTrainingConfig(),
375-
)
376-
"""
377-
378-
def __post_init__(self):
379-
_log_deprecation_warning(self)
380-
381-
382-
# for BC
383-
class from_intx_quantization_aware_training(FromIntXQuantizationAwareTrainingConfig):
384-
pass
385-
386-
387-
@register_quantize_module_handler(FromIntXQuantizationAwareTrainingConfig)
388-
def _from_intx_quantization_aware_training_transform(
389-
mod: torch.nn.Module,
390-
config: FromIntXQuantizationAwareTrainingConfig,
391-
) -> torch.nn.Module:
392-
"""
393-
If the given module is a fake quantized module, return the original
394-
corresponding version of the module without fake quantization.
395-
"""
396-
if isinstance(mod, FakeQuantizedLinear):
397-
return mod.to_linear()
398-
elif isinstance(mod, FakeQuantizedEmbedding):
399-
return mod.to_embedding()
400-
else:
401-
return mod
402-
403-
404289
class ComposableQATQuantizer(TwoStepQuantizer):
405290
"""
406291
Composable quantizer that users can use to apply multiple QAT quantizers easily.

0 commit comments

Comments
 (0)