@@ -186,6 +186,12 @@ def _req_keepvars_default(self):
186186 "specification." ,
187187 ).tag (config = True )
188188
189+ connect_to_job_cmd = Unicode ('' ,
190+ help = "Command to connect to running batch job and forward the port "
191+ "of the running notebook to the Hub. If empty, direct connectivity is assumed. "
192+ "Uses self.job_id as {job_id} and the self.port as {port}."
193+ ).tag (config = True )
194+
189195 # Raw output of job submission command unless overridden
190196 job_id = Unicode ()
191197
@@ -215,6 +221,18 @@ def cmd_formatted_for_batch(self):
215221 """The command which is substituted inside of the batch script"""
216222 return " " .join ([self .batchspawner_singleuser_cmd ] + self .cmd + self .get_args ())
217223
224+ async def connect_to_job (self ):
225+ """This command ensures the port of the singleuser server is reachable from the
226+ Batchspawner machine. By default, it does nothing, i.e. direct connectivity
227+ is assumed.
228+ """
229+ subvars = self .get_req_subvars ()
230+ subvars ['job_id' ] = self .job_id
231+ subvars ['port' ] = self .port
232+ cmd = ' ' .join ((format_template (self .exec_prefix , ** subvars ),
233+ format_template (self .connect_to_job_cmd , ** subvars )))
234+ await self .run_background_command (cmd )
235+
218236 async def run_command (self , cmd , input = None , env = None ):
219237 proc = await asyncio .create_subprocess_shell (
220238 cmd ,
@@ -268,6 +286,46 @@ async def run_command(self, cmd, input=None, env=None):
268286 out = out .decode ().strip ()
269287 return out
270288
289+ # List of running background processes, e.g. used by connect_to_job.
290+ background_processes = []
291+
292+ async def _async_wait_process (self , sleep_time ):
293+ """Asynchronously sleeping process for delayed checks"""
294+ await asyncio .sleep (sleep_time )
295+
296+ async def run_background_command (self , cmd , startup_check_delay = 1 , input = None , env = None ):
297+ """Runs the given background command, adds it to background_processes,
298+ and checks if the command is still running after startup_check_delay."""
299+ background_process = self .run_command (cmd , input , env )
300+ success_check_delay = self ._async_wait_process (startup_check_delay )
301+
302+ # Start up both the success check process and the actual process.
303+ done , pending = await asyncio .wait ([background_process , success_check_delay ], return_when = asyncio .FIRST_COMPLETED )
304+
305+ # If the success check process is the one which exited first, all is good, else fail.
306+ if list (done )[0 ]._coro == success_check_delay :
307+ background_task = list (pending )[0 ]
308+ self .background_processes .append (background_task )
309+ return background_task
310+ else :
311+ self .log .error ("Background command exited early: %s" % cmd )
312+ gather_pending = asyncio .gather (* pending )
313+ gather_pending .cancel ()
314+ try :
315+ self .log .debug ("Cancelling pending success check task..." )
316+ await gather_pending
317+ except asyncio .CancelledError :
318+ self .log .debug ("Cancel was successful." )
319+ pass
320+
321+ # Retrieve exception from "done" process.
322+ try :
323+ gather_done = asyncio .gather (* done )
324+ await gather_done
325+ except :
326+ self .log .debug ("Retrieving exception from failed background task..." )
327+ raise RuntimeError ('{} failed!' .format (cmd ))
328+
271329 async def _get_batch_script (self , ** subvars ):
272330 """Format batch script from vars"""
273331 # Could be overridden by subclasses, but mainly useful for testing
@@ -299,6 +357,27 @@ async def submit_batch_script(self):
299357 self .job_id = ""
300358 return self .job_id
301359
360+ def background_tasks_ok (self ):
361+ # Check background processes.
362+ if self .background_processes :
363+ self .log .debug ('Checking background processes...' )
364+ for background_process in self .background_processes :
365+ if background_process .done ():
366+ self .log .debug ('Found a background process in state "done"...' )
367+ try :
368+ background_exception = background_process .exception ()
369+ except asyncio .CancelledError :
370+ self .log .error ('Background process was cancelled!' )
371+ if background_exception :
372+ self .log .error ('Background process exited with an exception:' )
373+ self .log .error (background_exception )
374+ self .log .error ('At least one background process exited!' )
375+ return False
376+ else :
377+ self .log .debug ('Found a not-yet-done background process...' )
378+ self .log .debug ('All background processes still running.' )
379+ return True
380+
302381 # Override if your batch system needs something more elaborate to query the job status
303382 batch_query_cmd = Unicode (
304383 "" ,
@@ -353,6 +432,29 @@ async def cancel_batch_job(self):
353432 )
354433 )
355434 self .log .info ("Cancelling job " + self .job_id + ": " + cmd )
435+
436+ if self .background_processes :
437+ self .log .debug ('Job being cancelled, cancelling background processes...' )
438+ for background_process in self .background_processes :
439+ if not background_process .cancelled ():
440+ try :
441+ background_process .cancel ()
442+ except :
443+ self .log .error ('Encountered an exception cancelling background process...' )
444+ self .log .debug ('Cancelled background process, waiting for it to finish...' )
445+ try :
446+ await asyncio .wait ([background_process ])
447+ except asyncio .CancelledError :
448+ self .log .error ('Successfully cancelled background process.' )
449+ pass
450+ except :
451+ self .log .error ('Background process exited with another exception!' )
452+ raise
453+ else :
454+ self .log .debug ('Background process already cancelled...' )
455+ self .background_processes .clear ()
456+ self .log .debug ('All background processes cancelled.' )
457+
356458 await self .run_command (cmd )
357459
358460 def load_state (self , state ):
@@ -400,6 +502,13 @@ async def poll(self):
400502 """Poll the process"""
401503 status = await self .query_job_status ()
402504 if status in (JobStatus .PENDING , JobStatus .RUNNING , JobStatus .UNKNOWN ):
505+ if not self .background_tasks_ok ():
506+ self .log .debug ('Going to stop job, since background tasks have failed!' )
507+ await self .stop (now = True )
508+ status = await self .query_job_status ()
509+ if status not in (JobStatus .PENDING , JobStatus .RUNNING , JobStatus .UNKNOWN ):
510+ self .clear_state ()
511+ return 1
403512 return None
404513 else :
405514 self .clear_state ()
@@ -466,6 +575,9 @@ async def start(self):
466575 )
467576 )
468577
578+ if self .connect_to_job_cmd :
579+ await self .connect_to_job ()
580+
469581 return self .ip , self .port
470582
471583 async def stop (self , now = False ):
0 commit comments