Skip to content

Commit 46cdefc

Browse files
committed
Fix tests
1 parent 3242af1 commit 46cdefc

File tree

1 file changed

+7
-7
lines changed

1 file changed

+7
-7
lines changed

test/quantization/quantize_/workflows/intx/test_intx_unpacked_to_int8_tensor.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -52,17 +52,17 @@ def test_embedding(self):
5252

5353
def test_add(self):
5454
device = "cpu"
55-
a = torch.randint(low=0, high=128, size=(10,), device=device)
56-
a_orig = a.clone()
57-
b = torch.randint(low=0, high=128, size=(10,), device=device)
58-
sum = a + b
55+
a = torch.nn.Embedding(128, 256, dtype=dtype, device=device)
56+
b = torch.nn.Embedding(128, 256, dtype=dtype, device=device)
57+
a_orig = a.clone())
58+
sum = a.weight + b.weight
5959

6060
quantize_(a, self.config)
61-
a_quant_sum = a + b
61+
a_quant_sum = a.weight + b.weight
6262

6363
quantize_(b, self.config)
64-
b_quant_sum = a_orig + b
65-
a_b_quant_sum = a + b
64+
b_quant_sum = a_orig.weight + b.weight
65+
a_b_quant_sum = a.weight + b.weight
6666

6767
for quantized_sum in [a_quant_sum, b_quant_sum, a_b_quant_sum]:
6868
error = compute_error(sum, quantized_sum)

0 commit comments

Comments
 (0)