1515import sys
1616import threading
1717import time
18+ from concurrent .futures import ThreadPoolExecutor
1819from functools import partial
19- from signal import SIGINT
2020from signal import SIGTERM
2121from subprocess import check_output
2222from subprocess import PIPE
@@ -85,6 +85,12 @@ class UnknownStatus(LauncherError):
8585class BaseLauncher (LoggingConfigurable ):
8686 """An abstraction for starting, stopping and signaling a process."""
8787
88+ stop_timeout = Integer (
89+ 60 ,
90+ config = True ,
91+ help = "The number of seconds to wait for a process to exit before raising a TimeoutError in stop" ,
92+ )
93+
8894 # In all of the launchers, the work_dir is where child processes will be
8995 # run. This will usually be the profile_dir, but may not be. any work_dir
9096 # passed into the __init__ method will override the config value.
@@ -249,6 +255,10 @@ def signal(self, sig):
249255 """
250256 raise NotImplementedError ('signal must be implemented in a subclass' )
251257
258+ def join (self , timeout = None ):
259+ """Wait for the process to finish"""
260+ raise NotImplementedError ('join must be implemented in a subclass' )
261+
252262 output_limit = Integer (
253263 100 ,
254264 config = True ,
@@ -376,6 +386,12 @@ def _default_output_file(self):
376386 os .makedirs (log_dir , exist_ok = True )
377387 return os .path .join (log_dir , f'{ self .identifier } .log' )
378388
389+ stop_seconds_until_kill = Integer (
390+ 5 ,
391+ config = True ,
392+ help = """The number of seconds to wait for a process to exit after sending SIGTERM before sending SIGKILL""" ,
393+ )
394+
379395 stdout = None
380396 stderr = None
381397 process = None
@@ -446,6 +462,18 @@ def start(self):
446462 if self .log .level <= logging .DEBUG :
447463 self ._start_streaming ()
448464
465+ async def join (self , timeout = None ):
466+ """Wait for the process to exit"""
467+ with ThreadPoolExecutor (1 ) as pool :
468+ try :
469+ await asyncio .wrap_future (
470+ pool .submit (partial (self .process .wait , timeout ))
471+ )
472+ except psutil .TimeoutExpired :
473+ raise TimeoutError (
474+ f"Process { self .pid } did not complete in { timeout } seconds."
475+ )
476+
449477 def _stream_file (self , path ):
450478 """Stream one file"""
451479 with open (path , 'r' ) as f :
@@ -460,7 +488,7 @@ def _stream_file(self, path):
460488 time .sleep (0.1 )
461489
462490 def _start_streaming (self ):
463- t = threading .Thread (
491+ self . _stream_thread = t = threading .Thread (
464492 target = partial (self ._stream_file , self .output_file ),
465493 name = f"Stream Output { self .identifier } " ,
466494 daemon = True ,
@@ -483,35 +511,46 @@ def get_output(self, remove=False):
483511
484512 if remove and os .path .isfile (self .output_file ):
485513 self .log .debug (f"Removing { self .output_file } " )
486- os .remove (self .output_file )
514+ try :
515+ os .remove (self .output_file )
516+ except Exception as e :
517+ # don't crash on failure to remove a file,
518+ # e.g. due to another processing having it open
519+ self .log .error (f"Failed to remove { self .output_file } : { e } " )
487520
488521 return self ._output
489522
490- def stop (self ):
491- return self .interrupt_then_kill ()
523+ async def stop (self ):
524+ try :
525+ self .signal (SIGTERM )
526+ except Exception as e :
527+ self .log .debug (f"TERM failed: { e !r} " )
528+
529+ try :
530+ await self .join (timeout = self .stop_seconds_until_kill )
531+ except TimeoutError :
532+ self .log .warning (
533+ f"Process { self .pid } did not exit in { self .stop_seconds_until_kill } seconds after TERM"
534+ )
535+ else :
536+ return
537+
538+ try :
539+ self .signal (SIGKILL )
540+ except Exception as e :
541+ self .log .debug (f"KILL failed: { e !r} " )
542+
543+ await self .join (timeout = self .stop_timeout )
492544
493545 def signal (self , sig ):
494546 if self .state == 'running' :
495- if WINDOWS and sig != SIGINT :
547+ if WINDOWS and sig == SIGKILL :
496548 # use Windows tree-kill for better child cleanup
497- cmd = ['taskkill' , '-pid' , str (self .process .pid ), '-t' ]
498- if sig == SIGKILL :
499- cmd .append ("-f" )
549+ cmd = ['taskkill' , '/pid' , str (self .process .pid ), '/t' , '/F' ]
500550 check_output (cmd )
501551 else :
502552 self .process .send_signal (sig )
503553
504- def interrupt_then_kill (self , delay = 2.0 ):
505- """Send TERM, wait a delay and then send KILL."""
506- try :
507- self .signal (SIGTERM )
508- except Exception as e :
509- self .log .debug (f"interrupt failed: { e !r} " )
510- pass
511- self .killer = asyncio .get_event_loop ().call_later (
512- delay , lambda : self .signal (SIGKILL )
513- )
514-
515554 # callbacks, etc:
516555
517556 def handle_stdout (self , fd , events ):
@@ -637,21 +676,18 @@ def find_args(self):
637676 return ['engine set' ]
638677
639678 def signal (self , sig ):
640- dlist = []
641- for el in itervalues (self .launchers ):
642- d = el .signal (sig )
643- dlist .append (d )
644- return dlist
679+ for el in list (self .launchers .values ()):
680+ el .signal (sig )
645681
646- def interrupt_then_kill (self , delay = 1.0 ):
647- dlist = []
648- for el in itervalues (self .launchers ):
649- d = el .interrupt_then_kill ( delay )
650- dlist . append ( d )
651- return dlist
682+ async def stop (self ):
683+ futures = []
684+ for el in list (self .launchers . values () ):
685+ f = el .stop ( )
686+ if inspect . isawaitable ( f ):
687+ futures . append ( asyncio . ensure_future ( f ))
652688
653- def stop ( self ) :
654- return self . interrupt_then_kill ( )
689+ if futures :
690+ await asyncio . gather ( * futures )
655691
656692 def _notice_engine_stopped (self , data ):
657693 identifier = data ['identifier' ]
@@ -1146,6 +1182,12 @@ def wait_one(self, timeout):
11461182 raise TimeoutError ("still running" )
11471183 return int (values .get ("exit_code" , - 1 ))
11481184
1185+ async def join (self , timeout = None ):
1186+ with ThreadPoolExecutor (1 ) as pool :
1187+ await asyncio .wrap_future (
1188+ pool .submit (partial (self .wait_one , timeout = timeout ))
1189+ )
1190+
11491191 def signal (self , sig ):
11501192 if self .state == 'running' :
11511193 check_output (
@@ -1306,7 +1348,7 @@ def start(self, n):
13061348 return dlist
13071349
13081350
1309- class SSHProxyEngineSetLauncher (SSHLauncher ):
1351+ class SSHProxyEngineSetLauncher (SSHLauncher , EngineLauncher ):
13101352 """Launcher for calling
13111353 `ipcluster engines` on a remote machine.
13121354
0 commit comments