Skip to content

Commit b97900e

Browse files
committed
BatchSpawnerBase: Add background_tasks, connect_to_job feature.
This adds the possibility to start a "connect_to_job" background task on the hub on job start, which establishes connectivity to the actual single user server. An example for this can be "condor_ssh_to_job" for HTCondor batch systems. Additionally, the background tasks are monitored: - for successful startup. The background task is given some time to successfully establish connectivity. - in poll() during job runtime and if they fail, the job is terminated.
1 parent d4d6593 commit b97900e

File tree

1 file changed

+112
-0
lines changed

1 file changed

+112
-0
lines changed

batchspawner/batchspawner.py

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,12 @@ def _req_keepvars_default(self):
186186
"specification.",
187187
).tag(config=True)
188188

189+
connect_to_job_cmd = Unicode('',
190+
help="Command to connect to running batch job and forward the port "
191+
"of the running notebook to the Hub. If empty, direct connectivity is assumed. "
192+
"Uses self.job_id as {job_id} and the self.port as {port}."
193+
).tag(config=True)
194+
189195
# Raw output of job submission command unless overridden
190196
job_id = Unicode()
191197

@@ -215,6 +221,18 @@ def cmd_formatted_for_batch(self):
215221
"""The command which is substituted inside of the batch script"""
216222
return " ".join([self.batchspawner_singleuser_cmd] + self.cmd + self.get_args())
217223

224+
async def connect_to_job(self):
225+
"""This command ensures the port of the singleuser server is reachable from the
226+
Batchspawner machine. By default, it does nothing, i.e. direct connectivity
227+
is assumed.
228+
"""
229+
subvars = self.get_req_subvars()
230+
subvars['job_id'] = self.job_id
231+
subvars['port'] = self.port
232+
cmd = ' '.join((format_template(self.exec_prefix, **subvars),
233+
format_template(self.connect_to_job_cmd, **subvars)))
234+
await self.run_background_command(cmd)
235+
218236
async def run_command(self, cmd, input=None, env=None):
219237
proc = await asyncio.create_subprocess_shell(
220238
cmd,
@@ -268,6 +286,46 @@ async def run_command(self, cmd, input=None, env=None):
268286
out = out.decode().strip()
269287
return out
270288

289+
# List of running background processes, e.g. used by connect_to_job.
290+
background_processes = []
291+
292+
async def _async_wait_process(self, sleep_time):
293+
"""Asynchronously sleeping process for delayed checks"""
294+
await asyncio.sleep(sleep_time)
295+
296+
async def run_background_command(self, cmd, startup_check_delay=1, input=None, env=None):
297+
"""Runs the given background command, adds it to background_processes,
298+
and checks if the command is still running after startup_check_delay."""
299+
background_process = self.run_command(cmd, input, env)
300+
success_check_delay = self._async_wait_process(startup_check_delay)
301+
302+
# Start up both the success check process and the actual process.
303+
done, pending = await asyncio.wait([background_process, success_check_delay], return_when=asyncio.FIRST_COMPLETED)
304+
305+
# If the success check process is the one which exited first, all is good, else fail.
306+
if list(done)[0]._coro == success_check_delay:
307+
background_task = list(pending)[0]
308+
self.background_processes.append(background_task)
309+
return background_task
310+
else:
311+
self.log.error("Background command exited early: %s" % cmd)
312+
gather_pending = asyncio.gather(*pending)
313+
gather_pending.cancel()
314+
try:
315+
self.log.debug("Cancelling pending success check task...")
316+
await gather_pending
317+
except asyncio.CancelledError:
318+
self.log.debug("Cancel was successful.")
319+
pass
320+
321+
# Retrieve exception from "done" process.
322+
try:
323+
gather_done = asyncio.gather(*done)
324+
await gather_done
325+
except:
326+
self.log.debug("Retrieving exception from failed background task...")
327+
raise RuntimeError('{} failed!'.format(cmd))
328+
271329
async def _get_batch_script(self, **subvars):
272330
"""Format batch script from vars"""
273331
# Could be overridden by subclasses, but mainly useful for testing
@@ -299,6 +357,27 @@ async def submit_batch_script(self):
299357
self.job_id = ""
300358
return self.job_id
301359

360+
def background_tasks_ok(self):
361+
# Check background processes.
362+
if self.background_processes:
363+
self.log.debug('Checking background processes...')
364+
for background_process in self.background_processes:
365+
if background_process.done():
366+
self.log.debug('Found a background process in state "done"...')
367+
try:
368+
background_exception = background_process.exception()
369+
except asyncio.CancelledError:
370+
self.log.error('Background process was cancelled!')
371+
if background_exception:
372+
self.log.error('Background process exited with an exception:')
373+
self.log.error(background_exception)
374+
self.log.error('At least one background process exited!')
375+
return False
376+
else:
377+
self.log.debug('Found a not-yet-done background process...')
378+
self.log.debug('All background processes still running.')
379+
return True
380+
302381
# Override if your batch system needs something more elaborate to query the job status
303382
batch_query_cmd = Unicode(
304383
"",
@@ -353,6 +432,29 @@ async def cancel_batch_job(self):
353432
)
354433
)
355434
self.log.info("Cancelling job " + self.job_id + ": " + cmd)
435+
436+
if self.background_processes:
437+
self.log.debug('Job being cancelled, cancelling background processes...')
438+
for background_process in self.background_processes:
439+
if not background_process.cancelled():
440+
try:
441+
background_process.cancel()
442+
except:
443+
self.log.error('Encountered an exception cancelling background process...')
444+
self.log.debug('Cancelled background process, waiting for it to finish...')
445+
try:
446+
await asyncio.wait([background_process])
447+
except asyncio.CancelledError:
448+
self.log.error('Successfully cancelled background process.')
449+
pass
450+
except:
451+
self.log.error('Background process exited with another exception!')
452+
raise
453+
else:
454+
self.log.debug('Background process already cancelled...')
455+
self.background_processes.clear()
456+
self.log.debug('All background processes cancelled.')
457+
356458
await self.run_command(cmd)
357459

358460
def load_state(self, state):
@@ -400,6 +502,13 @@ async def poll(self):
400502
"""Poll the process"""
401503
status = await self.query_job_status()
402504
if status in (JobStatus.PENDING, JobStatus.RUNNING, JobStatus.UNKNOWN):
505+
if not self.background_tasks_ok():
506+
self.log.debug('Going to stop job, since background tasks have failed!')
507+
await self.stop(now=True)
508+
status = await self.query_job_status()
509+
if status not in (JobStatus.PENDING, JobStatus.RUNNING, JobStatus.UNKNOWN):
510+
self.clear_state()
511+
return 1
403512
return None
404513
else:
405514
self.clear_state()
@@ -466,6 +575,9 @@ async def start(self):
466575
)
467576
)
468577

578+
if self.connect_to_job_cmd:
579+
await self.connect_to_job()
580+
469581
return self.ip, self.port
470582

471583
async def stop(self, now=False):

0 commit comments

Comments
 (0)