Skip to content

Commit 614f33e

Browse files
committed
Fix tests
1 parent 3242af1 commit 614f33e

File tree

1 file changed

+7
-6
lines changed

1 file changed

+7
-6
lines changed

test/quantization/quantize_/workflows/intx/test_intx_unpacked_to_int8_tensor.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -51,18 +51,19 @@ def test_embedding(self):
5151
self.assertTrue(error > 20)
5252

5353
def test_add(self):
54+
dtype = torch.bfloat16
5455
device = "cpu"
55-
a = torch.randint(low=0, high=128, size=(10,), device=device)
56+
a = torch.nn.Embedding(128, 256, dtype=dtype, device=device)
57+
b = torch.nn.Embedding(128, 256, dtype=dtype, device=device)
5658
a_orig = a.clone()
57-
b = torch.randint(low=0, high=128, size=(10,), device=device)
58-
sum = a + b
59+
sum = a.weight + b.weight
5960

6061
quantize_(a, self.config)
61-
a_quant_sum = a + b
62+
a_quant_sum = a.weight + b.weight
6263

6364
quantize_(b, self.config)
64-
b_quant_sum = a_orig + b
65-
a_b_quant_sum = a + b
65+
b_quant_sum = a_orig.weight + b.weight
66+
a_b_quant_sum = a.weight + b.weight
6667

6768
for quantized_sum in [a_quant_sum, b_quant_sum, a_b_quant_sum]:
6869
error = compute_error(sum, quantized_sum)

0 commit comments

Comments
 (0)