From 3242af1871a24d9253aed955d7daa34048b5a1c4 Mon Sep 17 00:00:00 2001 From: Jack Zhang <32371937+jackzhxng@users.noreply.github.com> Date: Thu, 6 Nov 2025 06:44:59 -0800 Subject: [PATCH 1/3] Implement aten.add for IntxUnpackedToInt8Tensor --- .../intx/test_intx_unpacked_to_int8_tensor.py | 18 ++++++++++++++++++ .../intx/intx_unpacked_to_int8_tensor.py | 13 +++++++++++++ 2 files changed, 31 insertions(+) diff --git a/test/quantization/quantize_/workflows/intx/test_intx_unpacked_to_int8_tensor.py b/test/quantization/quantize_/workflows/intx/test_intx_unpacked_to_int8_tensor.py index f49e2b3f8d..a5aa60187e 100644 --- a/test/quantization/quantize_/workflows/intx/test_intx_unpacked_to_int8_tensor.py +++ b/test/quantization/quantize_/workflows/intx/test_intx_unpacked_to_int8_tensor.py @@ -50,6 +50,24 @@ def test_embedding(self): error = compute_error(original, quantized) self.assertTrue(error > 20) + def test_add(self): + device = "cpu" + a = torch.randint(low=0, high=128, size=(10,), device=device) + a_orig = a.clone() + b = torch.randint(low=0, high=128, size=(10,), device=device) + sum = a + b + + quantize_(a, self.config) + a_quant_sum = a + b + + quantize_(b, self.config) + b_quant_sum = a_orig + b + a_b_quant_sum = a + b + + for quantized_sum in [a_quant_sum, b_quant_sum, a_b_quant_sum]: + error = compute_error(sum, quantized_sum) + self.assertTrue(error > 20) + def test_linear(self): dtype = torch.bfloat16 device = "cpu" diff --git a/torchao/quantization/quantize_/workflows/intx/intx_unpacked_to_int8_tensor.py b/torchao/quantization/quantize_/workflows/intx/intx_unpacked_to_int8_tensor.py index bbbf62b412..f31f338f9e 100644 --- a/torchao/quantization/quantize_/workflows/intx/intx_unpacked_to_int8_tensor.py +++ b/torchao/quantization/quantize_/workflows/intx/intx_unpacked_to_int8_tensor.py @@ -355,6 +355,19 @@ def _(func, types, args, kwargs): return torch.nn.functional.embedding(indices, weight_tensor, **kwargs) +@implements(aten.add.Tensor) +def _(func, types, args, kwargs): + assert len(args) == 2 + t1, t2 = args[0], args[1] + if isinstance(t1, IntxUnpackedToInt8Tensor): + assert t1.activation_quantization is None + t1 = t1.dequantize() + if isinstance(t2, IntxUnpackedToInt8Tensor): + assert t2.activation_quantization is None + t2 = t2.dequantize() + return t1 + t2 + + @implements(aten.slice.Tensor) def _(func, types, args, kwargs): self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1]) From 614f33ee37a5d53ae9e17ed2441399f1029ee14e Mon Sep 17 00:00:00 2001 From: Jack Zhang <32371937+jackzhxng@users.noreply.github.com> Date: Fri, 7 Nov 2025 09:20:30 -0800 Subject: [PATCH 2/3] Fix tests --- .../intx/test_intx_unpacked_to_int8_tensor.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/test/quantization/quantize_/workflows/intx/test_intx_unpacked_to_int8_tensor.py b/test/quantization/quantize_/workflows/intx/test_intx_unpacked_to_int8_tensor.py index a5aa60187e..e4378dba2f 100644 --- a/test/quantization/quantize_/workflows/intx/test_intx_unpacked_to_int8_tensor.py +++ b/test/quantization/quantize_/workflows/intx/test_intx_unpacked_to_int8_tensor.py @@ -51,18 +51,19 @@ def test_embedding(self): self.assertTrue(error > 20) def test_add(self): + dtype = torch.bfloat16 device = "cpu" - a = torch.randint(low=0, high=128, size=(10,), device=device) + a = torch.nn.Embedding(128, 256, dtype=dtype, device=device) + b = torch.nn.Embedding(128, 256, dtype=dtype, device=device) a_orig = a.clone() - b = torch.randint(low=0, high=128, size=(10,), device=device) - sum = a + b + sum = a.weight + b.weight quantize_(a, self.config) - a_quant_sum = a + b + a_quant_sum = a.weight + b.weight quantize_(b, self.config) - b_quant_sum = a_orig + b - a_b_quant_sum = a + b + b_quant_sum = a_orig.weight + b.weight + a_b_quant_sum = a.weight + b.weight for quantized_sum in [a_quant_sum, b_quant_sum, a_b_quant_sum]: error = compute_error(sum, quantized_sum) From 6cf9b150b1c37a147e1f182746c1d7c08e18900b Mon Sep 17 00:00:00 2001 From: Jack Zhang <32371937+jackzhxng@users.noreply.github.com> Date: Mon, 10 Nov 2025 14:03:40 -0800 Subject: [PATCH 3/3] Fix test --- .../workflows/intx/test_intx_unpacked_to_int8_tensor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/quantization/quantize_/workflows/intx/test_intx_unpacked_to_int8_tensor.py b/test/quantization/quantize_/workflows/intx/test_intx_unpacked_to_int8_tensor.py index e4378dba2f..15cf83ba61 100644 --- a/test/quantization/quantize_/workflows/intx/test_intx_unpacked_to_int8_tensor.py +++ b/test/quantization/quantize_/workflows/intx/test_intx_unpacked_to_int8_tensor.py @@ -55,7 +55,7 @@ def test_add(self): device = "cpu" a = torch.nn.Embedding(128, 256, dtype=dtype, device=device) b = torch.nn.Embedding(128, 256, dtype=dtype, device=device) - a_orig = a.clone() + a_orig = copy.deepcopy(a) sum = a.weight + b.weight quantize_(a, self.config)