1111import os
1212import subprocess
1313import sys
14- from typing import Any , cast , Dict , FrozenSet , List , Optional , Sequence
14+ from typing import Any , Dict , FrozenSet , List , Optional , Sequence
1515
1616from monarch ._rust_bindings .monarch_hyperactor .channel import ChannelTransport
1717from monarch ._rust_bindings .monarch_hyperactor .config import configure
1818
1919from monarch ._src .actor .bootstrap import attach_to_workers
20- from monarch ._src .actor .host_mesh import HostMesh
2120from monarch ._src .job .job import JobState , JobTrait
2221
2322
@@ -55,6 +54,8 @@ def __init__(
5554 log_dir : Optional [str ] = None ,
5655 exclusive : bool = True ,
5756 gpus_per_node : Optional [int ] = None ,
57+ cpus_per_task : Optional [int ] = None ,
58+ mem : Optional [str ] = None ,
5859 ) -> None :
5960 """
6061 Args:
@@ -84,6 +85,8 @@ def __init__(
8485 self ._log_dir : str = log_dir if log_dir is not None else os .getcwd ()
8586 self ._exclusive = exclusive
8687 self ._gpus_per_node = gpus_per_node
88+ self ._cpus_per_task = cpus_per_task
89+ self ._mem = mem
8790 # Track the single SLURM job ID and all allocated hostnames
8891 self ._slurm_job_id : Optional [str ] = None
8992 self ._all_hostnames : List [str ] = []
@@ -128,12 +131,33 @@ def _submit_slurm_job(self, num_nodes: int) -> str:
128131 if self ._gpus_per_node is not None :
129132 sbatch_directives .append (f"#SBATCH --gpus-per-node={ self ._gpus_per_node } " )
130133
134+ if self ._cpus_per_task is not None :
135+ sbatch_directives .append (f"#SBATCH --cpus-per-task={ self ._cpus_per_task } " )
136+
137+ if self ._mem is not None :
138+ sbatch_directives .append (f"#SBATCH --mem={ self ._mem } " )
139+
131140 if self ._exclusive :
132141 sbatch_directives .append ("#SBATCH --exclusive" )
133142
134- if self ._partition :
143+ if self ._partition is not None :
135144 sbatch_directives .append (f"#SBATCH --partition={ self ._partition } " )
136145
146+ if (
147+ not self ._exclusive
148+ and self ._partition is not None
149+ and self ._gpus_per_node is not None
150+ ):
151+ gpus_per_task = self ._gpus_per_node // self ._ntasks_per_node
152+ assert (
153+ self ._partition
154+ ), "Slurm partition must be set for jobs that share nodes with other jobs"
155+ self .share_node (
156+ tasks_per_node = self ._ntasks_per_node ,
157+ gpus_per_task = gpus_per_task ,
158+ partition = self ._partition ,
159+ )
160+
137161 # Add any additional slurm args as directives
138162 for arg in self ._slurm_args :
139163 if arg .startswith ("-" ):
@@ -297,6 +321,8 @@ def can_run(self, spec: "JobTrait") -> bool:
297321 and spec ._time_limit == self ._time_limit
298322 and spec ._partition == self ._partition
299323 and spec ._gpus_per_node == self ._gpus_per_node
324+ and spec ._cpus_per_task == self ._cpus_per_task
325+ and spec ._mem == self ._mem
300326 and self ._jobs_active ()
301327 )
302328
@@ -318,6 +344,28 @@ def _jobs_active(self) -> bool:
318344
319345 return True
320346
347+ def share_node (
348+ self , tasks_per_node : int , gpus_per_task : int , partition : str
349+ ) -> None :
350+ """
351+ Share a node with other jobs.
352+ """
353+ try :
354+ import clusterscope
355+ except ImportError :
356+ raise RuntimeError (
357+ "please install clusterscope to use share_node. `pip install clusterscope`"
358+ )
359+ self ._exclusive = False
360+
361+ slurm_args = clusterscope .job_gen_task_slurm (
362+ partition = partition ,
363+ gpus_per_task = gpus_per_task ,
364+ tasks_per_node = tasks_per_node ,
365+ )
366+ self ._cpus_per_task = slurm_args ["cpus_per_task" ]
367+ self ._mem = slurm_args ["memory" ]
368+
321369 def _kill (self ) -> None :
322370 """Cancel the SLURM job."""
323371 if self ._slurm_job_id is not None :
0 commit comments