File tree Expand file tree Collapse file tree 1 file changed +56
-0
lines changed
torchao/quantization/quantize_/workflows/float8 Expand file tree Collapse file tree 1 file changed +56
-0
lines changed Original file line number Diff line number Diff line change 1+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2+ # All rights reserved.
3+ #
4+ # This source code is licensed under the BSD 3-Clause license found in the
5+ # LICENSE file in the root directory of this source tree.
6+ import torch
7+ from torchao .utils import TorchAOBaseTensor
8+
9+ __all__ = ["CutlassSemiSparseFp8Tensor" ]
10+ aten = torch .ops .aten
11+
12+ class CutlassSemiSparseFp8Tensor (TorchAOBaseTensor ):
13+ tensor_data_names = ["sparse" , "scale" , "meta" ]
14+
15+ def __new__ (
16+ cls ,
17+ sparse : torch .Tensor ,
18+ meta : torch .Tensor ,
19+ scale : torch .Tensor ,
20+ ):
21+ kwargs = {}
22+ kwargs ["device" ] = sparse .device
23+ kwargs ["dtype" ] = scale .dtype
24+ kwargs ["requires_grad" ] = False
25+ shape = (sparse .shape [0 ], 2 * sparse .shape [- 1 ])
26+ return torch .Tensor ._make_wrapper_subclass (cls , shape , ** kwargs ) # type: ignore[attr-defined]
27+
28+
29+ def __init__ (
30+ self ,
31+ sparse : torch .Tensor ,
32+ meta : torch .Tensor ,
33+ scale : torch .Tensor ,
34+ ):
35+ super ().__init__ ()
36+ self .sparse = sparse
37+ self .meta = meta
38+ self .scale = scale
39+
40+ def _quantization_type (self ):
41+ return f"shape={ self .shape } , device={ self .device } , dtype={ self .dtype } "
42+
43+
44+ @classmethod
45+ def from_hp (
46+ ):
47+ raise NotImplementedError ("CutlassSemiSparseFp8Tensor.from_hp is not implemented yet" )
48+
49+
50+ implements = CutlassSemiSparseFp8Tensor .implements
51+ implements_torch_function = CutlassSemiSparseFp8Tensor .implements_torch_function
52+
53+ CutlassSemiSparseFp8Tensor .__module__ = "torchao.quantization"
54+
55+ # Allow a model with CutlassSemiSparseFp8Tensor weights to be loaded with `weights_only=True`
56+ torch .serialization .add_safe_globals ([CutlassSemiSparseFp8Tensor ])
You can’t perform that action at this time.
0 commit comments