1212import socket
1313import uuid
1414
15+ import torch
16+
1517from monarch ._src .actor .actor_mesh import ActorMesh
1618from monarch ._src .actor .shape import Extent
1719
@@ -41,8 +43,19 @@ class _RemoteInfoFetcher(Actor):
4143
4244 @endpoint
4345 def get_info (self ) -> tuple [str , str ]:
46+ """Returns hostname and port."""
4447 return socket .gethostname (), _get_port ()
4548
49+ @endpoint
50+ def get_gpu_count (self ) -> int :
51+ """Returns the number of GPUs available on this host."""
52+ try :
53+ gpu_count = torch .cuda .device_count ()
54+ except Exception :
55+ # If torch is not available or CUDA is not available, assume no GPUs
56+ gpu_count = 0
57+ return gpu_count
58+
4659
4760class EnvSetter (Actor ):
4861 """Actor to set environment variables on each proc in a mesh.
@@ -87,14 +100,26 @@ async def get_remote_info(host_mesh: HostMesh) -> tuple[str, str]:
87100 singleton_slice = {k : slice (0 , 1 ) for k in fetcher .extent .keys ()}
88101 fetcher = fetcher .slice (** singleton_slice )
89102 # Fetcher should be a singleton at this point - call_one() will fail otherwise
90-
91103 host , port = await fetcher .get_info .call_one ()
92104
93105 # Stopping this proc is the right thing to do, but Monarch does not yet handle manual stops well.
94106 # await throwaway_procs.stop()
95107 return host , port
96108
97109
110+ async def get_host_gpus (host_mesh : HostMesh ) -> int :
111+ """Returns the number of GPUs available on the host mesh."""
112+ throwaway_procs = host_mesh .spawn_procs (per_host = {"procs" : 1 })
113+ fetcher = throwaway_procs .spawn ("_gpu_counter" , _RemoteInfoFetcher )
114+
115+ # Reduce to a singleton
116+ singleton_slice = {k : slice (0 , 1 ) for k in fetcher .extent .keys ()}
117+ fetcher = fetcher .slice (** singleton_slice )
118+
119+ gpu_count = await fetcher .get_gpu_count .call_one ()
120+ return gpu_count
121+
122+
98123async def set_environment (proc_mesh : ProcMesh , env_vars : dict [str , str ]):
99124 """Set environment variables on a proc mesh using EnvSetter actor.
100125
@@ -112,17 +137,35 @@ async def set_environment(proc_mesh: ProcMesh, env_vars: dict[str, str]):
112137class GpuManager :
113138 """Tracks and assigns GPU devices on a host.
114139
115- This currently mimics the `gpu_manager` in system_controllers - we will
116- consolidate as part of the "proper HostMesh integration" work.
140+ Args:
141+ available_devices: Set of GPU device IDs to manage. If None, uses all devices from 0 to max_device_count-1.
142+ max_device_count: Maximum number of GPU devices on this host. Defaults to 8.
117143
118144 """
119145
120- def __init__ (self , available_devices : set [int ] | None = None ):
146+ def __init__ (
147+ self , available_devices : set [int ] | None = None , max_device_count : int = 8
148+ ):
121149 if available_devices is None :
122- available_devices = set (range (0 , 8 ))
123- assert all (isinstance (x , int ) for x in available_devices )
124- assert all (x >= 0 and x < 8 for x in available_devices )
150+ available_devices = set (range (0 , max_device_count ))
151+ else :
152+ # Validate types first
153+ assert all (
154+ isinstance (x , int ) for x in available_devices
155+ ), f"All device IDs must be integers, got: { available_devices } "
156+ # When available_devices is provided (e.g., from CUDA_VISIBLE_DEVICES),
157+ # adjust max_device_count to accommodate the highest device ID
158+ if available_devices :
159+ max_device_count = max (max (available_devices ) + 1 , max_device_count )
160+
161+ assert all (
162+ isinstance (x , int ) for x in available_devices
163+ ), f"All device IDs must be integers, got: { available_devices } "
164+ assert all (
165+ x >= 0 for x in available_devices
166+ ), f"All device IDs must be non-negative, got: { available_devices } "
125167 self .available_gpus = available_devices
168+ self .max_device_count = max_device_count
126169
127170 def get_available_gpus (self ) -> list [str ]:
128171 """Returns a list of available GPU devices."""
@@ -171,8 +214,18 @@ def __init__(self, cfg: ProvisionerConfig | None = None):
171214 f"Invalid CUDA_VISIBLE_DEVICES format: '{ cuda_visible_devices } '. "
172215 f"Expected comma-separated integers (e.g., '0,1,2'). Error: { e } "
173216 ) from e
217+
218+ # Get the actual GPU count for the local host
219+ try :
220+ local_gpu_count = torch .cuda .device_count ()
221+ except Exception :
222+ # If torch is not available or CUDA is not available, assume no GPUs
223+ local_gpu_count = 0
224+
174225 self ._host_gpu_map = {
175- self ._this_host_id : GpuManager (available_local_devices ),
226+ self ._this_host_id : GpuManager (
227+ available_local_devices , max_device_count = local_gpu_count
228+ ),
176229 }
177230 self ._proc_host_map = {}
178231 self ._host_mesh_map = {}
@@ -277,7 +330,9 @@ async def get_proc_mesh(
277330 num_hosts = num_hosts ,
278331 )
279332 host_id = uuid .uuid1 ()
280- gpu_manager = GpuManager ()
333+ # Get the GPU count from the remote host
334+ remote_gpu_count = await get_host_gpus (host_mesh )
335+ gpu_manager = GpuManager (max_device_count = remote_gpu_count )
281336 self ._host_gpu_map [host_id ] = gpu_manager
282337 host_mesh ._host_id = host_id
283338 else :
0 commit comments