Skip to content

Commit 9b2cfe8

Browse files
bbeckcafacebook-github-bot
authored andcommitted
Add CutlassSemiSparseFp8Tensor
Summary: Moving float8 cutlass sparse layout into its own class: https://github.com/pytorch/ao/blob/main/torchao/dtypes/floatx/cutlass_semi_sparse_layout.py Differential Revision: D84467190
1 parent 30082cb commit 9b2cfe8

File tree

1 file changed

+56
-0
lines changed

1 file changed

+56
-0
lines changed
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
import torch
7+
from torchao.utils import TorchAOBaseTensor
8+
9+
__all__ = ["CutlassSemiSparseFp8Tensor"]
10+
aten = torch.ops.aten
11+
12+
class CutlassSemiSparseFp8Tensor(TorchAOBaseTensor):
13+
tensor_data_names = ["sparse", "scale", "meta"]
14+
15+
def __new__(
16+
cls,
17+
sparse: torch.Tensor,
18+
meta: torch.Tensor,
19+
scale: torch.Tensor,
20+
):
21+
kwargs = {}
22+
kwargs["device"] = sparse.device
23+
kwargs["dtype"] = scale.dtype
24+
kwargs["requires_grad"] = False
25+
shape = (sparse.shape[0], 2 * sparse.shape[-1])
26+
return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined]
27+
28+
29+
def __init__(
30+
self,
31+
sparse: torch.Tensor,
32+
meta: torch.Tensor,
33+
scale: torch.Tensor,
34+
):
35+
super().__init__()
36+
self.sparse = sparse
37+
self.meta = meta
38+
self.scale = scale
39+
40+
def _quantization_type(self):
41+
return f"shape={self.shape}, device={self.device}, dtype={self.dtype}"
42+
43+
44+
@classmethod
45+
def from_hp(
46+
):
47+
raise NotImplementedError("CutlassSemiSparseFp8Tensor.from_hp is not implemented yet")
48+
49+
50+
implements = CutlassSemiSparseFp8Tensor.implements
51+
implements_torch_function = CutlassSemiSparseFp8Tensor.implements_torch_function
52+
53+
CutlassSemiSparseFp8Tensor.__module__ = "torchao.quantization"
54+
55+
# Allow a model with CutlassSemiSparseFp8Tensor weights to be loaded with `weights_only=True`
56+
torch.serialization.add_safe_globals([CutlassSemiSparseFp8Tensor])

0 commit comments

Comments
 (0)