@@ -521,31 +521,45 @@ def test_averaging_trigger():
521521
522522
523523@pytest .mark .forked
524- def test_averaging_cancel ():
524+ @pytest .mark .parametrize ("target_group_size" , [None , 2 ])
525+ def test_averaging_cancel (target_group_size ):
526+ dht_instances = launch_dht_instances (4 )
525527 averagers = tuple (
526528 DecentralizedAverager (
527529 averaged_tensors = [torch .randn (3 )],
528530 dht = dht ,
529531 min_matchmaking_time = 0.5 ,
530532 request_timeout = 0.3 ,
531533 client_mode = (i % 2 == 0 ),
534+ target_group_size = target_group_size ,
532535 prefix = "mygroup" ,
533536 start = True ,
534537 )
535- for i , dht in enumerate (launch_dht_instances ( 4 ) )
538+ for i , dht in enumerate (dht_instances )
536539 )
537540
538- step_controls = [averager .step (wait = False , scheduled_time = hivemind .get_dht_time () + 1 ) for averager in averagers ]
541+ step_controls = [averager .step (wait = False , require_trigger = True ) for averager in averagers ]
542+
543+ peer_inds_to_cancel = (0 , 1 )
544+
545+ for peer_index in peer_inds_to_cancel :
546+ step_controls [peer_index ].cancel ()
539547
540548 time .sleep (0.05 )
541- step_controls [0 ].cancel ()
542- step_controls [1 ].cancel ()
543549
544550 for i , control in enumerate (step_controls ):
545- if i in (0 , 1 ):
551+ if i not in peer_inds_to_cancel :
552+ control .allow_allreduce ()
553+
554+ for i , control in enumerate (step_controls ):
555+ if i in peer_inds_to_cancel :
546556 assert control .cancelled ()
547557 else :
548- assert control .result () is not None and len (control .result ()) == 2
558+ result = control .result ()
559+ assert result is not None
560+ # Don't check group size when target_group_size=None, as it could change
561+ if target_group_size is not None :
562+ assert len (result ) == target_group_size
549563
550564 for averager in averagers :
551565 averager .shutdown ()
0 commit comments