Skip to content

Commit 42f28f7

Browse files
authored
Move block_sparse_layout to prototype (#3276)
1 parent 1c89061 commit 42f28f7

File tree

9 files changed

+314
-235
lines changed

9 files changed

+314
-235
lines changed

docs/source/api_ref_dtypes.rst

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@ Layouts and Tensor Subclasses
2222
FloatxTensor
2323
FloatxTensorCoreLayout
2424
MarlinSparseLayout
25-
BlockSparseLayout
2625
UintxLayout
2726
MarlinQQQTensor
2827
MarlinQQQLayout
@@ -43,6 +42,17 @@ Quantization techniques
4342
to_affine_quantized_floatx_static
4443
to_marlinqqq_quantized_intx
4544
to_nf4
45+
46+
Prototype
47+
---------
48+
.. currentmodule:: torchao.prototype.dtypes
49+
50+
.. autosummary::
51+
:toctree: generated/
52+
:nosignatures:
53+
54+
BlockSparseLayout
55+
4656
..
4757
_NF4Tensor - add after fixing torchao/dtypes/nf4tensor.py:docstring
4858
of torchao.dtypes.nf4tensor.NF4Tensor.dequantize_scalers:6:Unexpected indentation.

test/sparsity/test_sparse_api.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,7 @@ def test_sparse(self, compile):
253253
quantize_(model_copy, Int8DynamicActivationInt8WeightConfig())
254254
reference = model_copy(input)
255255

256-
from torchao.dtypes import BlockSparseLayout
256+
from torchao.prototype.dtypes import BlockSparseLayout
257257

258258
quantize_(
259259
model,
@@ -267,6 +267,33 @@ def test_sparse(self, compile):
267267

268268
torch.testing.assert_close(reference, sparse_result, rtol=1e-1, atol=1e-1)
269269

270+
# TODO: Remove this test once the deprecated API has been removed
271+
def test_sparse_deprecated(self):
272+
import sys
273+
import warnings
274+
275+
# We need to clear the cache to force re-importing and trigger the warning again.
276+
modules_to_clear = [
277+
"torchao.dtypes.uintx.block_sparse_layout",
278+
"torchao.dtypes",
279+
]
280+
for mod in modules_to_clear:
281+
if mod in sys.modules:
282+
del sys.modules[mod]
283+
284+
with warnings.catch_warnings(record=True) as w:
285+
from torchao.dtypes import BlockSparseLayout # noqa: F401
286+
287+
warnings.simplefilter("always") # Ensure all warnings are captured
288+
self.assertTrue(
289+
any(
290+
issubclass(warning.category, DeprecationWarning)
291+
and "BlockSparseLayout" in str(warning.message)
292+
for warning in w
293+
),
294+
f"Expected deprecation warning for BlockSparseLayout, got: {[str(w.message) for w in w]}",
295+
)
296+
270297

271298
common_utils.instantiate_parametrized_tests(TestSemiStructuredSparse)
272299
common_utils.instantiate_parametrized_tests(TestQuantSemiSparse)

torchao/dtypes/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
)
1515
from .nf4tensor import NF4Tensor, to_nf4
1616
from .uintx import (
17-
BlockSparseLayout,
1817
CutlassInt4PackedLayout,
1918
Int4CPULayout,
2019
Int4XPULayout,
@@ -29,6 +28,7 @@
2928
UintxLayout,
3029
to_marlinqqq_quantized_intx,
3130
)
31+
from .uintx.block_sparse_layout import BlockSparseLayout
3232
from .utils import (
3333
Layout,
3434
PlainLayout,

torchao/dtypes/affine_quantized_tensor_ops.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,6 @@
2525
_linear_f16_bf16_act_floatx_weight_check,
2626
_linear_f16_bf16_act_floatx_weight_impl,
2727
)
28-
from torchao.dtypes.uintx.block_sparse_layout import (
29-
_linear_int8_act_int8_weight_block_sparse_check,
30-
_linear_int8_act_int8_weight_block_sparse_impl,
31-
)
3228
from torchao.dtypes.uintx.cutlass_int4_packed_layout import (
3329
_linear_int4_act_int4_weight_cutlass_check,
3430
_linear_int4_act_int4_weight_cutlass_impl,
@@ -94,6 +90,10 @@
9490
_linear_bf16_act_uint4_weight_check,
9591
_linear_bf16_act_uint4_weight_impl,
9692
)
93+
from torchao.prototype.dtypes.uintx.block_sparse_layout import (
94+
_linear_int8_act_int8_weight_block_sparse_check,
95+
_linear_int8_act_int8_weight_block_sparse_impl,
96+
)
9797
from torchao.quantization.quant_primitives import (
9898
ZeroPointDomain,
9999
_dequantize_affine_no_zero_point,

torchao/dtypes/uintx/__init__.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,3 @@
1-
from .block_sparse_layout import (
2-
BlockSparseLayout,
3-
)
41
from .cutlass_int4_packed_layout import (
52
CutlassInt4PackedLayout,
63
)
@@ -39,7 +36,6 @@
3936

4037
__all__ = [
4138
"UintxLayout",
42-
"BlockSparseLayout",
4339
"MarlinSparseLayout",
4440
"SemiSparseLayout",
4541
"TensorCoreTiledLayout",

torchao/dtypes/uintx/block_sparse_layout.py

Lines changed: 15 additions & 224 deletions
Original file line numberDiff line numberDiff line change
@@ -3,231 +3,22 @@
33
#
44
# This source code is licensed under the BSD 3-Clause license found in the
55
# LICENSE file in the root directory of this source tree.
6-
import logging
7-
from dataclasses import dataclass
8-
from typing import Optional, Tuple
96

10-
import torch
11-
from torch.utils._python_dispatch import (
12-
return_and_correct_aliasing,
13-
)
7+
# Backward compatibility stub - imports from the new location
8+
import warnings
149

15-
from torchao.dtypes.affine_quantized_tensor import (
16-
AffineQuantizedTensor,
17-
register_layout,
18-
)
19-
from torchao.dtypes.uintx.plain_layout import (
20-
PlainAQTTensorImpl,
21-
_aqt_is_int8_reduced_range,
22-
)
23-
from torchao.dtypes.utils import (
24-
Layout,
25-
PlainLayout,
10+
warnings.warn(
11+
"Importing BlockSparseLayout from torchao.dtypes is deprecated. "
12+
"Please use 'from torchao.prototype.dtypes import BlockSparseLayout' instead. "
13+
"This import path will be removed in a future torchao release. "
14+
"Please check issue: https://github.com/pytorch/ao/issues/2752 for more details. ",
15+
DeprecationWarning,
16+
stacklevel=2,
2617
)
2718

28-
logger = logging.getLogger(__name__)
29-
30-
aten = torch.ops.aten
31-
32-
33-
@dataclass(frozen=True)
34-
class BlockSparseLayout(Layout):
35-
"""BlockSparseLayout is a data class that represents the layout of a block sparse matrix.
36-
37-
Attributes:
38-
blocksize (int): The size of the blocks in the sparse matrix. Default is 64.
39-
"""
40-
41-
blocksize: int = 64
42-
43-
44-
@register_layout(BlockSparseLayout)
45-
class BlockSparseAQTTensorImpl(PlainAQTTensorImpl):
46-
bsr_crow_indices: Optional[torch.Tensor]
47-
bsr_col_indices: Optional[torch.Tensor]
48-
bsr_values: Optional[torch.Tensor]
49-
scale: Optional[torch.Tensor]
50-
zero_point: Optional[torch.Tensor]
51-
52-
__slots__ = [
53-
"bsr_crow_indices",
54-
"bsr_col_indices",
55-
"bsr_values",
56-
"scale",
57-
"zero_point",
58-
]
59-
60-
@staticmethod
61-
def __new__( # noqa: PYI034
62-
cls,
63-
shape: torch.Size,
64-
bsr_crow_indices: Optional[torch.Tensor],
65-
bsr_col_indices: Optional[torch.Tensor],
66-
bsr_values: Optional[torch.Tensor],
67-
scale: Optional[torch.Tensor],
68-
zero_point: Optional[torch.Tensor],
69-
_layout: Layout,
70-
requires_grad: bool = False,
71-
):
72-
if bsr_values is None:
73-
raise ValueError("bsr values must be provided!")
74-
else:
75-
previous_tensor = bsr_values
76-
77-
kwargs = {
78-
"device": previous_tensor.device,
79-
"dtype": previous_tensor.dtype,
80-
"layout": previous_tensor.layout,
81-
"requires_grad": requires_grad,
82-
}
83-
return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined]
84-
85-
def __init__( # noqa: PYI034
86-
self,
87-
shape: torch.Size,
88-
bsr_crow_indices: Optional[torch.Tensor],
89-
bsr_col_indices: Optional[torch.Tensor],
90-
bsr_values: Optional[torch.Tensor],
91-
scale: Optional[torch.Tensor],
92-
zero_point: Optional[torch.Tensor],
93-
_layout: Layout,
94-
requires_grad: bool = False,
95-
):
96-
self.bsr_crow_indices = bsr_crow_indices
97-
self.bsr_col_indices = bsr_col_indices
98-
self.bsr_values = bsr_values
99-
self.scale = scale
100-
self.zero_point = zero_point
101-
self._layout = _layout
102-
103-
def __tensor_flatten__(self):
104-
inner_tensors = list(
105-
filter(lambda x: getattr(self, x) is not None, self.__slots__)
106-
)
107-
tensor_meta = (self.shape, self._layout, self.requires_grad)
108-
return inner_tensors, tensor_meta
109-
110-
@classmethod
111-
def __tensor_unflatten__(
112-
cls,
113-
inner_tensors,
114-
tensor_meta: Tuple[torch.Size, bool],
115-
outer_size,
116-
outer_stride,
117-
) -> torch.Tensor:
118-
shape, _layout, requires_grad = tensor_meta
119-
return cls(
120-
shape=shape,
121-
bsr_crow_indices=inner_tensors.get("bsr_crow_indices", None),
122-
bsr_col_indices=inner_tensors.get("bsr_col_indices", None),
123-
bsr_values=inner_tensors.get("bsr_values", None),
124-
scale=inner_tensors.get("scale", None),
125-
zero_point=inner_tensors.get("zero_point", None),
126-
_layout=_layout,
127-
requires_grad=requires_grad,
128-
)
129-
130-
@classmethod
131-
def from_plain(cls, int_data, scale, zero_point, _layout):
132-
bsr_tensor = int_data.to_sparse_bsr(_layout.blocksize)
133-
return cls(
134-
shape=int_data.shape,
135-
bsr_crow_indices=bsr_tensor.crow_indices(),
136-
bsr_col_indices=bsr_tensor.col_indices(),
137-
bsr_values=bsr_tensor.values(),
138-
scale=scale,
139-
zero_point=zero_point,
140-
_layout=_layout,
141-
requires_grad=False,
142-
)
143-
144-
def get_plain(self):
145-
int_data_expanded = torch.ops.blocksparse.bsr_to_dense(
146-
self.crow_indices(),
147-
self.col_indices(),
148-
self.values(),
149-
self.shape[0],
150-
self.shape[1],
151-
)
152-
return int_data_expanded, self.scale, self.zero_point
153-
154-
def _apply_fn_to_data(self, func):
155-
return self.__class__(
156-
shape=self.shape,
157-
bsr_crow_indices=func(self.bsr_crow_indices),
158-
bsr_col_indices=func(self.bsr_col_indices),
159-
bsr_values=func(self.bsr_values),
160-
scale=self.scale,
161-
zero_point=self.zero_point,
162-
_layout=self._layout,
163-
requires_grad=self.requires_grad,
164-
)
165-
166-
@classmethod
167-
def __torch_dispatch__(cls, func, types, args, kwargs):
168-
kwargs = {} if kwargs is None else kwargs
169-
170-
if func is aten.detach.default:
171-
return return_and_correct_aliasing(
172-
func, args, kwargs, args[0]._apply_fn_to_data(torch.detach)
173-
)
174-
if func is aten.clone.default:
175-
return return_and_correct_aliasing(
176-
func, args, kwargs, args[0]._apply_fn_to_data(torch.clone)
177-
)
178-
179-
# Need the following for bsr specific functions
180-
if func is aten.crow_indices.default:
181-
return args[0].bsr_crow_indices.detach()
182-
183-
if func is aten.col_indices.default:
184-
return args[0].bsr_col_indices.detach()
185-
186-
if func is aten.values.default:
187-
return args[0].bsr_values.detach()
188-
189-
if func is aten._nnz.default:
190-
return args[0].bsr_values.shape[0]
191-
192-
raise NotImplementedError(
193-
f"BlockSparseAQTTensorImpl dispatch: attempting to run {func}, this is not supported"
194-
)
195-
196-
197-
def _linear_int8_act_int8_weight_block_sparse_check(input_tensor, weight_tensor, bias):
198-
return (
199-
isinstance(input_tensor, AffineQuantizedTensor)
200-
and _aqt_is_int8_reduced_range(input_tensor)
201-
and isinstance(weight_tensor, AffineQuantizedTensor)
202-
and weight_tensor.is_cuda
203-
and input_tensor.dtype == weight_tensor.dtype
204-
and isinstance(input_tensor._layout, PlainLayout)
205-
and isinstance(weight_tensor._layout, BlockSparseLayout)
206-
)
207-
208-
209-
def _linear_int8_act_int8_weight_block_sparse_impl(input_tensor, weight_tensor, bias):
210-
x_vals_int8 = input_tensor.tensor_impl.int_data
211-
x_scales = input_tensor.tensor_impl.scale
212-
w_vals = weight_tensor.tensor_impl
213-
w_scales = weight_tensor.tensor_impl.scale
214-
tmp = x_vals_int8.reshape(-1, x_vals_int8.shape[-1])
215-
tmp_t = tmp.t()
216-
217-
y = torch.ops.blocksparse.int_addmm(
218-
w_vals.crow_indices(),
219-
w_vals.col_indices(),
220-
w_vals.values(),
221-
tmp_t,
222-
w_scales,
223-
x_scales.reshape(-1),
224-
)
225-
y_shape = (*x_vals_int8.shape[:-1], w_scales.shape[-1])
226-
y = y.reshape(*y_shape)
227-
228-
# can downcast only at the very end
229-
output_dtype = input_tensor.dtype
230-
y = y.to(output_dtype)
231-
if bias is not None:
232-
y += bias
233-
return y
19+
from torchao.prototype.dtypes.uintx.block_sparse_layout import (
20+
BlockSparseAQTTensorImpl, # noqa: F401
21+
BlockSparseLayout, # noqa: F401
22+
_linear_int8_act_int8_weight_block_sparse_check, # noqa: F401
23+
_linear_int8_act_int8_weight_block_sparse_impl, # noqa: F401
24+
)
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from .uintx import BlockSparseLayout
8+
9+
__all__ = [
10+
"BlockSparseLayout",
11+
]
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from .block_sparse_layout import BlockSparseLayout
8+
9+
__all__ = [
10+
"BlockSparseLayout",
11+
]

0 commit comments

Comments
 (0)