Skip to content

Commit ed4cd34

Browse files
authored
Drop old quantization flows (#3115)
* drop old quantization flows * remove old quantization flows * inline quantization API in benchmarks * drop old apis * fix pre-commit * revert smoothquant implementation * revert smoothquant test * revert unrelated changes * drop deprecated apis * fix duplicate import
1 parent 16c7d09 commit ed4cd34

File tree

10 files changed

+24
-1176
lines changed

10 files changed

+24
-1176
lines changed

benchmarks/benchmark_aq.py

Lines changed: 24 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -16,30 +16,6 @@
1616
_replace_with_custom_fn_if_matches_filter,
1717
quantize_,
1818
)
19-
from torchao.quantization.subclass import (
20-
Int4WeightOnlyQuantizedLinearWeight,
21-
Int8WeightOnlyQuantizedLinearWeight,
22-
)
23-
24-
25-
def _int8wo_api(mod, **kwargs):
26-
quantize_(mod, Int8WeightOnlyConfig(**kwargs), set_inductor_config=False)
27-
28-
29-
def _int8da_int8w_api(mod, **kwargs):
30-
quantize_(
31-
mod,
32-
Int8DynamicActivationInt8WeightConfig(**kwargs),
33-
set_inductor_config=False,
34-
)
35-
36-
37-
def _int4wo_api(mod, **kwargs):
38-
kwargs_copy = kwargs.copy()
39-
if "groupsize" in kwargs_copy:
40-
kwargs_copy["group_size"] = kwargs_copy["groupsize"]
41-
del kwargs_copy["groupsize"]
42-
quantize_(mod, Int4WeightOnlyConfig(**kwargs_copy), set_inductor_config=False)
4319

4420

4521
class ToyLinearModel(torch.nn.Module):
@@ -68,34 +44,6 @@ def forward(self, x):
6844
return x
6945

7046

71-
def _ref_change_linear_weights_to_int8_dqtensors(model, filter_fn=None, **kwargs):
72-
"""
73-
The deprecated implementation for int8 dynamic quant API, used as a reference for
74-
numerics and performance
75-
"""
76-
from torchao.quantization.quant_api import (
77-
_get_subclass_inserter,
78-
_is_linear,
79-
)
80-
from torchao.quantization.subclass import Int8DynamicallyQuantizedLinearWeight
81-
82-
def _in_features_greater_than_16(mod, *args):
83-
return hasattr(mod, "in_features") and mod.in_features > 16
84-
85-
if filter_fn is None:
86-
filter_fn = lambda *args: _is_linear(*args) and _in_features_greater_than_16(
87-
*args
88-
)
89-
90-
_replace_with_custom_fn_if_matches_filter(
91-
model,
92-
_get_subclass_inserter(
93-
Int8DynamicallyQuantizedLinearWeight, enable_parametrization=False, **kwargs
94-
),
95-
filter_fn,
96-
)
97-
98-
9947
def _get_ref_change_linear_weights_to_woqtensors(deprecated_tenosr_subclass):
10048
def _ref_change_linear_weights_to_woqtensors(model, filter_fn=None, **kwargs):
10149
"""
@@ -117,38 +65,18 @@ def _ref_change_linear_weights_to_woqtensors(model, filter_fn=None, **kwargs):
11765
return _ref_change_linear_weights_to_woqtensors
11866

11967

120-
_ref_change_linear_weights_to_int8_woqtensors = (
121-
_get_ref_change_linear_weights_to_woqtensors(Int8WeightOnlyQuantizedLinearWeight)
122-
)
123-
_ref_change_linear_weights_to_int4_woqtensors = (
124-
_get_ref_change_linear_weights_to_woqtensors(Int4WeightOnlyQuantizedLinearWeight)
125-
)
126-
127-
12868
torch._dynamo.config.cache_size_limit = 50000
12969

13070

13171
@torch.no_grad
132-
def _bench_quantized_tensor_subclass_perf(api, ref_api, M, N, K, kwargs=None):
133-
if kwargs is None:
134-
kwargs = {}
135-
72+
def _bench_quantized_tensor_subclass_perf(api, config, M, N, K):
13673
m = ToyLinearModel(
13774
M, N, K, has_bias=True, dtype=torch.bfloat16, device="cuda"
13875
).eval()
13976
m_bf16 = copy.deepcopy(m)
140-
m_ref = copy.deepcopy(m)
14177
example_inputs = m.example_inputs()
14278

143-
api(m, **kwargs)
144-
145-
# reference
146-
ref_api(m_ref, **kwargs)
147-
148-
res = m(*example_inputs)
149-
ref = m_ref(*example_inputs)
150-
151-
assert torch.equal(res, ref)
79+
api(m, config) # Pass both model and config
15280

15381
# perf comparison
15482
from torchao.utils import benchmark_model
@@ -158,22 +86,17 @@ def _bench_quantized_tensor_subclass_perf(api, ref_api, M, N, K, kwargs=None):
15886
RUNS = 100
15987

16088
torch._dynamo.reset()
161-
m_ref = torch.compile(m_ref, mode="max-autotune", fullgraph=True)
162-
benchmark_model(m_ref, WARMUP, example_inputs)
163-
ref_elapsed_time = benchmark_model(m_ref, RUNS, example_inputs)
89+
m_bf16 = torch.compile(m_bf16, mode="max-autotune", fullgraph=True)
90+
benchmark_model(m_bf16, WARMUP, example_inputs)
91+
bf16_elapsed_time = benchmark_model(m_bf16, RUNS, example_inputs)
16492

16593
torch._dynamo.reset()
16694
m = torch.compile(m, mode="max-autotune", fullgraph=True)
16795
benchmark_model(m, WARMUP, example_inputs)
16896
elapsed_time = benchmark_model(m, RUNS, example_inputs)
16997

170-
torch._dynamo.reset()
171-
m_bf16 = torch.compile(m_bf16, mode="max-autotune", fullgraph=True)
172-
benchmark_model(m_bf16, WARMUP, example_inputs)
173-
bf16_elapsed_time = benchmark_model(m_bf16, RUNS, example_inputs)
174-
17598
print(
176-
f"{(M, N, K)}: elapsed time: {elapsed_time}, ref elapsed time: {ref_elapsed_time}, bf16 elapsed time: {bf16_elapsed_time}"
99+
f"{(M, N, K)}: elapsed time: {elapsed_time}, bf16 elapsed time: {bf16_elapsed_time}"
177100
)
178101

179102

@@ -182,24 +105,32 @@ def _bench_quantized_tensor_subclass_perf(api, ref_api, M, N, K, kwargs=None):
182105
(20, 2048, 2048),
183106
]
184107

185-
print("_int8da_int8w_api")
186-
108+
print("Int8DynamicActivationInt8WeightConfig")
187109
for M, N, K in all_shapes:
188110
_bench_quantized_tensor_subclass_perf(
189-
_int8da_int8w_api, _ref_change_linear_weights_to_int8_dqtensors, M, N, K
111+
quantize_,
112+
Int8DynamicActivationInt8WeightConfig(),
113+
M,
114+
N,
115+
K,
190116
)
191117

192-
print("_int8wo_api")
193-
118+
print("Int8WeightOnlyConfig")
194119
for M, N, K in all_shapes:
195120
_bench_quantized_tensor_subclass_perf(
196-
_int8wo_api, _ref_change_linear_weights_to_int8_woqtensors, M, N, K
121+
quantize_,
122+
Int8WeightOnlyConfig(),
123+
M,
124+
N,
125+
K,
197126
)
198127

199-
print("_int4wo_api")
200-
kwargs = {"groupsize": 32, "version": 1}
201-
128+
print("Int4WeightOnlyConfig")
202129
for M, N, K in all_shapes:
203130
_bench_quantized_tensor_subclass_perf(
204-
_int4wo_api, _ref_change_linear_weights_to_int4_woqtensors, M, N, K, kwargs
131+
quantize_,
132+
Int4WeightOnlyConfig(group_size=32),
133+
M,
134+
N,
135+
K,
205136
)

test/integration/test_integration.py

Lines changed: 0 additions & 118 deletions
Original file line numberDiff line numberDiff line change
@@ -56,11 +56,6 @@
5656
smooth_fq_linear_to_inference,
5757
swap_linear_with_smooth_fq_linear,
5858
)
59-
from torchao.quantization.subclass import (
60-
Int4WeightOnlyQuantizedLinearWeight,
61-
Int8DynamicallyQuantizedLinearWeight,
62-
Int8WeightOnlyQuantizedLinearWeight,
63-
)
6459
from torchao.quantization.utils import (
6560
LoggingTensorMode,
6661
_apply_logging_hook,
@@ -681,62 +676,6 @@ def _test_dequantize_impl(
681676
f"{lin.weight.__class__.__name__} failed transpose on dtype={test_dtype}",
682677
)
683678

684-
@parameterized.expand(COMMON_DEVICE_DTYPE)
685-
def test_dequantize_int8_dynamic_quant_subclass(self, device, dtype):
686-
self._test_dequantize_impl(
687-
Int8DynamicallyQuantizedLinearWeight.from_float,
688-
device,
689-
35,
690-
test_dtype=dtype,
691-
)
692-
693-
@parameterized.expand(COMMON_DEVICE_DTYPE)
694-
def test_dequantize_int8_weight_only_quant_subclass(self, device, dtype):
695-
self._test_dequantize_impl(
696-
Int8WeightOnlyQuantizedLinearWeight.from_float, device, 35, test_dtype=dtype
697-
)
698-
699-
@parameterized.expand(COMMON_DEVICE_DTYPE)
700-
@skip_if_rocm("ROCm enablement in progress")
701-
def test_dequantize_int4_weight_only_quant_subclass(self, device, dtype):
702-
if device == "cpu":
703-
self.skipTest(f"Temporarily skipping for {device}")
704-
if dtype != torch.bfloat16:
705-
self.skipTest("Currently only supports bfloat16.")
706-
for test_shape in [(16, 1024, 16)] + (
707-
[(1, 1024, 8)] if device == "cuda" else []
708-
):
709-
self._test_dequantize_impl(
710-
Int4WeightOnlyQuantizedLinearWeight.from_float,
711-
device,
712-
15,
713-
test_shape=test_shape,
714-
test_dtype=dtype,
715-
)
716-
717-
@parameterized.expand(COMMON_DEVICE_DTYPE)
718-
@skip_if_rocm("ROCm enablement in progress")
719-
def test_dequantize_int4_weight_only_quant_subclass_grouped(self, device, dtype):
720-
if device == "cpu":
721-
self.skipTest(f"Temporarily skipping for {device}")
722-
if dtype != torch.bfloat16:
723-
self.skipTest("Currently only supports bfloat16.")
724-
m_shapes = [16, 256] + ([1] if device == "cuda" else [])
725-
n_shapes = [16] + ([8, 13] if device == "cuda" else [])
726-
for groupsize in [256, 128]:
727-
for inner_k_tiles in [8, 4, 2]:
728-
for m in m_shapes:
729-
for n in n_shapes:
730-
self._test_dequantize_impl(
731-
lambda w: Int4WeightOnlyQuantizedLinearWeight.from_float(
732-
w, groupsize, inner_k_tiles
733-
),
734-
device,
735-
15,
736-
test_shape=[m, 256, n],
737-
test_dtype=dtype,
738-
)
739-
740679
@run_supported_device_dtype
741680
def _test_lin_weight_subclass_impl(
742681
self,
@@ -771,22 +710,6 @@ def _test_lin_weight_subclass_impl(
771710
f"{lin.weight.__class__.__name__} failed at compile with dtype={test_dtype}, (m, k, n)={test_shape}",
772711
)
773712

774-
@parameterized.expand(COMMON_DEVICE_DTYPE)
775-
def test_int8_dynamic_quant_subclass(self, device, dtype):
776-
self._test_lin_weight_subclass_impl(
777-
Int8DynamicallyQuantizedLinearWeight.from_float,
778-
device,
779-
35,
780-
test_dtype=dtype,
781-
)
782-
783-
@parameterized.expand(COMMON_DEVICE_DTYPE)
784-
def test_int8_weight_only_quant_subclass(self, device, dtype):
785-
undo_recommended_configs()
786-
self._test_lin_weight_subclass_impl(
787-
Int8WeightOnlyQuantizedLinearWeight.from_float, device, 40, test_dtype=dtype
788-
)
789-
790713
@parameterized.expand(COMMON_DEVICE_DTYPE)
791714
def test_aq_int8_dynamic_quant_subclass(self, device, dtype):
792715
self._test_lin_weight_subclass_impl(
@@ -891,46 +814,6 @@ def test_aq_float8_dynamic_quant_tensorwise_scaling_subclass(self, device, dtype
891814
test_dtype=dtype,
892815
)
893816

894-
@parameterized.expand(COMMON_DEVICE_DTYPE)
895-
@skip_if_rocm("ROCm enablement in progress")
896-
def test_int4_weight_only_quant_subclass(self, device, dtype):
897-
if device == "cpu":
898-
self.skipTest(f"Temporarily skipping for {device}")
899-
if dtype != torch.bfloat16:
900-
self.skipTest(f"Fails for {dtype}")
901-
for test_shape in [(16, 1024, 16)] + (
902-
[(1, 1024, 8)] if device == "cuda" else []
903-
):
904-
self._test_lin_weight_subclass_impl(
905-
Int4WeightOnlyQuantizedLinearWeight.from_float,
906-
device,
907-
10,
908-
test_shape=test_shape,
909-
test_dtype=dtype,
910-
)
911-
912-
@parameterized.expand(COMMON_DEVICE_DTYPE)
913-
@skip_if_rocm("ROCm enablement in progress")
914-
@unittest.skip("Skip to fix CI until we deprecate these APIs long term")
915-
def test_int4_weight_only_quant_subclass_grouped(self, device, dtype):
916-
if dtype != torch.bfloat16:
917-
self.skipTest(f"Fails for {dtype}")
918-
m_shapes = [16, 256] + ([1] if device == "cuda" else [])
919-
n_shapes = [16] + ([8, 13] if device == "cuda" else [])
920-
for groupsize in [128, 64]:
921-
for inner_k_tiles in [8, 4, 2]:
922-
for m in m_shapes:
923-
for n in n_shapes:
924-
self._test_lin_weight_subclass_impl(
925-
lambda w: Int4WeightOnlyQuantizedLinearWeight.from_float(
926-
w, groupsize, inner_k_tiles
927-
),
928-
device,
929-
10,
930-
test_shape=[m, 256, n],
931-
test_dtype=dtype,
932-
)
933-
934817
@torch.no_grad()
935818
@run_supported_device_dtype
936819
def _test_lin_weight_subclass_api_impl(
@@ -1120,7 +1003,6 @@ def test_dynamic_quant(self):
11201003

11211004
sqnr = compute_error(y_ref, y_test)
11221005
self.assertGreater(sqnr, 40.0)
1123-
# self.assertTrue(isinstance(m[0], DynamicallyPerAxisQuantizedLinear))
11241006

11251007

11261008
class TestWeightOnlyInt8Quant(unittest.TestCase):

0 commit comments

Comments
 (0)