1- import numpy
1+ import sys
2+ import signal
23import threading
4+
5+ import numpy
36from . import exceptions
47
58from functools import wraps
1013)
1114
1215
13- ThreadStatus = Literal ['Idle' , 'Running' , 'Invoking hooks' , 'Completed' , 'Errored' ]
16+ ThreadStatus = Literal ['Idle' , 'Running' , 'Invoking hooks' , 'Completed' , 'Errored' , 'Killed' ]
1417Data_In = Any
1518Data_Out = Any
1619Overflow_In = Any
@@ -34,6 +37,7 @@ class Thread(threading.Thread):
3437
3538 # threading.Thread stuff
3639 _initialized : bool
40+ _run : Callable
3741
3842
3943 def __init__ (
@@ -108,6 +112,7 @@ def wrapper(*args: Any, **kwargs: Any) -> Any:
108112
109113
110114 def _invoke_hooks (self ) -> None :
115+ """Invokes hooks in the thread"""
111116 errors : List [Tuple [Exception , str ]] = []
112117 for hook in self .hooks :
113118 try :
@@ -126,14 +131,33 @@ def _invoke_hooks(self) -> None:
126131
127132
128133 def _handle_exceptions (self ) -> None :
129- """Raises exceptions if not suppressed"""
134+ """Raises exceptions if not suppressed in the main thread """
130135 if self .suppress_errors :
131136 return
132137
133138 for e in self .errors :
134139 raise e
135140
136141
142+ def global_trace (self , frame , event : str , arg ) -> Callable | None :
143+ if event == 'call' :
144+ return self .local_trace
145+
146+ def local_trace (self , frame , event , arg ):
147+ if self .status == 'Killed' and event == 'line' :
148+ print ('KILLED ident:%s' % self .ident )
149+ raise SystemExit ()
150+ return self .local_trace
151+
152+ def _run_with_trace (self ) -> None :
153+ """This will replace `threading.Thread`'s `run()` method"""
154+ if not self ._run :
155+ raise exceptions .ThreadNotInitializedError ('Running `_run_with_trace` may cause unintended behaviour, run `start` instead' )
156+
157+ sys .settrace (self .global_trace )
158+ self ._run ()
159+
160+
137161 @property
138162 def result (self ) -> Data_Out :
139163 """
@@ -147,7 +171,7 @@ def result(self) -> Data_Out:
147171 """
148172 if not self ._initialized :
149173 raise exceptions .ThreadNotInitializedError ()
150- if self .status == 'Idle' :
174+ if self .status in [ 'Idle' , 'Killed' ] :
151175 raise exceptions .ThreadNotRunningError ()
152176
153177 self ._handle_exceptions ()
@@ -204,7 +228,7 @@ def join(self, timeout: Optional[float] = None) -> bool:
204228 if not self ._initialized :
205229 raise exceptions .ThreadNotInitializedError ()
206230
207- if self .status == 'Idle' :
231+ if self .status == [ 'Idle' , 'Killed' ] :
208232 raise exceptions .ThreadNotRunningError ()
209233
210234 super ().join (timeout )
@@ -224,6 +248,20 @@ def get_return_value(self) -> Data_Out:
224248 return self .result
225249
226250
251+ def kill (self ) -> None :
252+ """
253+ Kills the thread
254+
255+ Raises
256+ ------
257+ ThreadNotInitializedError: If the thread is not initialized
258+ ThreadNotRunningError: If the thread is not running
259+ """
260+ if not self .is_alive ():
261+ raise exceptions .ThreadNotRunningError ()
262+ self .status = 'Killed'
263+
264+
227265 def start (self ) -> None :
228266 """
229267 Starts the thread
@@ -236,6 +274,8 @@ def start(self) -> None:
236274 if self .is_alive ():
237275 raise exceptions .ThreadStillRunningError ()
238276
277+ self ._run = self .run
278+ self .run = self ._run_with_trace
239279 super ().start ()
240280
241281
@@ -338,10 +378,6 @@ def results(self) -> Data_Out:
338378 results : List [Data_Out ] = []
339379 for thread in self ._threads :
340380 results += thread .result
341- if thread .status == 'Idle' :
342- raise exceptions .ThreadNotRunningError ()
343- elif thread .status == 'Running' :
344- raise exceptions .ThreadStillRunningError ()
345381 return results
346382
347383
@@ -397,6 +433,19 @@ def join(self) -> bool:
397433 return True
398434
399435
436+ def kill (self ) -> None :
437+ """
438+ Kills the threads
439+
440+ Raises
441+ ------
442+ ThreadNotInitializedError: If the thread is not initialized
443+ ThreadNotRunningError: If the thread is not running
444+ """
445+ for thread in self ._threads :
446+ thread .kill ()
447+
448+
400449 def start (self ) -> None :
401450 """
402451 Starts the threads
@@ -424,3 +473,22 @@ def start(self) -> None:
424473 )
425474 self ._threads .append (chunk_thread )
426475 chunk_thread .start ()
476+
477+
478+
479+
480+
481+ # Handle abrupt exit
482+ def service_shutdown (signum , frame ):
483+ print ('\n Caught signal %d' % signum )
484+ print ('Gracefully killing active threads' )
485+
486+ for thread in threading .enumerate ():
487+ if isinstance (thread , Thread ):
488+ thread .kill ()
489+ sys .exit (0 )
490+
491+
492+ # Register the signal handlers
493+ signal .signal (signal .SIGTERM , service_shutdown )
494+ signal .signal (signal .SIGINT , service_shutdown )
0 commit comments