Skip to content

Conversation

@namgyu-youn
Copy link
Contributor

@namgyu-youn namgyu-youn commented Oct 24, 2025

Summary:
Introduce a new tensor subclass API. Main features are

  • Int8Tensor: Main API, which handles quantization and dequantization operations
  • Utility operation functions: Tensor slice, index selection

This api is integrated to global variants (Int8WeightOnlyConfig, Int8DynamicActivationInt8WeightConfig) using version, and not defined as a default.

Related Issue/PR:
This is reopened PR for #3038

Test plan:
test/quantization/quantize_/workflows/int8/test_int8_tensor.py

Performance:
The following are the results of https://github.com/pytorch/ao/blob/main/tutorials/quantize_vit/run_vit_b_quant.py with a batch size of 32:

API With torch.compile Without torch.compile
Old 65.47 ms 234.39 ms
New 63.30 ms 239.30 ms

@pytorch-bot
Copy link

pytorch-bot bot commented Oct 24, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/3241

Note: Links to docs will display an error until the docs builds have been completed.

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Oct 24, 2025
)

@common_utils.parametrize("dtype", [torch.bfloat16, torch.float16])
def test_quantization_shapes(self, dtype):
Copy link
Contributor

@jerryzh168 jerryzh168 Oct 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this seems to be a combination of two tests, one for dynamic quant one for static quant, can you use something like this:

@common_utils.parametrize("mode", ["dynamic", "weight-only"])

also I feel it might be better to not add static quant in this PR, and in a separate PR add both the tensor support and config support for static quant

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, not sure to remove static flags (although its not fully implemented) before, but small PR should be always better I feel. I will remove static_scale and all those supports.

if act_quant_kwargs is not None and act_quant_kwargs.static_scale is not None:
# INT8 × INT8 (static)
scale = act_quant_kwargs.static_scale
zero_point = torch.zeros_like(scale, dtype=torch.int8)
Copy link
Contributor

@jerryzh168 jerryzh168 Oct 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think user should specify static_zero_point as well

but again, it's better to do this in a separate PR, since current state is a half of the static quant feature (no config)

Comment on lines 196 to 225
# Cast fp16 scale to float
intermediate_dtype = (
torch.float if x_scales.dtype == torch.half else x_scales.dtype
)
# Note: CUDA doesn't support int32/int64 matmul, so we convert to float
# Error message is NotImplementedError: "addmm_cuda" not implemented for 'Int'
# This may introduce minor numerical differences compared to int arithmetic
y_dot = torch.mm(tmp.to(intermediate_dtype), w_vals_t.to(intermediate_dtype))

# Apply activation scale
is_per_tensor_act = x_scales.numel() == 1
if is_per_tensor_act:
y_dot.mul_(x_scales.to(intermediate_dtype))
else:
# For block-wise activation scale, reshape to match y_dot
x_scales_reshaped = x_scales.view(y_dot.shape[0], -1)
y_dot.mul_(x_scales_reshaped.to(intermediate_dtype))

# Apply weight scale
is_per_tensor_weight = w_scales.numel() == 1
if is_per_tensor_weight:
result = y_dot.mul_(w_scales.to(intermediate_dtype))
else:
# Per-row weight scale - transpose and broadcast
w_scales_broadcast = w_scales.t().expand_as(y_dot)
result = y_dot.mul_(w_scales_broadcast.to(intermediate_dtype))

# Reshape back to original shape
result = result.view(*x_vals.shape[:-1], result.shape[-1])
result = result.to(activation_tensor.dtype)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should follow:

def _linear_int8_act_int8_weight_impl(input_tensor, weight_tensor, bias):

Copy link
Contributor

@jerryzh168 jerryzh168 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should

  1. split the static quant support to separate PR
  2. follow what https://github.com/pytorch/ao/blob/main/torchao/dtypes/uintx/plain_layout.py is doing for quantized linear implementation

this should be a refactor PR, not a refactor + some extra modifications + some feature implementations I think

aten = torch.ops.aten

# Unsupported case for now, this would be 1 scale per data element
# Per-tensor quantization (scalar scale)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this change related?

Copy link
Contributor Author

@namgyu-youn namgyu-youn Oct 31, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is updated to support more granularity. Without this change, we can't use per-tensor (0D scale) and per-row (1D scale).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So maybe it's better to move this util function to a common place?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this can be moved to torchao/quantization/quantize_/common/utils.py I think

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, then I will move this to torchao/quantization/quantize_/common/utils.py after this PR.


@common_utils.parametrize("dtype", [torch.bfloat16, torch.float16])
@common_utils.parametrize("has_bias", [True, False])
def test_weight_only_linear_with_bias(self, dtype, has_bias):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this can probably be merged into the linear varaints test as well

Copy link
Contributor

@jerryzh168 jerryzh168 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks, I think the tensor changes looks good, but need to make a linear_variants tests to make sure we cover different aspects of things (e.g. compile), see comments inline

can you also do a e2e perf check with https://github.com/pytorch/ao/blob/main/tutorials/quantize_vit/run_vit_b_quant.py to make sure the performance are the same before and after change for vit model?

also add a kernel check might be useful to make sure we don't regress things:

def test_expected_gpu_kernel_fbgemm(self):

@namgyu-youn
Copy link
Contributor Author

namgyu-youn commented Oct 31, 2025

Updated logs:

@Xia-Weiwen
Copy link
Collaborator

Xia-Weiwen commented Nov 3, 2025

Hi @namgyu-youn Do you plan to submit another PR for static quantization? We also need static quantization for SmoothQuant. So, we are wondering if you have a plan or we should consider adding it ourselves. Thanks. CC @cyxlily

@namgyu-youn
Copy link
Contributor Author

Hi @namgyu-youn Do you plan to submit another PR for static quantization? We also need static quantization for SmoothQuant. So, we are wondering if you have a plan or we should consider adding it ourselves. Thanks. CC @cyxlily

Yeah, static quantization support using static/dynamic flags is planned; I hope to show it to your team in the foreseeable future.

Also, in the SmoothQuant case, validating its support for the new quantization APIs (below) has higher priority, I think. Could you look into it?

  • W4A16-INT: Int4WeightOnlyConfig(group_size=32, version=2)
  • W4A16-FP: Float8WeightOnlyConfig(version=2)
  • W8A8-FP-dynamic: Float8DynamicActivationFloat8WeightConfig(version=2)

@Xia-Weiwen
Copy link
Collaborator

Yeah, static quantization support using static/dynamic flags is planned; I hope to show it to your team in the foreseeable future.

Thanks. Looking forward to it. If there is anything we can help with, please let us know.

Also, in the SmoothQuant case, validating its support for the new quantization APIs (below) has higher priority, I think. Could you look into it?

  • W4A16-INT: Int4WeightOnlyConfig(group_size=32, version=2)
  • W4A16-FP: Float8WeightOnlyConfig(version=2)
  • W8A8-FP-dynamic: Float8DynamicActivationFloat8WeightConfig(version=2)

By "validating them", do you mean adding test cases? And are W4A16 and W8A16 (I guess there is a typo in your comment) really needed for SmoothQuant? For W4A16 , it would be much the same as AWQ. And for W8A16, I think accuracy is generally good enough without SmoothQuant.

@namgyu-youn
Copy link
Contributor Author

By "validating them", do you mean adding test cases? And are W4A16 and W8A16 (I guess there is a typo in your comment) really needed for SmoothQuant? For W4A16 , it would be much the same as AWQ. And for W8A16, I think accuracy is generally good enough without SmoothQuant.

Oh yes, it was a typo (W8A16 is right), and W4A16-INT (Int4WeightOnlyConfig(group_size=32, version=2)) is of interest. In my last experience and https://arxiv.org/html/2411.02355v3, W4A16-INT is the most efficient choice for synchronous deployments, while W8A8-INT maximize throughput in asynchronous settings.

Because current AWQ/SmoothQuant test is only working with old APIs (version 1), we can replace it with new APIs like Int4WeightOnlyConfig(group_size=32, version=2) I guess.

@Xia-Weiwen
Copy link
Collaborator

By "validating them", do you mean adding test cases? And are W4A16 and W8A16 (I guess there is a typo in your comment) really needed for SmoothQuant? For W4A16 , it would be much the same as AWQ. And for W8A16, I think accuracy is generally good enough without SmoothQuant.

Oh yes, it was a typo (W8A16 is right), and W4A16-INT (Int4WeightOnlyConfig(group_size=32, version=2)) is of interest. In my last experience and https://arxiv.org/html/2411.02355v3, W4A16-INT is the most efficient choice for synchronous deployments, while W8A8-INT maximize throughput in asynchronous settings.

Because current AWQ/SmoothQuant test is only working with old APIs (version 1), we can replace it with new APIs like Int4WeightOnlyConfig(group_size=32, version=2) I guess.

I see. Thanks. We will evaluate that.

@Xia-Weiwen
Copy link
Collaborator

Hi @namgyu-youn May I know if you have a timeline to land this? Thanks.

((32, 128), 64, 256), # 3D
],
)
def test_int8_linear_quantization_accuracy(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can this one be combined with test_int8_linear_variants as well?

with self.assertRaises(NotImplementedError):
_ = dummy.weight[::2]

def test_index_select(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think all the tests should be modified to use config as well

kernels = {}

# Check for Triton kernels
if "torch.ops.triton" in code[0]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should add some asserts I think

- Configs are updated to global variants
@namgyu-youn
Copy link
Contributor Author

namgyu-youn commented Nov 4, 2025

Updated logs:

self.assertEqual(weight2.qdata, dummy.weight.qdata.narrow(1, 0, slice_sizes[1]))

# Int8DynamicActivationInt8WeightConfig uses per-row (PerRow)
# Int8WeightOnlyConfig uses per-tensor (PerTensor)
Copy link
Contributor

@jerryzh168 jerryzh168 Nov 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it should be per row I think?

group_size = weight.shape[-1]

)
else:
assert config.version == 2, f"Unexpected version: {config.version}"
block_size = [weight.shape[0], weight.shape[1]]
Copy link
Contributor

@jerryzh168 jerryzh168 Nov 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why does this default to per tensor? I think it should follow the existing logic from L1376-1378?

Comment on lines 233 to 237
self.assertLess(
torch.abs(dequantized - test_data).max().item(),
0.1,
msg=f"Dequantization error exceeds tolerance of {0.1}",
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe check sqnr with compute_error instead of hardcoded absolute value?

Comment on lines 251 to 258
has_triton = "triton" in code[0].lower() # Trition
has_fbgemm = "fbgemm" in code[0].lower() # FB-GEMM
has_int_mm = "_int_mm" in code[0] # Int8 MatMul

self.assertTrue(
has_triton or has_fbgemm or has_int_mm,
f"No int8 quantization kernels found. has_triton={has_triton}, has_fbgemm={has_fbgemm}, has_int_mm={has_int_mm}",
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it possible to be more strict to spell out the kernels called and their order? like

FileCheck().check_count(
"torch.ops.triton.quantize_fp8_row.default(", 1
).check_count("torch.ops.fbgemm.f8f8bf16_rowwise.default(", 1).check_not(
".run("
).run(code[0])

)

x_vals = activation_tensor.qdata.reshape(-1, activation_tensor.qdata.shape[-1])
x_scales = preprocess_scale(activation_tensor.scale, x_vals.shape)
Copy link
Contributor

@jerryzh168 jerryzh168 Nov 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this called? can you match the code with can you match the code with

def _linear_int8_act_int8_weight_impl(input_tensor, weight_tensor, bias):
exactly?

@namgyu-youn
Copy link
Contributor Author

namgyu-youn commented Nov 11, 2025

self.dtype = torch.bfloat16
self.batch_size = 32

torch.manual_seed(42)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: this can probably be moved to TorchAOIntegrationTestCase
can also add 8e3b3da

feel free to do in a separate PR though

dtype=self.dtype,
device="cuda",
)
linear.weight.data = self.weight_fp.cuda()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we do this? doesn't sound very necessary?

Int8WeightOnlyConfig(version=2),
],
)
def test_per_row_scale_shape(self, dtype, config):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: can you merge the checks in this test to previous test test_int8_linear_variants?

f"Dequantization error is too high to get a SQNR of {compute_error(dequantized, test_data)}"
)

@common_utils.parametrize(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this have to be parametrize? I think what we need here is to check the code contains a sequence of ops / kernel calls, like this:

FileCheck().check_count(
"torch.ops.triton.quantize_fp8_row.default(", 1
).check_count("torch.ops.fbgemm.f8f8bf16_rowwise.default(", 1).check_not(
".run("
).run(code[0])

I think we can check 1. the quantize op and then 2. the mm op extern_kernels._int_mm, in a single run (see example), that should be enough

Copy link
Contributor Author

@namgyu-youn namgyu-youn Nov 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, thanks. We don't have to go with multiple runs.

Int8WeightOnlyConfig(version=2),
],
)
def test_dequantization_accuracy(self, config):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

additional comments for this test, 1. increase the size of linear? 2. I think we don't need to overwrite the weight, we can just save the floating point weight (deepcopy) before quantization and compare the results

Comment on lines +62 to +65
self.weight_fp = torch.randn(*self.test_shape, dtype=self.dtype)
self.input_fp = torch.randn(*self.test_shape, dtype=self.dtype)
self.bias = torch.randn(self.test_shape[0], dtype=self.dtype)
self.block_size = list(self.test_shape)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel we probably don't need these, it's also easier for people to follow to define everything / most of things in the test itself


# Unsupported case for now, this would be 1 scale per data element
# Per-tensor quantization (scalar scale)
if scale.numel() == 1:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note: I think we can just check for ndim consistently everywhere, after #3324 is fixed

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also isn't handling for per tensor and per row already included in original code?

if block_size_for_dim == 1:
# Scale is per-element along this dimension
# Slice away as normal
return aten.slice.Tensor(scale, dim, start, end, step)
else:
# There is blocking in this dimension
# Calculate which scale elements correspond to the sliced data
scale_start = start // block_size_for_dim if start is not None else None
scale_end = (
(end + block_size_for_dim - 1) // block_size_for_dim
if end is not None
else None
)
# Error on Step > 1
if step > 1:
raise NotImplementedError(
"Slicing with step > 1 is not implemented for scale tensors."
)
return aten.slice.Tensor(scale, dim, scale_start, scale_end, 1)

Comment on lines +178 to +182
if len(act_kwargs.block_size) != input_ndim:
if input_ndim == 3 and len(act_kwargs.block_size) == 2:
block_size_updated = [1] + list(act_kwargs.block_size)
else:
block_size_updated = list(act_kwargs.block_size)[-input_ndim:]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we kind of changed the meaning of block_size used in PerBlock quant recently, check

elif isinstance(granularity, PerBlock):

when is this code needed?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in principle we shouldn't update block_size here, but instead, make sure block_size make sense and is consistent throughout the code base

quantized_weight = Int8Tensor.from_hp(
weight,
block_size,
act_quant_kwargs=QuantizeTensorToInt8Kwargs(block_size=block_size),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why does activation use the same block_size as weight?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can change block_size argument to granularity, if the block_size is unknown for activation since the shape is unknown

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants