|
9 | 9 | import os |
10 | 10 | import sys |
11 | 11 | import unittest |
12 | | -from typing import Dict, List |
| 12 | +from typing import cast, Dict, List |
13 | 13 |
|
14 | 14 | import cloudpickle |
15 | | -import monarch.actor |
16 | 15 | import torch |
| 16 | +import torch.distributed as dist |
| 17 | +from monarch._src.actor.actor_mesh import ActorMesh |
17 | 18 | from monarch._src.actor.host_mesh import create_local_host_mesh, fake_in_process_host |
18 | | -from monarch.actor import Actor, endpoint |
| 19 | +from monarch.actor import Actor, current_rank, current_size, endpoint, this_host |
19 | 20 |
|
20 | 21 |
|
21 | 22 | class CudaInitTestActor(Actor): |
@@ -46,6 +47,42 @@ async def is_cuda_initialized(self) -> bool: |
46 | 47 | return self.cuda_initialized |
47 | 48 |
|
48 | 49 |
|
| 50 | +class TorchDistributedActor(Actor): |
| 51 | + """Actor that initializes CUDA and checks environment variables""" |
| 52 | + |
| 53 | + def __init__(self) -> None: |
| 54 | + self.rank = int(current_rank()["gpus"]) |
| 55 | + self.world_size = int(current_size()["gpus"]) |
| 56 | + self.port = 29500 |
| 57 | + os.environ["MASTER_ADDR"] = "localhost" |
| 58 | + os.environ["MASTER_PORT"] = str(self.port) |
| 59 | + |
| 60 | + @endpoint |
| 61 | + def init_torch_distributed(self) -> None: |
| 62 | + if not dist.is_initialized(): |
| 63 | + dist.init_process_group( |
| 64 | + backend="nccl", |
| 65 | + world_size=self.world_size, |
| 66 | + rank=self.rank, |
| 67 | + ) |
| 68 | + |
| 69 | + @endpoint |
| 70 | + def is_initialized(self) -> bool: |
| 71 | + return dist.is_initialized() |
| 72 | + |
| 73 | + # Cleanup is a special function called automatically on actor stop. |
| 74 | + def __cleanup__(self, exc: Exception | None) -> None: |
| 75 | + self.logger.info(f"Cleanup called with exception: {exc}") |
| 76 | + if dist.is_initialized(): |
| 77 | + dist.destroy_process_group() |
| 78 | + |
| 79 | + |
| 80 | +class IsTorchInitializedActor(Actor): |
| 81 | + @endpoint |
| 82 | + def is_initialized(self) -> bool: |
| 83 | + return dist.is_initialized() |
| 84 | + |
| 85 | + |
49 | 86 | class TestEnvBeforeCuda(unittest.IsolatedAsyncioTestCase): |
50 | 87 | """Test that the env vars are setup before cuda init""" |
51 | 88 |
|
@@ -149,3 +186,16 @@ async def test_proc_mesh_with_dictionary_env(self) -> None: |
149 | 186 | env_vars.get("CUDA_DEVICE_MAX_CONNECTIONS"), |
150 | 187 | "1", |
151 | 188 | ) |
| 189 | + |
| 190 | + async def test_cleanup_torch_distributed(self) -> None: |
| 191 | + """Test that calling stop on the actor destroys the process group""" |
| 192 | + proc_mesh = this_host().spawn_procs(per_host={"gpus": 1}) |
| 193 | + |
| 194 | + actor = proc_mesh.spawn("torch_init", TorchDistributedActor) |
| 195 | + tester = proc_mesh.spawn("check", IsTorchInitializedActor) |
| 196 | + await actor.init_torch_distributed.call_one() |
| 197 | + self.assertTrue(await actor.is_initialized.call_one()) |
| 198 | + # Stop the actor and ensure cleanup is called, by using another actor |
| 199 | + # on the same proc. |
| 200 | + await cast(ActorMesh[TorchDistributedActor], actor).stop() |
| 201 | + self.assertFalse(await tester.is_initialized.call_one()) |
0 commit comments