@@ -303,27 +303,27 @@ class Case:
303303 ],
304304)
305305@pytest .mark .parametrize ("block_m" , [16 , 128 ])
306- @pytest .mark .parametrize ("do_gather, do_scatter, fused_scatter, inner_expt_opt" , [
307- (False , False , False , None ),
308- (True , False , False , None ),
309- (False , True , False , None ),
310- (False , True , True , None ),
311- (True , True , False , None ),
312- (True , True , True , None ),
313- (False , False , False , "pad_w" ),
314- (False , False , False , "pad_x" ),
306+ @pytest .mark .parametrize ("do_gather, do_scatter, inner_expt_opt" , [
307+ (False , False , None ),
308+ (True , False , None ),
309+ (False , True , None ),
310+ (False , True , None ),
311+ (True , True , None ),
312+ (True , True , None ),
313+ (False , False , "pad_w" ),
314+ (False , False , "pad_x" ),
315315])
316316@pytest .mark .parametrize ("has_y_gammas" , [False , True ])
317317@pytest .mark .parametrize ("is_persistent" , [False , True ])
318- def test_op (m , n , k , split_k , do_gather , do_scatter , fused_scatter , inner_expt_opt , has_y_gammas , is_persistent , n_expts_tot ,
318+ def test_op (m , n , k , split_k , do_gather , do_scatter , inner_expt_opt , has_y_gammas , is_persistent , n_expts_tot ,
319319 n_expts_act , mode , act_dtype_str , weight_dtype_str , block_m , hbm_swizzling , colmajor_mxfp_weight , epilogue_subtile ,
320320 x_transpose , w_transpose , y_transpose ,
321321 device , opt_flags_scope ):
322322 # We catch and re-invoke pytest.skip(), because otherwise pytest may hold a reference to
323323 # the frame that called pytest.skip, including all the tensors, leading to OOM.
324324 skip_message = None
325325 try :
326- _test_op (m , n , k , split_k , do_gather , do_scatter , fused_scatter , inner_expt_opt , has_y_gammas , is_persistent , n_expts_tot ,
326+ _test_op (m , n , k , split_k , do_gather , do_scatter , inner_expt_opt , has_y_gammas , is_persistent , n_expts_tot ,
327327 n_expts_act , mode , act_dtype_str , weight_dtype_str , block_m , hbm_swizzling , colmajor_mxfp_weight , epilogue_subtile ,
328328 x_transpose , w_transpose , y_transpose ,
329329 device , opt_flags_scope )
@@ -333,7 +333,7 @@ def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, inner_expt_o
333333 if skip_message is not None :
334334 pytest .skip (skip_message )
335335
336- def _test_op (m , n , k , split_k , do_gather , do_scatter , fused_scatter , inner_expt_opt , has_y_gammas , is_persistent , n_expts_tot ,
336+ def _test_op (m , n , k , split_k , do_gather , do_scatter , inner_expt_opt , has_y_gammas , is_persistent , n_expts_tot ,
337337 n_expts_act , mode , act_dtype_str , weight_dtype_str , block_m , hbm_swizzling , colmajor_mxfp_weight , epilogue_subtile ,
338338 x_transpose , w_transpose , y_transpose ,
339339 device , opt_flags_scope ):
@@ -362,9 +362,6 @@ def _test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, inner_expt_
362362 if "float8_e4m3fnuz" in (weight_dtype_str , act_dtype_str ) and not is_hip_cdna3 ():
363363 pytest .xfail ("float8_e4m3fnuz only tested on AMD CDNA3 Platform" )
364364
365- if fused_scatter and split_k is not None and split_k > 1 :
366- pytest .xfail ("fused scatter scratchpad not supported with split_k" )
367-
368365 if hbm_swizzling :
369366 if is_hip ():
370367 if not is_hip_cdna4 ():
@@ -414,7 +411,6 @@ def _test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, inner_expt_
414411 "block_m" : block_m ,
415412 "block_k" : block_k ,
416413 "split_k" : split_k ,
417- "fused_scatter" : fused_scatter ,
418414 "is_persistent" : is_persistent ,
419415 "epilogue_subtile" : epilogue_subtile ,
420416 }
@@ -727,12 +723,11 @@ def test_set_idle_sms():
727723 (800 , 800 , 400 , "batched" ),
728724])
729725@pytest .mark .parametrize ("split_k" , [1 , 2 ])
730- @pytest .mark .parametrize ("do_gather, do_scatter, fused_scatter" , [
731- (False , False , False ),
732- (True , False , False ),
733- (False , True , False ),
734- (True , True , False ),
735- (True , True , True ),
726+ @pytest .mark .parametrize ("do_gather, do_scatter" , [
727+ (False , False ),
728+ (True , False ),
729+ (False , True ),
730+ (True , True ),
736731])
737732@pytest .mark .parametrize ("is_persistent, epilogue_subtile" , [
738733 (False , None ),
@@ -744,16 +739,13 @@ def test_set_idle_sms():
744739 (1.0 , 1.2 ),
745740 (0.7 , 1.0 ),
746741])
747- def test_fused_act (m , n , k , mode , split_k , do_gather , do_scatter , fused_scatter , is_persistent , epilogue_subtile ,
742+ def test_fused_act (m , n , k , mode , split_k , do_gather , do_scatter , is_persistent , epilogue_subtile ,
748743 swiglu_alpha , swiglu_limit , device , opt_flags_scope ):
749- if fused_scatter and split_k > 1 :
750- pytest .xfail ("fused scatter scratchpad not supported with split_k" )
751744 torch .manual_seed (0 )
752745 constraints = {
753746 "is_persistent" : is_persistent ,
754747 "epilogue_subtile" : epilogue_subtile ,
755748 "split_k" : split_k ,
756- "fused_scatter" : fused_scatter ,
757749 }
758750 n_expts_tot , n_expts_act = 1 , 1
759751 opt_flags .update_opt_flags_constraints (constraints )
0 commit comments