Skip to content

Commit 29c6584

Browse files
allenwang28Allen Wang
andauthored
Removes the legacy GPU manager, updates the GPU manager to (#534)
Co-authored-by: Allen Wang <allencwang@fb.com>
1 parent 557a00c commit 29c6584

File tree

5 files changed

+112
-363
lines changed

5 files changed

+112
-363
lines changed

src/forge/controller/provisioner.py

Lines changed: 64 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
import socket
1313
import uuid
1414

15+
import torch
16+
1517
from monarch._src.actor.actor_mesh import ActorMesh
1618
from monarch._src.actor.shape import Extent
1719

@@ -41,8 +43,19 @@ class _RemoteInfoFetcher(Actor):
4143

4244
@endpoint
4345
def get_info(self) -> tuple[str, str]:
46+
"""Returns hostname and port."""
4447
return socket.gethostname(), _get_port()
4548

49+
@endpoint
50+
def get_gpu_count(self) -> int:
51+
"""Returns the number of GPUs available on this host."""
52+
try:
53+
gpu_count = torch.cuda.device_count()
54+
except Exception:
55+
# If torch is not available or CUDA is not available, assume no GPUs
56+
gpu_count = 0
57+
return gpu_count
58+
4659

4760
class EnvSetter(Actor):
4861
"""Actor to set environment variables on each proc in a mesh.
@@ -87,14 +100,26 @@ async def get_remote_info(host_mesh: HostMesh) -> tuple[str, str]:
87100
singleton_slice = {k: slice(0, 1) for k in fetcher.extent.keys()}
88101
fetcher = fetcher.slice(**singleton_slice)
89102
# Fetcher should be a singleton at this point - call_one() will fail otherwise
90-
91103
host, port = await fetcher.get_info.call_one()
92104

93105
# Stopping this proc is the right thing to do, but Monarch does not yet handle manual stops well.
94106
# await throwaway_procs.stop()
95107
return host, port
96108

97109

110+
async def get_host_gpus(host_mesh: HostMesh) -> int:
111+
"""Returns the number of GPUs available on the host mesh."""
112+
throwaway_procs = host_mesh.spawn_procs(per_host={"procs": 1})
113+
fetcher = throwaway_procs.spawn("_gpu_counter", _RemoteInfoFetcher)
114+
115+
# Reduce to a singleton
116+
singleton_slice = {k: slice(0, 1) for k in fetcher.extent.keys()}
117+
fetcher = fetcher.slice(**singleton_slice)
118+
119+
gpu_count = await fetcher.get_gpu_count.call_one()
120+
return gpu_count
121+
122+
98123
async def set_environment(proc_mesh: ProcMesh, env_vars: dict[str, str]):
99124
"""Set environment variables on a proc mesh using EnvSetter actor.
100125
@@ -112,17 +137,35 @@ async def set_environment(proc_mesh: ProcMesh, env_vars: dict[str, str]):
112137
class GpuManager:
113138
"""Tracks and assigns GPU devices on a host.
114139
115-
This currently mimics the `gpu_manager` in system_controllers - we will
116-
consolidate as part of the "proper HostMesh integration" work.
140+
Args:
141+
available_devices: Set of GPU device IDs to manage. If None, uses all devices from 0 to max_device_count-1.
142+
max_device_count: Maximum number of GPU devices on this host. Defaults to 8.
117143
118144
"""
119145

120-
def __init__(self, available_devices: set[int] | None = None):
146+
def __init__(
147+
self, available_devices: set[int] | None = None, max_device_count: int = 8
148+
):
121149
if available_devices is None:
122-
available_devices = set(range(0, 8))
123-
assert all(isinstance(x, int) for x in available_devices)
124-
assert all(x >= 0 and x < 8 for x in available_devices)
150+
available_devices = set(range(0, max_device_count))
151+
else:
152+
# Validate types first
153+
assert all(
154+
isinstance(x, int) for x in available_devices
155+
), f"All device IDs must be integers, got: {available_devices}"
156+
# When available_devices is provided (e.g., from CUDA_VISIBLE_DEVICES),
157+
# adjust max_device_count to accommodate the highest device ID
158+
if available_devices:
159+
max_device_count = max(max(available_devices) + 1, max_device_count)
160+
161+
assert all(
162+
isinstance(x, int) for x in available_devices
163+
), f"All device IDs must be integers, got: {available_devices}"
164+
assert all(
165+
x >= 0 for x in available_devices
166+
), f"All device IDs must be non-negative, got: {available_devices}"
125167
self.available_gpus = available_devices
168+
self.max_device_count = max_device_count
126169

127170
def get_available_gpus(self) -> list[str]:
128171
"""Returns a list of available GPU devices."""
@@ -171,8 +214,18 @@ def __init__(self, cfg: ProvisionerConfig | None = None):
171214
f"Invalid CUDA_VISIBLE_DEVICES format: '{cuda_visible_devices}'. "
172215
f"Expected comma-separated integers (e.g., '0,1,2'). Error: {e}"
173216
) from e
217+
218+
# Get the actual GPU count for the local host
219+
try:
220+
local_gpu_count = torch.cuda.device_count()
221+
except Exception:
222+
# If torch is not available or CUDA is not available, assume no GPUs
223+
local_gpu_count = 0
224+
174225
self._host_gpu_map = {
175-
self._this_host_id: GpuManager(available_local_devices),
226+
self._this_host_id: GpuManager(
227+
available_local_devices, max_device_count=local_gpu_count
228+
),
176229
}
177230
self._proc_host_map = {}
178231
self._host_mesh_map = {}
@@ -277,7 +330,9 @@ async def get_proc_mesh(
277330
num_hosts=num_hosts,
278331
)
279332
host_id = uuid.uuid1()
280-
gpu_manager = GpuManager()
333+
# Get the GPU count from the remote host
334+
remote_gpu_count = await get_host_gpus(host_mesh)
335+
gpu_manager = GpuManager(max_device_count=remote_gpu_count)
281336
self._host_gpu_map[host_id] = gpu_manager
282337
host_mesh._host_id = host_id
283338
else:

src/forge/controller/system_controllers/__init__.py

Lines changed: 0 additions & 12 deletions
This file was deleted.

src/forge/controller/system_controllers/gpu_manager.py

Lines changed: 0 additions & 73 deletions
This file was deleted.

0 commit comments

Comments
 (0)