-
Notifications
You must be signed in to change notification settings - Fork 365
introduce new int8 quantization API #3241
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/3241
Note: Links to docs will display an error until the docs builds have been completed. This comment was automatically generated by Dr. CI and updates every 15 minutes. |
| ) | ||
|
|
||
| @common_utils.parametrize("dtype", [torch.bfloat16, torch.float16]) | ||
| def test_quantization_shapes(self, dtype): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this seems to be a combination of two tests, one for dynamic quant one for static quant, can you use something like this:
| @common_utils.parametrize("mode", ["dynamic", "weight-only"]) |
also I feel it might be better to not add static quant in this PR, and in a separate PR add both the tensor support and config support for static quant
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay, not sure to remove static flags (although its not fully implemented) before, but small PR should be always better I feel. I will remove static_scale and all those supports.
| if act_quant_kwargs is not None and act_quant_kwargs.static_scale is not None: | ||
| # INT8 × INT8 (static) | ||
| scale = act_quant_kwargs.static_scale | ||
| zero_point = torch.zeros_like(scale, dtype=torch.int8) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think user should specify static_zero_point as well
but again, it's better to do this in a separate PR, since current state is a half of the static quant feature (no config)
| # Cast fp16 scale to float | ||
| intermediate_dtype = ( | ||
| torch.float if x_scales.dtype == torch.half else x_scales.dtype | ||
| ) | ||
| # Note: CUDA doesn't support int32/int64 matmul, so we convert to float | ||
| # Error message is NotImplementedError: "addmm_cuda" not implemented for 'Int' | ||
| # This may introduce minor numerical differences compared to int arithmetic | ||
| y_dot = torch.mm(tmp.to(intermediate_dtype), w_vals_t.to(intermediate_dtype)) | ||
|
|
||
| # Apply activation scale | ||
| is_per_tensor_act = x_scales.numel() == 1 | ||
| if is_per_tensor_act: | ||
| y_dot.mul_(x_scales.to(intermediate_dtype)) | ||
| else: | ||
| # For block-wise activation scale, reshape to match y_dot | ||
| x_scales_reshaped = x_scales.view(y_dot.shape[0], -1) | ||
| y_dot.mul_(x_scales_reshaped.to(intermediate_dtype)) | ||
|
|
||
| # Apply weight scale | ||
| is_per_tensor_weight = w_scales.numel() == 1 | ||
| if is_per_tensor_weight: | ||
| result = y_dot.mul_(w_scales.to(intermediate_dtype)) | ||
| else: | ||
| # Per-row weight scale - transpose and broadcast | ||
| w_scales_broadcast = w_scales.t().expand_as(y_dot) | ||
| result = y_dot.mul_(w_scales_broadcast.to(intermediate_dtype)) | ||
|
|
||
| # Reshape back to original shape | ||
| result = result.view(*x_vals.shape[:-1], result.shape[-1]) | ||
| result = result.to(activation_tensor.dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this should follow:
ao/torchao/dtypes/uintx/plain_layout.py
Line 281 in e9c7bea
| def _linear_int8_act_int8_weight_impl(input_tensor, weight_tensor, bias): |
jerryzh168
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should
- split the static quant support to separate PR
- follow what https://github.com/pytorch/ao/blob/main/torchao/dtypes/uintx/plain_layout.py is doing for quantized linear implementation
this should be a refactor PR, not a refactor + some extra modifications + some feature implementations I think
| aten = torch.ops.aten | ||
|
|
||
| # Unsupported case for now, this would be 1 scale per data element | ||
| # Per-tensor quantization (scalar scale) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this change related?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is updated to support more granularity. Without this change, we can't use per-tensor (0D scale) and per-row (1D scale).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So maybe it's better to move this util function to a common place?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this can be moved to torchao/quantization/quantize_/common/utils.py I think
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay, then I will move this to torchao/quantization/quantize_/common/utils.py after this PR.
|
|
||
| @common_utils.parametrize("dtype", [torch.bfloat16, torch.float16]) | ||
| @common_utils.parametrize("has_bias", [True, False]) | ||
| def test_weight_only_linear_with_bias(self, dtype, has_bias): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this can probably be merged into the linear varaints test as well
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks, I think the tensor changes looks good, but need to make a linear_variants tests to make sure we cover different aspects of things (e.g. compile), see comments inline
can you also do a e2e perf check with https://github.com/pytorch/ao/blob/main/tutorials/quantize_vit/run_vit_b_quant.py to make sure the performance are the same before and after change for vit model?
also add a kernel check might be useful to make sure we don't regress things:
| def test_expected_gpu_kernel_fbgemm(self): |
|
Updated logs:
|
|
Hi @namgyu-youn Do you plan to submit another PR for static quantization? We also need static quantization for SmoothQuant. So, we are wondering if you have a plan or we should consider adding it ourselves. Thanks. CC @cyxlily |
Yeah, static quantization support using static/dynamic flags is planned; I hope to show it to your team in the foreseeable future. Also, in the SmoothQuant case, validating its support for the new quantization APIs (below) has higher priority, I think. Could you look into it?
|
Thanks. Looking forward to it. If there is anything we can help with, please let us know.
By "validating them", do you mean adding test cases? And are W4A16 and W8A16 (I guess there is a typo in your comment) really needed for SmoothQuant? For W4A16 , it would be much the same as AWQ. And for W8A16, I think accuracy is generally good enough without SmoothQuant. |
Oh yes, it was a typo (W8A16 is right), and W4A16-INT ( Because current AWQ/SmoothQuant test is only working with old APIs (version 1), we can replace it with new APIs like |
I see. Thanks. We will evaluate that. |
|
Hi @namgyu-youn May I know if you have a timeline to land this? Thanks. |
| ((32, 128), 64, 256), # 3D | ||
| ], | ||
| ) | ||
| def test_int8_linear_quantization_accuracy( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can this one be combined with test_int8_linear_variants as well?
| with self.assertRaises(NotImplementedError): | ||
| _ = dummy.weight[::2] | ||
|
|
||
| def test_index_select(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think all the tests should be modified to use config as well
| kernels = {} | ||
|
|
||
| # Check for Triton kernels | ||
| if "torch.ops.triton" in code[0]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should add some asserts I think
- Configs are updated to global variants
|
Updated logs:
|
| self.assertEqual(weight2.qdata, dummy.weight.qdata.narrow(1, 0, slice_sizes[1])) | ||
|
|
||
| # Int8DynamicActivationInt8WeightConfig uses per-row (PerRow) | ||
| # Int8WeightOnlyConfig uses per-tensor (PerTensor) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it should be per row I think?
ao/torchao/quantization/quant_api.py
Line 1343 in 6815e57
| group_size = weight.shape[-1] |
torchao/quantization/quant_api.py
Outdated
| ) | ||
| else: | ||
| assert config.version == 2, f"Unexpected version: {config.version}" | ||
| block_size = [weight.shape[0], weight.shape[1]] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why does this default to per tensor? I think it should follow the existing logic from L1376-1378?
| self.assertLess( | ||
| torch.abs(dequantized - test_data).max().item(), | ||
| 0.1, | ||
| msg=f"Dequantization error exceeds tolerance of {0.1}", | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe check sqnr with compute_error instead of hardcoded absolute value?
| has_triton = "triton" in code[0].lower() # Trition | ||
| has_fbgemm = "fbgemm" in code[0].lower() # FB-GEMM | ||
| has_int_mm = "_int_mm" in code[0] # Int8 MatMul | ||
|
|
||
| self.assertTrue( | ||
| has_triton or has_fbgemm or has_int_mm, | ||
| f"No int8 quantization kernels found. has_triton={has_triton}, has_fbgemm={has_fbgemm}, has_int_mm={has_int_mm}", | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is it possible to be more strict to spell out the kernels called and their order? like
ao/test/quantization/quantize_/workflows/float8/test_float8_tensor.py
Lines 596 to 600 in 6815e57
| FileCheck().check_count( | |
| "torch.ops.triton.quantize_fp8_row.default(", 1 | |
| ).check_count("torch.ops.fbgemm.f8f8bf16_rowwise.default(", 1).check_not( | |
| ".run(" | |
| ).run(code[0]) |
| ) | ||
|
|
||
| x_vals = activation_tensor.qdata.reshape(-1, activation_tensor.qdata.shape[-1]) | ||
| x_scales = preprocess_scale(activation_tensor.scale, x_vals.shape) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why is this called? can you match the code with can you match the code with
ao/torchao/dtypes/uintx/plain_layout.py
Line 281 in 6815e57
| def _linear_int8_act_int8_weight_impl(input_tensor, weight_tensor, bias): |
|
Updated logs:
|
| self.dtype = torch.bfloat16 | ||
| self.batch_size = 32 | ||
|
|
||
| torch.manual_seed(42) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: this can probably be moved to TorchAOIntegrationTestCase
can also add 8e3b3da
feel free to do in a separate PR though
| dtype=self.dtype, | ||
| device="cuda", | ||
| ) | ||
| linear.weight.data = self.weight_fp.cuda() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why do we do this? doesn't sound very necessary?
| Int8WeightOnlyConfig(version=2), | ||
| ], | ||
| ) | ||
| def test_per_row_scale_shape(self, dtype, config): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: can you merge the checks in this test to previous test test_int8_linear_variants?
| f"Dequantization error is too high to get a SQNR of {compute_error(dequantized, test_data)}" | ||
| ) | ||
|
|
||
| @common_utils.parametrize( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does this have to be parametrize? I think what we need here is to check the code contains a sequence of ops / kernel calls, like this:
ao/test/quantization/quantize_/workflows/float8/test_float8_tensor.py
Lines 827 to 831 in b4ec4cb
| FileCheck().check_count( | |
| "torch.ops.triton.quantize_fp8_row.default(", 1 | |
| ).check_count("torch.ops.fbgemm.f8f8bf16_rowwise.default(", 1).check_not( | |
| ".run(" | |
| ).run(code[0]) |
I think we can check 1. the quantize op and then 2. the mm op extern_kernels._int_mm, in a single run (see example), that should be enough
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see, thanks. We don't have to go with multiple runs.
| Int8WeightOnlyConfig(version=2), | ||
| ], | ||
| ) | ||
| def test_dequantization_accuracy(self, config): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
additional comments for this test, 1. increase the size of linear? 2. I think we don't need to overwrite the weight, we can just save the floating point weight (deepcopy) before quantization and compare the results
| self.weight_fp = torch.randn(*self.test_shape, dtype=self.dtype) | ||
| self.input_fp = torch.randn(*self.test_shape, dtype=self.dtype) | ||
| self.bias = torch.randn(self.test_shape[0], dtype=self.dtype) | ||
| self.block_size = list(self.test_shape) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel we probably don't need these, it's also easier for people to follow to define everything / most of things in the test itself
|
|
||
| # Unsupported case for now, this would be 1 scale per data element | ||
| # Per-tensor quantization (scalar scale) | ||
| if scale.numel() == 1: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
note: I think we can just check for ndim consistently everywhere, after #3324 is fixed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also isn't handling for per tensor and per row already included in original code?
ao/torchao/float8/inference.py
Lines 158 to 178 in b4ec4cb
| if block_size_for_dim == 1: | |
| # Scale is per-element along this dimension | |
| # Slice away as normal | |
| return aten.slice.Tensor(scale, dim, start, end, step) | |
| else: | |
| # There is blocking in this dimension | |
| # Calculate which scale elements correspond to the sliced data | |
| scale_start = start // block_size_for_dim if start is not None else None | |
| scale_end = ( | |
| (end + block_size_for_dim - 1) // block_size_for_dim | |
| if end is not None | |
| else None | |
| ) | |
| # Error on Step > 1 | |
| if step > 1: | |
| raise NotImplementedError( | |
| "Slicing with step > 1 is not implemented for scale tensors." | |
| ) | |
| return aten.slice.Tensor(scale, dim, scale_start, scale_end, 1) |
| if len(act_kwargs.block_size) != input_ndim: | ||
| if input_ndim == 3 and len(act_kwargs.block_size) == 2: | ||
| block_size_updated = [1] + list(act_kwargs.block_size) | ||
| else: | ||
| block_size_updated = list(act_kwargs.block_size)[-input_ndim:] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we kind of changed the meaning of block_size used in PerBlock quant recently, check
ao/torchao/quantization/utils.py
Line 707 in b4ec4cb
| elif isinstance(granularity, PerBlock): |
when is this code needed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
in principle we shouldn't update block_size here, but instead, make sure block_size make sense and is consistent throughout the code base
| quantized_weight = Int8Tensor.from_hp( | ||
| weight, | ||
| block_size, | ||
| act_quant_kwargs=QuantizeTensorToInt8Kwargs(block_size=block_size), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why does activation use the same block_size as weight?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can change block_size argument to granularity, if the block_size is unknown for activation since the shape is unknown
Summary:
Introduce a new tensor subclass API. Main features are
Int8Tensor: Main API, which handles quantization and dequantization operationsThis api is integrated to global variants (
Int8WeightOnlyConfig,Int8DynamicActivationInt8WeightConfig) usingversion, and not defined as a default.Related Issue/PR:
This is reopened PR for #3038
Test plan:
test/quantization/quantize_/workflows/int8/test_int8_tensor.py
Performance:
The following are the results of https://github.com/pytorch/ao/blob/main/tutorials/quantize_vit/run_vit_b_quant.py with a batch size of 32:
torch.compiletorch.compile