@@ -61,6 +61,7 @@ def __init__(
6161 mounts : Optional [List [str ]] = None ,
6262 rdzv_port : int = 29500 ,
6363 scheduler_args : Optional [Dict [str , str ]] = None ,
64+ image : Optional [str ] = None ,
6465 ):
6566 if bool (script ) == bool (m ): # logical XOR
6667 raise ValueError (
@@ -82,6 +83,7 @@ def __init__(
8283 self .scheduler_args : Dict [str , str ] = (
8384 scheduler_args if scheduler_args is not None else dict ()
8485 )
86+ self .image = image
8587
8688 def _dry_run (self , cluster : "Cluster" ):
8789 j = f"{ cluster .config .max_worker } x{ max (cluster .config .gpu , 1 )} " # # of proc. = # of gpus
@@ -108,15 +110,58 @@ def _dry_run(self, cluster: "Cluster"):
108110 workspace = f"file://{ Path .cwd ()} " ,
109111 )
110112
111- def submit (self , cluster : "Cluster" ) -> "Job" :
113+ def _missing_spec (self , spec : str ):
114+ raise ValueError (f"Job definition missing arg: { spec } " )
115+
116+ def _dry_run_no_cluster (self ):
117+ return torchx_runner .dryrun (
118+ app = ddp (
119+ * self .script_args ,
120+ script = self .script ,
121+ m = self .m ,
122+ name = self .name if self .name is not None else self ._missing_spec ("name" ),
123+ h = self .h ,
124+ cpu = self .cpu
125+ if self .cpu is not None
126+ else self ._missing_spec ("cpu (# cpus per worker)" ),
127+ gpu = self .gpu
128+ if self .gpu is not None
129+ else self ._missing_spec ("gpu (# gpus per worker)" ),
130+ memMB = self .memMB
131+ if self .memMB is not None
132+ else self ._missing_spec ("memMB (memory in MB)" ),
133+ j = self .j
134+ if self .j is not None
135+ else self ._missing_spec (
136+ "j (`workers`x`procs`)"
137+ ), # # of proc. = # of gpus,
138+ env = self .env , # should this still exist?
139+ max_retries = self .max_retries ,
140+ rdzv_port = self .rdzv_port , # should this still exist?
141+ mounts = self .mounts ,
142+ image = self .image
143+ if self .image is not None
144+ else self ._missing_spec ("image" ),
145+ ),
146+ scheduler = "kubernetes_mcad" ,
147+ cfg = self .scheduler_args if self .scheduler_args is not None else None ,
148+ workspace = "" ,
149+ )
150+
151+ def submit (self , cluster : "Cluster" = None ) -> "Job" :
112152 return DDPJob (self , cluster )
113153
114154
115155class DDPJob (Job ):
116156 def __init__ (self , job_definition : "DDPJobDefinition" , cluster : "Cluster" ):
117157 self .job_definition = job_definition
118158 self .cluster = cluster
119- self ._app_handle = torchx_runner .schedule (job_definition ._dry_run (cluster ))
159+ if self .cluster :
160+ self ._app_handle = torchx_runner .schedule (job_definition ._dry_run (cluster ))
161+ else :
162+ self ._app_handle = torchx_runner .schedule (
163+ job_definition ._dry_run_no_cluster ()
164+ )
120165 all_jobs .append (self )
121166
122167 def status (self ) -> str :
0 commit comments