-
Notifications
You must be signed in to change notification settings - Fork 365
introduce new int8 quantization API #3241
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
da7cfea
27076d3
cdb1d9f
9f1b6c9
9301717
caaba7a
0f51ee6
305c3a9
3ab38ba
b516304
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,222 @@ | ||||||||||||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||||||||||||
| # All rights reserved. | ||||||||||||
| # | ||||||||||||
| # This source code is licensed under the BSD 3-Clause license found in the | ||||||||||||
| # LICENSE file in the root directory of this source tree. | ||||||||||||
|
|
||||||||||||
| import copy | ||||||||||||
| import unittest | ||||||||||||
|
|
||||||||||||
| import torch | ||||||||||||
| from torch._inductor.utils import run_and_get_code | ||||||||||||
| from torch.testing import FileCheck | ||||||||||||
| from torch.testing._internal import common_utils | ||||||||||||
|
|
||||||||||||
| from torchao.quantization import ( | ||||||||||||
| Int8DynamicActivationInt8WeightConfig, | ||||||||||||
| Int8WeightOnlyConfig, | ||||||||||||
| quantize_, | ||||||||||||
| ) | ||||||||||||
| from torchao.quantization.utils import compute_error | ||||||||||||
| from torchao.testing.utils import TorchAOIntegrationTestCase | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
| # TODO: Refactor after https://github.com/pytorch/ao/pull/2729 is merged | ||||||||||||
| class ToyTwoLinearModel(torch.nn.Module): | ||||||||||||
| def __init__( | ||||||||||||
| self, | ||||||||||||
| input_dim, | ||||||||||||
| hidden_dim, | ||||||||||||
| output_dim, | ||||||||||||
| has_bias=False, | ||||||||||||
| dtype=None, | ||||||||||||
| device=None, | ||||||||||||
| ): | ||||||||||||
| super().__init__() | ||||||||||||
| self.dtype = dtype | ||||||||||||
| self.device = device | ||||||||||||
| self.linear1 = torch.nn.Linear( | ||||||||||||
| input_dim, hidden_dim, bias=has_bias, dtype=dtype, device=device | ||||||||||||
| ) | ||||||||||||
| self.linear2 = torch.nn.Linear( | ||||||||||||
| hidden_dim, output_dim, bias=has_bias, dtype=dtype, device=device | ||||||||||||
| ) | ||||||||||||
|
|
||||||||||||
| def forward(self, x): | ||||||||||||
| x = self.linear1(x) | ||||||||||||
| x = self.linear2(x) | ||||||||||||
| return x | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
| @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") | ||||||||||||
| @common_utils.instantiate_parametrized_tests | ||||||||||||
| class TestInt8Tensor(TorchAOIntegrationTestCase): | ||||||||||||
| def setUp(self): | ||||||||||||
| super().setUp() | ||||||||||||
|
|
||||||||||||
| self.test_shape = (32, 20) | ||||||||||||
| self.dtype = torch.bfloat16 | ||||||||||||
| self.batch_size = 32 | ||||||||||||
|
|
||||||||||||
| torch.manual_seed(42) | ||||||||||||
|
|
||||||||||||
| @common_utils.parametrize( | ||||||||||||
| "config", | ||||||||||||
| [ | ||||||||||||
| Int8DynamicActivationInt8WeightConfig(version=2), | ||||||||||||
| Int8WeightOnlyConfig(version=2), | ||||||||||||
| ], | ||||||||||||
| ) | ||||||||||||
| def test_creation_and_attributes(self, config): | ||||||||||||
| """Test tensor creation, dtypes, and ranges""" | ||||||||||||
| linear = torch.nn.Linear( | ||||||||||||
| self.test_shape[1], | ||||||||||||
| self.test_shape[0], | ||||||||||||
| bias=False, | ||||||||||||
| dtype=self.dtype, | ||||||||||||
| device="cuda", | ||||||||||||
| ) | ||||||||||||
| quantize_(linear, config) | ||||||||||||
|
|
||||||||||||
| w = linear.weight | ||||||||||||
|
|
||||||||||||
| self.assertEqual(w.shape, self.test_shape) | ||||||||||||
| self.assertEqual(w.qdata.dtype, torch.int8) | ||||||||||||
| self.assertTrue(torch.all(w.qdata >= -128) and torch.all(w.qdata <= 127)) | ||||||||||||
|
|
||||||||||||
| @common_utils.parametrize("dtype", [torch.bfloat16, torch.float32]) | ||||||||||||
| @common_utils.parametrize("compile", [True, False]) | ||||||||||||
| @common_utils.parametrize( | ||||||||||||
| "config", | ||||||||||||
| [ | ||||||||||||
| Int8DynamicActivationInt8WeightConfig(version=2), | ||||||||||||
| Int8WeightOnlyConfig(version=2), | ||||||||||||
| ], | ||||||||||||
| ) | ||||||||||||
| @common_utils.parametrize( | ||||||||||||
| "sizes", | ||||||||||||
| [ | ||||||||||||
| ((128,), 256, 128), # 2D | ||||||||||||
| ((32, 128), 64, 256), # 3D | ||||||||||||
| ], | ||||||||||||
| ) | ||||||||||||
| def test_int8_linear_variants( | ||||||||||||
| self, | ||||||||||||
| dtype: torch.dtype, | ||||||||||||
| config, | ||||||||||||
| compile: bool, | ||||||||||||
| sizes: tuple, | ||||||||||||
| ): | ||||||||||||
| """Test linear operation supports including shape and compile""" | ||||||||||||
| M, N, K = sizes | ||||||||||||
| input_tensor = torch.randn(*M, K, dtype=dtype, device="cuda") | ||||||||||||
| model = ToyTwoLinearModel(K, N, K, dtype=dtype, device="cuda").eval() | ||||||||||||
| model_q = copy.deepcopy(model) | ||||||||||||
|
|
||||||||||||
| quantize_(model_q, config) | ||||||||||||
|
|
||||||||||||
| self.assertEqual(model_q.linear2.weight.scale.shape, (K,)) | ||||||||||||
| self.assertEqual(model_q.linear2.weight.scale.ndim, 1) | ||||||||||||
|
|
||||||||||||
| if compile: | ||||||||||||
| model_q = torch.compile(model_q, fullgraph=True) | ||||||||||||
|
|
||||||||||||
| output_fp = model(input_tensor) | ||||||||||||
| output_quantized = model_q(input_tensor) | ||||||||||||
|
|
||||||||||||
| assert compute_error(output_fp, output_quantized) > 20, ( | ||||||||||||
| f"Quantization error is too high got a SQNR of {compute_error(output_fp, output_quantized)}" | ||||||||||||
| ) | ||||||||||||
|
|
||||||||||||
| @common_utils.parametrize( | ||||||||||||
| "config", | ||||||||||||
| [ | ||||||||||||
| Int8DynamicActivationInt8WeightConfig(version=2), | ||||||||||||
| Int8WeightOnlyConfig(version=2), | ||||||||||||
| ], | ||||||||||||
| ) | ||||||||||||
| @common_utils.parametrize("device", ["cpu", "cuda"]) | ||||||||||||
| @common_utils.parametrize("dtype", [torch.bfloat16, torch.float16]) | ||||||||||||
| def test_slice(self, config, device, dtype): | ||||||||||||
| """Test tensor slicing with per-row quantization""" | ||||||||||||
| tensor_size = 256 | ||||||||||||
| slice_sizes = (64, 128) | ||||||||||||
|
|
||||||||||||
| dummy = torch.nn.Linear( | ||||||||||||
| tensor_size, tensor_size, bias=False, dtype=dtype, device=device | ||||||||||||
| ) | ||||||||||||
| quantize_(dummy, config) | ||||||||||||
|
|
||||||||||||
| weight1 = dummy.weight.clone().narrow(0, 0, slice_sizes[0]) | ||||||||||||
| weight2 = dummy.weight.clone().narrow(1, 0, slice_sizes[1]) | ||||||||||||
|
|
||||||||||||
| self.assertEqual(weight1.qdata, dummy.weight.qdata.narrow(0, 0, slice_sizes[0])) | ||||||||||||
| self.assertEqual(weight2.qdata, dummy.weight.qdata.narrow(1, 0, slice_sizes[1])) | ||||||||||||
| self.assertEqual(weight1.scale, dummy.weight.scale.narrow(0, 0, slice_sizes[0])) | ||||||||||||
| self.assertEqual(weight2.scale, dummy.weight.scale) | ||||||||||||
| with self.assertRaises(NotImplementedError): | ||||||||||||
| _ = dummy.weight[::2] | ||||||||||||
|
|
||||||||||||
| @common_utils.parametrize( | ||||||||||||
| "config", | ||||||||||||
| [ | ||||||||||||
| Int8DynamicActivationInt8WeightConfig(version=2), | ||||||||||||
| Int8WeightOnlyConfig(version=2), | ||||||||||||
| ], | ||||||||||||
| ) | ||||||||||||
| def test_index_select(self, config): | ||||||||||||
| """test that `x_0 = x[0]` works when `x` is a 2D quantized tensor.""" | ||||||||||||
| N, K = 256, 512 | ||||||||||||
| x = torch.randn(N, K, device="cuda", dtype=torch.bfloat16) | ||||||||||||
| linear = torch.nn.Linear(K, N, bias=False, dtype=torch.bfloat16, device="cuda") | ||||||||||||
| linear.weight.data = x | ||||||||||||
| quantize_(linear, config) | ||||||||||||
|
|
||||||||||||
| x_int8 = linear.weight | ||||||||||||
| x_int8_0 = x_int8[0] | ||||||||||||
| torch.testing.assert_close( | ||||||||||||
| x_int8.dequantize()[0], x_int8_0.dequantize(), atol=0, rtol=0 | ||||||||||||
| ) | ||||||||||||
|
|
||||||||||||
| @common_utils.parametrize( | ||||||||||||
| "config", | ||||||||||||
| [ | ||||||||||||
| Int8DynamicActivationInt8WeightConfig(version=2), | ||||||||||||
| Int8WeightOnlyConfig(version=2), | ||||||||||||
| ], | ||||||||||||
| ) | ||||||||||||
| def test_dequantization_accuracy(self, config): | ||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. additional comments for this test, 1. increase the size of linear? 2. I think we don't need to overwrite the weight, we can just save the floating point weight (deepcopy) before quantization and compare the results |
||||||||||||
| """Test dequantization accuracy separately""" | ||||||||||||
| test_data = torch.tensor([[1.0, -1.0]], dtype=torch.bfloat16, device="cuda") | ||||||||||||
| linear = torch.nn.Linear(2, 1, bias=False, dtype=torch.bfloat16, device="cuda") | ||||||||||||
| linear.weight.data = test_data | ||||||||||||
| quantize_(linear, config) | ||||||||||||
|
|
||||||||||||
| tensor = linear.weight | ||||||||||||
| dequantized = tensor.dequantize() | ||||||||||||
| self.assertEqual(dequantized.shape, test_data.shape) | ||||||||||||
| assert compute_error(dequantized, test_data) > 20, ( | ||||||||||||
| f"Dequantization error is too high to get a SQNR of {compute_error(dequantized, test_data)}" | ||||||||||||
| ) | ||||||||||||
|
|
||||||||||||
| @common_utils.parametrize( | ||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. does this have to be parametrize? I think what we need here is to check the code contains a sequence of ops / kernel calls, like this: ao/test/quantization/quantize_/workflows/float8/test_float8_tensor.py Lines 827 to 831 in b4ec4cb
I think we can check 1. the quantize op and then 2. the mm op |
||||||||||||
| "kernel", | ||||||||||||
| ["triton_per_fused", "extern_kernels._int_mm", "triton_poi_fused"], | ||||||||||||
| ) | ||||||||||||
| def test_available_gpu_kernels(self, kernel): | ||||||||||||
| """Check which GPU kernels are available""" | ||||||||||||
| M, K, N = 128, 256, 512 | ||||||||||||
| m = torch.nn.Sequential( | ||||||||||||
| torch.nn.Linear(K, N, device="cuda", dtype=torch.bfloat16) | ||||||||||||
| ) | ||||||||||||
| config = Int8DynamicActivationInt8WeightConfig(version=2) | ||||||||||||
| quantize_(m, config) | ||||||||||||
| m = torch.compile(m) | ||||||||||||
| x = torch.randn(M, K, device="cuda", dtype=torch.bfloat16) | ||||||||||||
|
|
||||||||||||
| out, code = run_and_get_code(m, x) | ||||||||||||
| FileCheck().check(kernel).run(code[0]) | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
| if __name__ == "__main__": | ||||||||||||
| common_utils.run_tests() | ||||||||||||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -140,7 +140,18 @@ def _slice_scale_for_dimension( | |||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||
| aten = torch.ops.aten | ||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||
| # Unsupported case for now, this would be 1 scale per data element | ||||||||||||||||||||||||||||||||||||||||||||
| # Per-tensor quantization (scalar scale) | ||||||||||||||||||||||||||||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is this change related?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It is updated to support more granularity. Without this change, we can't use per-tensor (0D scale) and per-row (1D scale).
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So maybe it's better to move this util function to a common place?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this can be moved to torchao/quantization/quantize_/common/utils.py I think
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Okay, then I will move this to |
||||||||||||||||||||||||||||||||||||||||||||
| if scale.numel() == 1: | ||||||||||||||||||||||||||||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. note: I think we can just check for ndim consistently everywhere, after #3324 is fixed
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. also isn't handling for per tensor and per row already included in original code? ao/torchao/float8/inference.py Lines 158 to 178 in b4ec4cb
|
||||||||||||||||||||||||||||||||||||||||||||
| return scale | ||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||
| # Per-row quantization (1D scale) | ||||||||||||||||||||||||||||||||||||||||||||
| if scale.ndim == 1: | ||||||||||||||||||||||||||||||||||||||||||||
| if dim == 0: | ||||||||||||||||||||||||||||||||||||||||||||
| return aten.slice.Tensor(scale, 0, start, end, step) | ||||||||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||||||||
| return scale | ||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||
| # Block-wise quantization (2D scale) | ||||||||||||||||||||||||||||||||||||||||||||
| if scale.shape == data_shape: | ||||||||||||||||||||||||||||||||||||||||||||
| return aten.slice.Tensor(scale, dim, start, end, step) | ||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -158,6 +169,12 @@ def _slice_scale_for_dimension( | |||||||||||||||||||||||||||||||||||||||||||
| # Slice away as normal | ||||||||||||||||||||||||||||||||||||||||||||
| return aten.slice.Tensor(scale, dim, start, end, step) | ||||||||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||||||||
| # Error on Step > 1 | ||||||||||||||||||||||||||||||||||||||||||||
| if step > 1: | ||||||||||||||||||||||||||||||||||||||||||||
| raise NotImplementedError( | ||||||||||||||||||||||||||||||||||||||||||||
| "Slicing with step > 1 is not implemented for scale tensors." | ||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||
| # There is blocking in this dimension | ||||||||||||||||||||||||||||||||||||||||||||
| # Calculate which scale elements correspond to the sliced data | ||||||||||||||||||||||||||||||||||||||||||||
| scale_start = start // block_size_for_dim if start is not None else None | ||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -167,12 +184,6 @@ def _slice_scale_for_dimension( | |||||||||||||||||||||||||||||||||||||||||||
| else None | ||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||
| # Error on Step > 1 | ||||||||||||||||||||||||||||||||||||||||||||
| if step > 1: | ||||||||||||||||||||||||||||||||||||||||||||
| raise NotImplementedError( | ||||||||||||||||||||||||||||||||||||||||||||
| "Slicing with step > 1 is not implemented for scale tensors." | ||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||
| return aten.slice.Tensor(scale, dim, scale_start, scale_end, 1) | ||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: this can probably be moved to
TorchAOIntegrationTestCasecan also add 8e3b3da
feel free to do in a separate PR though