diff --git a/test/quantization/quantize_/workflows/float8/test_float8_semi_sparse.py b/test/quantization/quantize_/workflows/float8/test_float8_semi_sparse.py new file mode 100644 index 0000000000..232414bbb6 --- /dev/null +++ b/test/quantization/quantize_/workflows/float8/test_float8_semi_sparse.py @@ -0,0 +1,74 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +import torch +from torch.testing._internal.common_utils import ( + TestCase, + instantiate_parametrized_tests, + parametrize, + run_tests, +) + +from torchao.float8.inference import Float8MMConfig +from torchao.quantization.quantize_.workflows.float8.float8_semi_sparse_tensor import ( + Float8SemiSparseTensor, +) +from torchao.quantization.quantize_.workflows.float8.float8_tensor import Float8Tensor +from torchao.sparsity.sparse_api import apply_fake_sparsity +from torchao.testing.utils import skip_if_rocm +from torchao.utils import is_sm_at_least_90 + + +@unittest.skipIf(not is_sm_at_least_90(), "Need H100+ to run") +@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") +class TestFloat8SemiSparseTensor(TestCase): + def setUp(self): + self.GPU_DEVICES = ["cuda"] if torch.cuda.is_available() else [] + + @skip_if_rocm("ROCm enablement in progress") + @parametrize( + "sizes", + [ + ((128,), 256, 128), + ((32, 128), 512, 128), + ((2, 32, 128), 256, 128), + ], + ) + def test_sparse_vs_dense_fp8(self, sizes): + dtype = torch.bfloat16 + device = "cuda" + + M, N, K = sizes + input = torch.randn(*M, K, dtype=dtype, device=device) + linear = torch.nn.Linear(K, N, dtype=dtype, device=device) + + apply_fake_sparsity(linear) + + mm_config = Float8MMConfig(use_fast_accum=True) + input_fp8 = Float8Tensor.from_hp( + input, float8_dtype=torch.float8_e4m3fn, mm_config=mm_config + ) + + weight_fp8 = Float8Tensor.from_hp( + linear.weight.data, float8_dtype=torch.float8_e4m3fn, mm_config=mm_config + ) + dense_output = torch.nn.functional.linear(input_fp8, weight_fp8, linear.bias) + + weight_sparse_fp8 = Float8SemiSparseTensor.from_hp(linear.weight.data, [1, K]) + sparse_output = torch.nn.functional.linear( + input_fp8, weight_sparse_fp8, linear.bias + ) + + torch.testing.assert_close(dense_output, sparse_output, atol=3e-1, rtol=3e-1) + + +instantiate_parametrized_tests(TestFloat8SemiSparseTensor) + + +if __name__ == "__main__": + run_tests() diff --git a/torchao/quantization/__init__.py b/torchao/quantization/__init__.py index aa19aa1890..b44bcb107c 100644 --- a/torchao/quantization/__init__.py +++ b/torchao/quantization/__init__.py @@ -78,6 +78,7 @@ quantize_affine, ) from .quantize_.workflows import ( + Float8SemiSparseTensor, Float8Tensor, Int4MarlinSparseTensor, Int4OpaqueTensor, @@ -148,6 +149,7 @@ "Int4TilePackedTo4dTensor", "Float8Tensor", "Int4OpaqueTensor", + "Float8SemiSparseTensor", # smooth quant - subject to change "get_scale", "SmoothFakeDynQuantMixin", diff --git a/torchao/quantization/quantize_/workflows/__init__.py b/torchao/quantization/quantize_/workflows/__init__.py index 4307637f8e..7166e244a6 100644 --- a/torchao/quantization/quantize_/workflows/__init__.py +++ b/torchao/quantization/quantize_/workflows/__init__.py @@ -1,3 +1,6 @@ +from .float8.float8_semi_sparse_tensor import ( + Float8SemiSparseTensor, +) from .float8.float8_tensor import ( Float8Tensor, QuantizeTensorToFloat8Kwargs, @@ -38,6 +41,7 @@ "Int4PlainInt32Tensor", "Int4TilePackedTo4dTensor", "Float8Tensor", + "Float8SemiSparseTensor", "QuantizeTensorToFloat8Kwargs", "Int4OpaqueTensor", "Int4ChooseQParamsAlgorithm", diff --git a/torchao/quantization/quantize_/workflows/float8/float8_semi_sparse_tensor.py b/torchao/quantization/quantize_/workflows/float8/float8_semi_sparse_tensor.py new file mode 100644 index 0000000000..4384cc0aff --- /dev/null +++ b/torchao/quantization/quantize_/workflows/float8/float8_semi_sparse_tensor.py @@ -0,0 +1,153 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. +from typing import List + +import torch + +from torchao.ops import to_sparse_semi_structured_cutlass_sm9x_f8 +from torchao.quantization.quant_primitives import ( + _choose_scale_float8, + _quantize_affine_float8, +) +from torchao.utils import TorchAOBaseTensor + +__all__ = ["Float8SemiSparseTensor"] +aten = torch.ops.aten + + +class Float8SemiSparseTensor(TorchAOBaseTensor): + tensor_data_names = ["sparse", "meta", "scale"] + + def __new__( + cls, + sparse: torch.Tensor, + meta: torch.Tensor, + scale: torch.Tensor, + ): + kwargs = {} + kwargs["device"] = sparse.device + kwargs["dtype"] = scale.dtype + kwargs["requires_grad"] = False + shape = (sparse.shape[0], 2 * sparse.shape[-1]) + return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] + + def __init__( + self, + sparse: torch.Tensor, + meta: torch.Tensor, + scale: torch.Tensor, + ): + super().__init__() + self.sparse = sparse + self.meta = meta + self.scale = scale + + def _quantization_type(self): + return f"shape={self.shape}, device={self.device}, dtype={self.dtype}" + + @classmethod + def from_hp( + cls, + w: torch.Tensor, + block_size: List[int], + ): + from torchao.sparsity.utils import mask_creator + + dense = w * mask_creator(w).bool() + + scale = _choose_scale_float8( + dense, + block_size=block_size, + float8_dtype=torch.float8_e4m3fn, + ) + + w_fp8 = _quantize_affine_float8( + dense, + scale=scale, + float8_dtype=torch.float8_e4m3fn, + ) + + sparse, meta = to_sparse_semi_structured_cutlass_sm9x_f8(w_fp8) + + return cls( + sparse, + meta, + scale, + ) + + +implements = Float8SemiSparseTensor.implements +implements_torch_function = Float8SemiSparseTensor.implements_torch_function + + +@implements(aten.t.default) +def _(func, types, args, kwargs): + from torch.utils._python_dispatch import return_and_correct_aliasing + + self = args[0] + new = Float8SemiSparseTensor( + sparse=self.sparse, + meta=self.meta, + scale=self.scale, + ) + return return_and_correct_aliasing(func, args, kwargs, new) + + +def _linear_fp8_semi_sparse(input_tensor, weight_tensor, bias): + from torchao.ops import rowwise_scaled_linear_sparse_cutlass_f8f8 + from torchao.quantization.quantize_.workflows.float8.float8_tensor import ( + Float8Tensor, + ) + + if isinstance(input_tensor, Float8Tensor): + input = input_tensor.qdata + input_scale = input_tensor.scale + out_dtype = input_tensor.dtype + else: + input = input_tensor.qdata + input_scale = input_tensor.scale + out_dtype = input_tensor.dtype + + weight = weight_tensor.sparse + weight_meta = weight_tensor.meta + weight_scale = weight_tensor.scale + + # Reshape input_scale if needed: kernel expects scale to match input shape minus last dim + # For input [B, K], scale should be [B] not [B, 1] + if input_scale.dim() > input.dim() - 1: + input_scale = input_scale.squeeze(-1) + + return rowwise_scaled_linear_sparse_cutlass_f8f8( + input, input_scale, weight, weight_meta, weight_scale, bias, out_dtype + ) + + +@implements([aten.mm.default, aten.addmm.default]) +def _(func, types, args, kwargs): + if func == aten.addmm.default: + bias, input_tensor, weight_tensor = args + else: # aten.mm.default + input_tensor, weight_tensor = args + bias = None + + return _linear_fp8_semi_sparse(input_tensor, weight_tensor, bias) + + +@implements(aten.linear.default) +@implements_torch_function(torch.nn.functional.linear) +def _(func, types, args, kwargs): + input_tensor, weight_tensor, bias = ( + args[0], + args[1], + args[2] if len(args) > 2 else None, + ) + return _linear_fp8_semi_sparse(input_tensor, weight_tensor, bias) + + +Float8SemiSparseTensor.__module__ = "torchao.quantization" + +# Allow a model with Float8SemiSparseTensor weights to be loaded with `weights_only=True` +torch.serialization.add_safe_globals([Float8SemiSparseTensor]) diff --git a/torchao/quantization/quantize_/workflows/float8/float8_tensor.py b/torchao/quantization/quantize_/workflows/float8/float8_tensor.py index 47395a15af..97faa8ce06 100644 --- a/torchao/quantization/quantize_/workflows/float8/float8_tensor.py +++ b/torchao/quantization/quantize_/workflows/float8/float8_tensor.py @@ -256,9 +256,10 @@ def _(func, types, args, kwargs): args[1], args[2] if len(args) > 2 else None, ) - assert isinstance(weight_tensor, Float8Tensor), ( - f"Don't expect to reach here with an override other than weight currently, {type(input_tensor)} {type(weight_tensor)}" - ) + + # If weight is not Float8Tensor, return NotImplemented to allow weight's dispatch to handle it + if not isinstance(weight_tensor, Float8Tensor): + return NotImplemented act_quant_kwargs = weight_tensor.act_quant_kwargs # quantizing activation, if `act_quant_kwargs` is specified