@@ -293,38 +293,39 @@ def test_infer_layout_cast_layout(self):
293293 self .checkInLayouts (cast , [wgmma_layout ])
294294 self .checkOutLayouts (cast , [wgmma_layout ])
295295
296- @parameterized .parameters (
297- (0 , mgpu .WGMMA_ROW_LAYOUT , None ),
298- (1 , mgpu .WGMMA_COL_LAYOUT , None ),
299- (0 , None , mgpu .WGMMA_LAYOUT ),
300- (1 , None , mgpu .WGMMA_LAYOUT ),
301- (0 , mgpu .TCGEN05_ROW_LAYOUT , None ),
302- (0 , None , mgpu .TCGEN05_LAYOUT ),
303- (1 , None , mgpu .TCGEN05_LAYOUT ),
296+ @parameterized .product (
297+ layout = (
298+ mtu .RegisterLayout .WGMMA ,
299+ mtu .RegisterLayout .TCGEN05 ,
300+ mtu .RegisterLayout .TCGEN05_TMEM_NATIVE ,
301+ mtu .RegisterLayout .TCGEN05_M64_COLLECTIVE ,
302+ ),
303+ axis = (0 , 1 ),
304+ hint_on_input = (True , False ),
304305 )
305- def test_infer_broadcast_in_dim_layout (self , broadcast_dim , in_cast , out_cast ):
306+ def test_infer_broadcast_in_dim_layout (self , layout , axis , hint_on_input ):
306307 in_shape = (128 ,)
307308 out_shape = (128 , 128 )
309+ dtype = ir .F32Type .get ()
310+ out_layout = layout .to_mgpu (out_shape , dtype )
311+ in_layout = out_layout .reduce ((1 - axis ,))
308312
309313 with ir .InsertionPoint (self .module .body ):
310- [x ] = undefs (ir .VectorType .get (in_shape , ir .F32Type .get ()))
311- x = layout_cast (x , in_cast ) if in_cast is not None else x
312- out_type = ir .VectorType .get (out_shape , ir .F32Type .get ())
313- bcast = mgpu .dialect .BroadcastInDimOp (out_type , x , [broadcast_dim ])
314- if out_cast is not None :
315- layout_cast (bcast .result , out_cast )
316-
317- # The tests always expect WGMMA or TCGEN05 as the out layout.
318- if out_cast == mgpu .TCGEN05_LAYOUT or in_cast == mgpu .TCGEN05_ROW_LAYOUT :
319- out_layout = mgpu .TCGEN05_LAYOUT
320- else :
321- out_layout = mgpu .WGMMA_LAYOUT
314+ [x ] = undefs (ir .VectorType .get (in_shape , dtype ))
315+ if hint_on_input :
316+ x = layout_cast (x , in_layout )
317+ out_type = ir .VectorType .get (out_shape , dtype )
318+ bcast = mgpu .dialect .BroadcastInDimOp (out_type , x , [axis ])
319+ if not hint_on_input :
320+ layout_cast (bcast .result , out_layout )
322321
323- in_layout = out_layout .reduce ((1 - broadcast_dim ,))
322+ if hint_on_input and axis == 1 and layout == mtu .RegisterLayout .TCGEN05 :
323+ # Both TCGEN05 and WGMMA are valid layout candidates. WGMMA is tried first.
324+ out_layout = fa .WGMMA_LAYOUT
324325
325326 mgpu .infer_layout (self .module )
326- self .checkInLayouts (bcast , [layouts . to_layout_attr ( in_layout ) ])
327- self .checkOutLayouts (bcast , [layouts . to_layout_attr ( out_layout ) ])
327+ self .checkInLayouts (bcast , [in_layout ])
328+ self .checkOutLayouts (bcast , [out_layout ])
328329
329330 @parameterized .parameters (
330331 (1 , mgpu .WGMMA_LAYOUT , None , None ),
0 commit comments