3737)
3838from monarch ._rust_bindings .monarch_hyperactor .proc import ActorId
3939from monarch ._rust_bindings .monarch_hyperactor .pytokio import PythonTask
40+ from monarch ._rust_bindings .monarch_hyperactor .shape import Extent
4041
4142from monarch ._src .actor .actor_mesh import ActorMesh , Channel , context , Port
42- from monarch ._src .actor .allocator import AllocHandle
43+ from monarch ._src .actor .allocator import AllocHandle , ProcessAllocator
4344from monarch ._src .actor .future import Future
4445from monarch ._src .actor .host_mesh import (
4546 create_local_host_mesh ,
4647 fake_in_process_host ,
4748 HostMesh ,
4849)
49- from monarch ._src .actor .proc_mesh import ProcMesh
50+ from monarch ._src .actor .proc_mesh import _get_bootstrap_args , ProcMesh
5051from monarch ._src .actor .v1 .host_mesh import (
52+ _bootstrap_cmd ,
5153 fake_in_process_host as fake_in_process_host_v1 ,
5254 HostMesh as HostMeshV1 ,
5355 this_host as this_host_v1 ,
@@ -466,7 +468,7 @@ async def no_more(self) -> None:
466468
467469
468470@pytest .mark .parametrize ("v1" , [True , False ])
469- @pytest .mark .timeout (30 )
471+ @pytest .mark .timeout (60 )
470472async def test_async_concurrency (v1 : bool ):
471473 """Test that async endpoints will be processed concurrently."""
472474 pm = spawn_procs_on_this_host (v1 , {})
@@ -603,8 +605,9 @@ def _handle_undeliverable_message(
603605 return True
604606
605607
608+ @pytest .mark .parametrize ("v1" , [True , False ])
606609@pytest .mark .timeout (60 )
607- async def test_actor_log_streaming () -> None :
610+ async def test_actor_log_streaming (v1 : bool ) -> None :
608611 # Save original file descriptors
609612 original_stdout_fd = os .dup (1 ) # stdout
610613 original_stderr_fd = os .dup (2 ) # stderr
@@ -631,7 +634,7 @@ async def test_actor_log_streaming() -> None:
631634 sys .stderr = stderr_file
632635
633636 try :
634- pm = spawn_procs_on_this_host (v1 = False , per_host = {"gpus" : 2 })
637+ pm = spawn_procs_on_this_host (v1 , per_host = {"gpus" : 2 })
635638 am = pm .spawn ("printer" , Printer )
636639
637640 # Disable streaming logs to client
@@ -671,7 +674,10 @@ async def test_actor_log_streaming() -> None:
671674 await am .print .call ("has print streaming too" )
672675 await am .log .call ("has log streaming as level matched" )
673676
674- await pm .stop ()
677+ if not v1 :
678+ await pm .stop ()
679+ else :
680+ await asyncio .sleep (1 )
675681
676682 # Flush all outputs
677683 stdout_file .flush ()
@@ -752,8 +758,9 @@ async def test_actor_log_streaming() -> None:
752758 pass
753759
754760
761+ @pytest .mark .parametrize ("v1" , [True , False ])
755762@pytest .mark .timeout (120 )
756- async def test_alloc_based_log_streaming () -> None :
763+ async def test_alloc_based_log_streaming (v1 : bool ) -> None :
757764 """Test both AllocHandle.stream_logs = False and True cases."""
758765
759766 async def test_stream_logs_case (stream_logs : bool , test_name : str ) -> None :
@@ -770,23 +777,45 @@ async def test_stream_logs_case(stream_logs: bool, test_name: str) -> None:
770777
771778 try :
772779 # Create proc mesh with custom stream_logs setting
773- host_mesh = create_local_host_mesh ()
774- alloc_handle = host_mesh ._alloc (hosts = 1 , gpus = 2 )
780+ if not v1 :
781+ host_mesh = create_local_host_mesh ()
782+ alloc_handle = host_mesh ._alloc (hosts = 1 , gpus = 2 )
783+
784+ # Override the stream_logs setting
785+ custom_alloc_handle = AllocHandle (
786+ alloc_handle ._hy_alloc , alloc_handle ._extent , stream_logs
787+ )
788+
789+ pm = ProcMesh .from_alloc (custom_alloc_handle )
790+ else :
775791
776- # Override the stream_logs setting
777- custom_alloc_handle = AllocHandle (
778- alloc_handle ._hy_alloc , alloc_handle ._extent , stream_logs
779- )
792+ class ProcessAllocatorStreamLogs (ProcessAllocator ):
793+ def _stream_logs (self ) -> bool :
794+ return stream_logs
795+
796+ alloc = ProcessAllocatorStreamLogs (* _get_bootstrap_args ())
797+
798+ host_mesh = HostMeshV1 .allocate_nonblocking (
799+ "host" ,
800+ Extent (["hosts" ], [1 ]),
801+ alloc ,
802+ bootstrap_cmd = _bootstrap_cmd (),
803+ )
804+
805+ pm = host_mesh .spawn_procs (name = "proc" , per_host = {"gpus" : 2 })
780806
781- pm = ProcMesh .from_alloc (custom_alloc_handle )
782807 am = pm .spawn ("printer" , Printer )
783808
784809 await pm .initialized
785810
786811 for _ in range (5 ):
787812 await am .print .call (f"{ test_name } print streaming" )
788813
789- await pm .stop ()
814+ if not v1 :
815+ await pm .stop ()
816+ else :
817+ # Wait for at least the aggregation window (3 seconds)
818+ await asyncio .sleep (5 )
790819
791820 # Flush all outputs
792821 stdout_file .flush ()
@@ -810,18 +839,18 @@ async def test_stream_logs_case(stream_logs: bool, test_name: str) -> None:
810839 # When stream_logs=False, logs should not be streamed to client
811840 assert not re .search (
812841 rf"similar log lines.*{ test_name } print streaming" , stdout_content
813- ), f"stream_logs=True case: { stdout_content } "
842+ ), f"stream_logs=False case: { stdout_content } "
814843 assert re .search (
815844 rf"{ test_name } print streaming" , stdout_content
816- ), f"stream_logs=True case: { stdout_content } "
845+ ), f"stream_logs=False case: { stdout_content } "
817846 else :
818847 # When stream_logs=True, logs should be streamed to client (no aggregation by default)
819848 assert re .search (
820849 rf"similar log lines.*{ test_name } print streaming" , stdout_content
821- ), f"stream_logs=False case: { stdout_content } "
850+ ), f"stream_logs=True case: { stdout_content } "
822851 assert not re .search (
823852 rf"\[[0-9]\]{ test_name } print streaming" , stdout_content
824- ), f"stream_logs=False case: { stdout_content } "
853+ ), f"stream_logs=True case: { stdout_content } "
825854
826855 finally :
827856 # Ensure file descriptors are restored even if something goes wrong
@@ -836,8 +865,9 @@ async def test_stream_logs_case(stream_logs: bool, test_name: str) -> None:
836865 await test_stream_logs_case (True , "stream_logs_true" )
837866
838867
868+ @pytest .mark .parametrize ("v1" , [True , False ])
839869@pytest .mark .timeout (60 )
840- async def test_logging_option_defaults () -> None :
870+ async def test_logging_option_defaults (v1 : bool ) -> None :
841871 # Save original file descriptors
842872 original_stdout_fd = os .dup (1 ) # stdout
843873 original_stderr_fd = os .dup (2 ) # stderr
@@ -864,14 +894,18 @@ async def test_logging_option_defaults() -> None:
864894 sys .stderr = stderr_file
865895
866896 try :
867- pm = spawn_procs_on_this_host (v1 = False , per_host = {"gpus" : 2 })
897+ pm = spawn_procs_on_this_host (v1 , per_host = {"gpus" : 2 })
868898 am = pm .spawn ("printer" , Printer )
869899
870900 for _ in range (5 ):
871901 await am .print .call ("print streaming" )
872902 await am .log .call ("log streaming" )
873903
874- await pm .stop ()
904+ if not v1 :
905+ await pm .stop ()
906+ else :
907+ # Wait for > default aggregation window (3 seconds)
908+ await asyncio .sleep (5 )
875909
876910 # Flush all outputs
877911 stdout_file .flush ()
@@ -949,7 +983,8 @@ def __init__(self):
949983
950984# oss_skip: pytest keeps complaining about mocking get_ipython module
951985@pytest .mark .oss_skip
952- async def test_flush_called_only_once () -> None :
986+ @pytest .mark .parametrize ("v1" , [True , False ])
987+ async def test_flush_called_only_once (v1 : bool ) -> None :
953988 """Test that flush is called only once when ending an ipython cell"""
954989 mock_ipython = MockIPython ()
955990 with unittest .mock .patch (
@@ -961,8 +996,8 @@ async def test_flush_called_only_once() -> None:
961996 "monarch._src.actor.logging.flush_all_proc_mesh_logs"
962997 ) as mock_flush :
963998 # Create 2 proc meshes with a large aggregation window
964- pm1 = this_host (). spawn_procs ( per_host = {"gpus" : 2 })
965- _ = this_host (). spawn_procs ( per_host = {"gpus" : 2 })
999+ pm1 = spawn_procs_on_this_host ( v1 , per_host = {"gpus" : 2 })
1000+ _ = spawn_procs_on_this_host ( v1 , per_host = {"gpus" : 2 })
9661001 # flush not yet called unless post_run_cell
9671002 assert mock_flush .call_count == 0
9681003 assert mock_ipython .events .registers == 0
@@ -976,8 +1011,9 @@ async def test_flush_called_only_once() -> None:
9761011
9771012# oss_skip: pytest keeps complaining about mocking get_ipython module
9781013@pytest .mark .oss_skip
1014+ @pytest .mark .parametrize ("v1" , [True , False ])
9791015@pytest .mark .timeout (180 )
980- async def test_flush_logs_ipython () -> None :
1016+ async def test_flush_logs_ipython (v1 : bool ) -> None :
9811017 """Test that logs are flushed when get_ipython is available and post_run_cell event is triggered."""
9821018 # Save original file descriptors
9831019 original_stdout_fd = os .dup (1 ) # stdout
@@ -1003,8 +1039,8 @@ async def test_flush_logs_ipython() -> None:
10031039 ), unittest .mock .patch ("monarch._src.actor.logging.IN_IPYTHON" , True ):
10041040 # Make sure we can register and unregister callbacks
10051041 for _ in range (3 ):
1006- pm1 = this_host (). spawn_procs ( per_host = {"gpus" : 2 })
1007- pm2 = this_host (). spawn_procs ( per_host = {"gpus" : 2 })
1042+ pm1 = spawn_procs_on_this_host ( v1 , per_host = {"gpus" : 2 })
1043+ pm2 = spawn_procs_on_this_host ( v1 , per_host = {"gpus" : 2 })
10081044 am1 = pm1 .spawn ("printer" , Printer )
10091045 am2 = pm2 .spawn ("printer" , Printer )
10101046
@@ -1108,8 +1144,9 @@ async def test_flush_logs_fast_exit() -> None:
11081144 ), process .stdout
11091145
11101146
1147+ @pytest .mark .parametrize ("v1" , [True , False ])
11111148@pytest .mark .timeout (60 )
1112- async def test_flush_on_disable_aggregation () -> None :
1149+ async def test_flush_on_disable_aggregation (v1 : bool ) -> None :
11131150 """Test that logs are flushed when disabling aggregation.
11141151
11151152 This tests the corner case: "Make sure we flush whatever in the aggregators before disabling aggregation."
@@ -1130,7 +1167,7 @@ async def test_flush_on_disable_aggregation() -> None:
11301167 sys .stdout = stdout_file
11311168
11321169 try :
1133- pm = this_host (). spawn_procs ( per_host = {"gpus" : 2 })
1170+ pm = spawn_procs_on_this_host ( v1 , per_host = {"gpus" : 2 })
11341171 am = pm .spawn ("printer" , Printer )
11351172
11361173 # Set a long aggregation window to ensure logs aren't flushed immediately
@@ -1151,7 +1188,11 @@ async def test_flush_on_disable_aggregation() -> None:
11511188 for _ in range (5 ):
11521189 await am .print .call ("single log line" )
11531190
1154- await pm .stop ()
1191+ if not v1 :
1192+ await pm .stop ()
1193+ else :
1194+ # Wait for > default aggregation window (3 secs)
1195+ await asyncio .sleep (5 )
11551196
11561197 # Flush all outputs
11571198 stdout_file .flush ()
@@ -1197,14 +1238,15 @@ async def test_flush_on_disable_aggregation() -> None:
11971238 pass
11981239
11991240
1241+ @pytest .mark .parametrize ("v1" , [True , False ])
12001242@pytest .mark .timeout (120 )
1201- async def test_multiple_ongoing_flushes_no_deadlock () -> None :
1243+ async def test_multiple_ongoing_flushes_no_deadlock (v1 : bool ) -> None :
12021244 """
12031245 The goal is to make sure when a user sends multiple sync flushes, we are not deadlocked.
12041246 Because now a flush call is purely sync, it is very easy to get into a deadlock.
12051247 So we assert the last flush call will not get into such a state.
12061248 """
1207- pm = this_host (). spawn_procs ( per_host = {"gpus" : 4 })
1249+ pm = spawn_procs_on_this_host ( v1 , per_host = {"gpus" : 4 })
12081250 am = pm .spawn ("printer" , Printer )
12091251
12101252 # Generate some logs that will be aggregated but not flushed immediately
@@ -1227,8 +1269,9 @@ async def test_multiple_ongoing_flushes_no_deadlock() -> None:
12271269 futures [- 1 ].get ()
12281270
12291271
1272+ @pytest .mark .parametrize ("v1" , [True , False ])
12301273@pytest .mark .timeout (60 )
1231- async def test_adjust_aggregation_window () -> None :
1274+ async def test_adjust_aggregation_window (v1 : bool ) -> None :
12321275 """Test that the flush deadline is updated when the aggregation window is adjusted.
12331276
12341277 This tests the corner case: "This can happen if the user has adjusted the aggregation window."
@@ -1249,7 +1292,7 @@ async def test_adjust_aggregation_window() -> None:
12491292 sys .stdout = stdout_file
12501293
12511294 try :
1252- pm = this_host (). spawn_procs ( per_host = {"gpus" : 2 })
1295+ pm = spawn_procs_on_this_host ( v1 , per_host = {"gpus" : 2 })
12531296 am = pm .spawn ("printer" , Printer )
12541297
12551298 # Set a long aggregation window initially
@@ -1267,7 +1310,11 @@ async def test_adjust_aggregation_window() -> None:
12671310 for _ in range (3 ):
12681311 await am .print .call ("second batch of logs" )
12691312
1270- await pm .stop ()
1313+ if not v1 :
1314+ await pm .stop ()
1315+ else :
1316+ # Wait for > aggregation window (2 secs)
1317+ await asyncio .sleep (4 )
12711318
12721319 # Flush all outputs
12731320 stdout_file .flush ()
0 commit comments