@@ -172,6 +172,12 @@ def _req_keepvars_default(self):
172172 "specification."
173173 ).tag (config = True )
174174
175+ connect_to_job_cmd = Unicode ('' ,
176+ help = "Command to connect to running batch job and forward the port "
177+ "of the running notebook to the Hub. If empty, direct connectivity is assumed. "
178+ "Uses self.job_id as {job_id} and the self.port as {port}."
179+ ).tag (config = True )
180+
175181 # Raw output of job submission command unless overridden
176182 job_id = Unicode ()
177183
@@ -200,6 +206,18 @@ def cmd_formatted_for_batch(self):
200206 """The command which is substituted inside of the batch script"""
201207 return ' ' .join ([self .batchspawner_singleuser_cmd ] + self .cmd + self .get_args ())
202208
209+ async def connect_to_job (self ):
210+ """This command ensures the port of the singleuser server is reachable from the
211+ Batchspawner machine. By default, it does nothing, i.e. direct connectivity
212+ is assumed.
213+ """
214+ subvars = self .get_req_subvars ()
215+ subvars ['job_id' ] = self .job_id
216+ subvars ['port' ] = self .port
217+ cmd = ' ' .join ((format_template (self .exec_prefix , ** subvars ),
218+ format_template (self .connect_to_job_cmd , ** subvars )))
219+ await self .run_background_command (cmd )
220+
203221 async def run_command (self , cmd , input = None , env = None ):
204222 proc = await asyncio .create_subprocess_shell (cmd , env = env ,
205223 stdin = asyncio .subprocess .PIPE ,
@@ -236,6 +254,46 @@ async def run_command(self, cmd, input=None, env=None):
236254 out = out .decode ().strip ()
237255 return out
238256
257+ # List of running background processes, e.g. used by connect_to_job.
258+ background_processes = []
259+
260+ async def _async_wait_process (self , sleep_time ):
261+ """Asynchronously sleeping process for delayed checks"""
262+ await asyncio .sleep (sleep_time )
263+
264+ async def run_background_command (self , cmd , startup_check_delay = 1 , input = None , env = None ):
265+ """Runs the given background command, adds it to background_processes,
266+ and checks if the command is still running after startup_check_delay."""
267+ background_process = self .run_command (cmd , input , env )
268+ success_check_delay = self ._async_wait_process (startup_check_delay )
269+
270+ # Start up both the success check process and the actual process.
271+ done , pending = await asyncio .wait ([background_process , success_check_delay ], return_when = asyncio .FIRST_COMPLETED )
272+
273+ # If the success check process is the one which exited first, all is good, else fail.
274+ if list (done )[0 ]._coro == success_check_delay :
275+ background_task = list (pending )[0 ]
276+ self .background_processes .append (background_task )
277+ return background_task
278+ else :
279+ self .log .error ("Background command exited early: %s" % cmd )
280+ gather_pending = asyncio .gather (* pending )
281+ gather_pending .cancel ()
282+ try :
283+ self .log .debug ("Cancelling pending success check task..." )
284+ await gather_pending
285+ except asyncio .CancelledError :
286+ self .log .debug ("Cancel was successful." )
287+ pass
288+
289+ # Retrieve exception from "done" process.
290+ try :
291+ gather_done = asyncio .gather (* done )
292+ await gather_done
293+ except :
294+ self .log .debug ("Retrieving exception from failed background task..." )
295+ raise RuntimeError ('{} failed!' .format (cmd ))
296+
239297 async def _get_batch_script (self , ** subvars ):
240298 """Format batch script from vars"""
241299 # Could be overridden by subclasses, but mainly useful for testing
@@ -263,6 +321,27 @@ async def submit_batch_script(self):
263321 self .job_id = ''
264322 return self .job_id
265323
324+ def background_tasks_ok (self ):
325+ # Check background processes.
326+ if self .background_processes :
327+ self .log .debug ('Checking background processes...' )
328+ for background_process in self .background_processes :
329+ if background_process .done ():
330+ self .log .debug ('Found a background process in state "done"...' )
331+ try :
332+ background_exception = background_process .exception ()
333+ except asyncio .CancelledError :
334+ self .log .error ('Background process was cancelled!' )
335+ if background_exception :
336+ self .log .error ('Background process exited with an exception:' )
337+ self .log .error (background_exception )
338+ self .log .error ('At least one background process exited!' )
339+ return False
340+ else :
341+ self .log .debug ('Found a not-yet-done background process...' )
342+ self .log .debug ('All background processes still running.' )
343+ return True
344+
266345 # Override if your batch system needs something more elaborate to query the job status
267346 batch_query_cmd = Unicode ('' ,
268347 help = "Command to run to query job status. Formatted using req_xyz traits as {xyz} "
@@ -307,6 +386,29 @@ async def cancel_batch_job(self):
307386 cmd = ' ' .join ((format_template (self .exec_prefix , ** subvars ),
308387 format_template (self .batch_cancel_cmd , ** subvars )))
309388 self .log .info ('Cancelling job ' + self .job_id + ': ' + cmd )
389+
390+ if self .background_processes :
391+ self .log .debug ('Job being cancelled, cancelling background processes...' )
392+ for background_process in self .background_processes :
393+ if not background_process .cancelled ():
394+ try :
395+ background_process .cancel ()
396+ except :
397+ self .log .error ('Encountered an exception cancelling background process...' )
398+ self .log .debug ('Cancelled background process, waiting for it to finish...' )
399+ try :
400+ await asyncio .wait ([background_process ])
401+ except asyncio .CancelledError :
402+ self .log .error ('Successfully cancelled background process.' )
403+ pass
404+ except :
405+ self .log .error ('Background process exited with another exception!' )
406+ raise
407+ else :
408+ self .log .debug ('Background process already cancelled...' )
409+ self .background_processes .clear ()
410+ self .log .debug ('All background processes cancelled.' )
411+
310412 await self .run_command (cmd )
311413
312414 def load_state (self , state ):
@@ -354,6 +456,13 @@ async def poll(self):
354456 """Poll the process"""
355457 status = await self .query_job_status ()
356458 if status in (JobStatus .PENDING , JobStatus .RUNNING , JobStatus .UNKNOWN ):
459+ if not self .background_tasks_ok ():
460+ self .log .debug ('Going to stop job, since background tasks have failed!' )
461+ await self .stop (now = True )
462+ status = await self .query_job_status ()
463+ if status not in (JobStatus .PENDING , JobStatus .RUNNING , JobStatus .UNKNOWN ):
464+ self .clear_state ()
465+ return 1
357466 return None
358467 else :
359468 self .clear_state ()
@@ -413,6 +522,9 @@ async def start(self):
413522 self .job_id , self .ip , self .port )
414523 )
415524
525+ if self .connect_to_job_cmd :
526+ await self .connect_to_job ()
527+
416528 return self .ip , self .port
417529
418530 async def stop (self , now = False ):
0 commit comments