Skip to content

Commit 82eb780

Browse files
authored
Move dyn_int8_act_int4_wei_cpu_layout to prototype/dtypes (#3299)
1 parent 0c081fd commit 82eb780

File tree

10 files changed

+388
-369
lines changed

10 files changed

+388
-369
lines changed

docs/source/api_ref_dtypes.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ Prototype
5252

5353
BlockSparseLayout
5454
CutlassInt4PackedLayout
55+
Int8DynamicActInt4WeightCPULayout
5556

5657
..
5758
_NF4Tensor - add after fixing torchao/dtypes/nf4tensor.py:docstring

test/dtypes/test_uintx.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@
33
#
44
# This source code is licensed under the BSD 3-Clause license found in the
55
# LICENSE file in the root directory of this source tree.
6+
import sys
7+
import warnings
8+
69
import pytest
710
import torch
811

@@ -165,3 +168,39 @@ def test_uintx_model_size(dtype):
165168
quantize_(linear[0], UIntXWeightOnlyConfig(dtype))
166169
quantized_size = get_model_size_in_bytes(linear)
167170
assert bf16_size * _dtype_to_ratio[dtype] == quantized_size
171+
172+
173+
def test_uintx_api_deprecation():
174+
"""
175+
Test that deprecated uintx APIs trigger deprecation warnings on import.
176+
TODO: Remove this test once the deprecated APIs have been removed.
177+
"""
178+
deprecated_apis = [
179+
(
180+
"Int8DynamicActInt4WeightCPULayout",
181+
"torchao.dtypes.uintx.dyn_int8_act_int4_wei_cpu_layout",
182+
),
183+
("CutlassInt4PackedLayout", "torchao.dtypes.uintx.cutlass_int4_packed_layout"),
184+
("BlockSparseLayout", "torchao.dtypes.uintx.block_sparse_layout"),
185+
]
186+
187+
for api_name, module_path in deprecated_apis:
188+
# Clear the cache to force re-importing and trigger the warning again
189+
modules_to_clear = [module_path, "torchao.dtypes"]
190+
for mod in modules_to_clear:
191+
if mod in sys.modules:
192+
del sys.modules[mod]
193+
194+
with warnings.catch_warnings(record=True) as w:
195+
warnings.simplefilter("always") # Ensure all warnings are captured
196+
197+
# Dynamically import the deprecated API
198+
exec(f"from torchao.dtypes import {api_name}")
199+
200+
assert any(
201+
issubclass(warning.category, DeprecationWarning)
202+
and api_name in str(warning.message)
203+
for warning in w
204+
), (
205+
f"Expected deprecation warning for {api_name}, got: {[str(warning.message) for warning in w]}"
206+
)

test/integration/test_integration.py

Lines changed: 0 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1948,32 +1948,5 @@ def test_benchmark_model_cpu(self):
19481948
assert self.run_benchmark_model("cpu") is not None
19491949

19501950

1951-
# TODO: Remove this test once the deprecated API has been removed
1952-
def test_cutlass_int4_packed_layout_deprecated():
1953-
import sys
1954-
import warnings
1955-
1956-
# We need to clear the cache to force re-importing and trigger the warning again.
1957-
modules_to_clear = [
1958-
"torchao.dtypes.uintx.cutlass_int4_packed_layout",
1959-
"torchao.dtypes",
1960-
]
1961-
for mod in modules_to_clear:
1962-
if mod in sys.modules:
1963-
del sys.modules[mod]
1964-
1965-
with warnings.catch_warnings(record=True) as w:
1966-
from torchao.dtypes import CutlassInt4PackedLayout # noqa: F401
1967-
1968-
warnings.simplefilter("always") # Ensure all warnings are captured
1969-
assert any(
1970-
issubclass(warning.category, DeprecationWarning)
1971-
and "CutlassInt4PackedLayout" in str(warning.message)
1972-
for warning in w
1973-
), (
1974-
f"Expected deprecation warning for CutlassInt4PackedLayout, got: {[str(warning.message) for warning in w]}"
1975-
)
1976-
1977-
19781951
if __name__ == "__main__":
19791952
unittest.main()

test/sparsity/test_sparse_api.py

Lines changed: 0 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -267,33 +267,6 @@ def test_sparse(self, compile):
267267

268268
torch.testing.assert_close(reference, sparse_result, rtol=1e-1, atol=1e-1)
269269

270-
# TODO: Remove this test once the deprecated API has been removed
271-
def test_sparse_deprecated(self):
272-
import sys
273-
import warnings
274-
275-
# We need to clear the cache to force re-importing and trigger the warning again.
276-
modules_to_clear = [
277-
"torchao.dtypes.uintx.block_sparse_layout",
278-
"torchao.dtypes",
279-
]
280-
for mod in modules_to_clear:
281-
if mod in sys.modules:
282-
del sys.modules[mod]
283-
284-
with warnings.catch_warnings(record=True) as w:
285-
from torchao.dtypes import BlockSparseLayout # noqa: F401
286-
287-
warnings.simplefilter("always") # Ensure all warnings are captured
288-
self.assertTrue(
289-
any(
290-
issubclass(warning.category, DeprecationWarning)
291-
and "BlockSparseLayout" in str(warning.message)
292-
for warning in w
293-
),
294-
f"Expected deprecation warning for BlockSparseLayout, got: {[str(w.message) for w in w]}",
295-
)
296-
297270

298271
common_utils.instantiate_parametrized_tests(TestSemiStructuredSparse)
299272
common_utils.instantiate_parametrized_tests(TestQuantSemiSparse)

torchao/dtypes/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
from .uintx import (
1717
Int4CPULayout,
1818
Int4XPULayout,
19-
Int8DynamicActInt4WeightCPULayout,
2019
MarlinQQQLayout,
2120
MarlinQQQTensor,
2221
MarlinSparseLayout,
@@ -29,6 +28,7 @@
2928
)
3029
from .uintx.block_sparse_layout import BlockSparseLayout
3130
from .uintx.cutlass_int4_packed_layout import CutlassInt4PackedLayout
31+
from .uintx.dyn_int8_act_int4_wei_cpu_layout import Int8DynamicActInt4WeightCPULayout
3232
from .utils import (
3333
Layout,
3434
PlainLayout,

torchao/dtypes/affine_quantized_tensor_ops.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,6 @@
2525
_linear_f16_bf16_act_floatx_weight_check,
2626
_linear_f16_bf16_act_floatx_weight_impl,
2727
)
28-
from torchao.dtypes.uintx.dyn_int8_act_int4_wei_cpu_layout import (
29-
_linear_int8_act_int4_weight_cpu_check,
30-
_linear_int8_act_int4_weight_cpu_impl,
31-
)
3228
from torchao.dtypes.uintx.gemlite_layout import (
3329
_linear_fp_act_int4_weight_gemlite_check,
3430
_linear_fp_act_int4_weight_gemlite_impl,
@@ -94,6 +90,10 @@
9490
_linear_int8_act_int4_weight_cutlass_check,
9591
_linear_int8_act_int4_weight_cutlass_impl,
9692
)
93+
from torchao.prototype.dtypes.uintx.dyn_int8_act_int4_wei_cpu_layout import (
94+
_linear_int8_act_int4_weight_cpu_check,
95+
_linear_int8_act_int4_weight_cpu_impl,
96+
)
9797
from torchao.quantization.quant_primitives import (
9898
ZeroPointDomain,
9999
_dequantize_affine_no_zero_point,

0 commit comments

Comments
 (0)