@@ -91,7 +91,7 @@ def setUp(self):
9191 @common_utils .parametrize ("compile" , [True , False ])
9292 @common_utils .parametrize (
9393 "granularity" ,
94- [PerTensor (), PerRow (), (PerBlock (( 1 , 128 )) , PerBlock (( 128 , 128 ) ))],
94+ [PerTensor (), PerRow (), (PerBlock ([ 1 , 128 ]) , PerBlock ([ 128 , 128 ] ))],
9595 )
9696 @common_utils .parametrize (
9797 "kernel_preference" ,
@@ -125,7 +125,7 @@ def test_fp8_linear_variants(
125125 elif mode == "weight-only" :
126126 return unittest .skip ("unimplemented" )
127127
128- elif granularity == (PerBlock (( 1 , 128 )) , PerBlock (( 128 , 128 ) )):
128+ elif granularity == (PerBlock ([ 1 , 128 ]) , PerBlock ([ 128 , 128 ] )):
129129 if dtype is not torch .bfloat16 :
130130 return unittest .skip ("unimplemented" )
131131 elif mode != "dynamic" :
@@ -199,7 +199,7 @@ def test_fp8_linear_variants(
199199 assert qs1 .shape == (N , 1 )
200200 assert qs2 .shape == (K , 1 )
201201 else :
202- assert granularity == (PerBlock (( 1 , 128 )) , PerBlock (( 128 , 128 ) ))
202+ assert granularity == (PerBlock ([ 1 , 128 ]) , PerBlock ([ 128 , 128 ] ))
203203 assert qs1 .shape == (N // 128 , K // 128 )
204204 assert qs2 .shape == (K // 128 , N // 128 )
205205
0 commit comments