@@ -309,8 +309,6 @@ class MockAttention(torch.nn.Module):
309309 # static token is not supported
310310 # channel is not supported
311311 # group is not supported
312- # tensor group is not supported
313- # block is not supported
314312 (
315313 QuantizationArgs (
316314 num_bits = 4 ,
@@ -340,6 +338,34 @@ class MockAttention(torch.nn.Module):
340338 ),
341339 0.55 ,
342340 ),
341+ # block is not supported
342+ (
343+ QuantizationArgs (
344+ num_bits = 4 ,
345+ type = "int" ,
346+ symmetric = True ,
347+ strategy = "attn_head" ,
348+ ),
349+ torch .tensor ([[[0.0 ]], [[12.0 ]]]),
350+ torch .tensor ([[[11.0 ]], [[23.0 ]]]),
351+ torch .tensor (
352+ [
353+ [
354+ [
355+ [0.0000 , 1.4688 , 1.4688 , 2.9375 ],
356+ [4.4062 , 4.4062 , 5.8750 , 7.3438 ],
357+ [7.3438 , 8.8125 , 10.2500 , 10.2500 ],
358+ ],
359+ [
360+ [12.2500 , 12.2500 , 15.3125 , 15.3125 ],
361+ [15.3125 , 18.3750 , 18.3750 , 18.3750 ],
362+ [21.5000 , 21.5000 , 21.5000 , 21.5000 ],
363+ ],
364+ ]
365+ ]
366+ ),
367+ 0.55 ,
368+ ),
343369 ],
344370)
345371def test_static_attention_quantization (
0 commit comments