@@ -444,25 +444,27 @@ def test_bmm(self):
444444 # only support per row quantization
445445 config = Float8DynamicActivationFloat8WeightConfig (granularity = PerRow ())
446446
447- class M (torch .nn .Module ):
447+ class Model (torch .nn .Module ):
448448 def __init__ (self , weight ):
449449 super ().__init__ ()
450450 self .weight = weight
451451
452452 def forward (self , x ):
453- return torch .bmm (x , self .weight )
453+ return torch .bmm (x , self .weight . transpose ( - 2 , - 1 ) )
454454
455455 dtype = torch .bfloat16
456456 device = "cuda"
457- input = torch .randn (10 , 32 , 128 , dtype = dtype , device = device )
458- weight = torch .randn (10 , 128 , 256 , dtype = dtype , device = device )
459- m = M (weight ).eval ()
457+
458+ B , M , K , N = 10 , 32 , 128 , 256
459+
460+ input = torch .randn (B , M , K , dtype = dtype , device = device )
461+ weight = torch .randn (B , N , K , dtype = dtype , device = device )
462+ m = Model (weight ).eval ()
460463 original = m (input )
461- # we need to transpose the weight first for bmm
462- m .weight = torch .nn .Parameter (m .weight .transpose (1 , 2 ).contiguous ())
463464 quantize_ (m , config , filter_fn = lambda x , fqn : True )
464465 quantized = m (input )
465- self .assertTrue (compute_error (original , quantized ) > 20 )
466+ sqnr = compute_error (original , quantized )
467+ self .assertTrue (sqnr > 20 )
466468
467469 @common_utils .parametrize ("granularity" , [PerTensor (), PerRow ()])
468470 @common_utils .parametrize (
@@ -551,6 +553,10 @@ def test_cat(self, granularity, sizes):
551553 self .assertEqual (cat_qweight2 .qdata , ref_data )
552554 self .assertEqual (cat_qweight2 .scale , ref_scale )
553555
556+ # TODO(future PR): add this back
557+ @unittest .skip (
558+ "This requires rowwise scaling for weight in layout BKN across axis 1 to work"
559+ )
554560 @unittest .skipIf (not is_sm_at_least_90 (), "Nedd sm90+" )
555561 @unittest .skipIf (not _is_fbgemm_gpu_genai_available (), "Need fbgemm_gpu_genai" )
556562 def test_moe_weight_reshape_ops (self ):
0 commit comments