Skip to content

Commit f3ec68f

Browse files
+ Thread killing\n+ Graceful killing
1 parent 1485edc commit f3ec68f

File tree

1 file changed

+77
-9
lines changed

1 file changed

+77
-9
lines changed

src/thread/thread.py

Lines changed: 77 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
1-
import numpy
1+
import sys
2+
import signal
23
import threading
4+
5+
import numpy
36
from . import exceptions
47

58
from functools import wraps
@@ -10,7 +13,7 @@
1013
)
1114

1215

13-
ThreadStatus = Literal['Idle', 'Running', 'Invoking hooks', 'Completed', 'Errored']
16+
ThreadStatus = Literal['Idle', 'Running', 'Invoking hooks', 'Completed', 'Errored', 'Killed']
1417
Data_In = Any
1518
Data_Out = Any
1619
Overflow_In = Any
@@ -34,6 +37,7 @@ class Thread(threading.Thread):
3437

3538
# threading.Thread stuff
3639
_initialized : bool
40+
_run : Callable
3741

3842

3943
def __init__(
@@ -108,6 +112,7 @@ def wrapper(*args: Any, **kwargs: Any) -> Any:
108112

109113

110114
def _invoke_hooks(self) -> None:
115+
"""Invokes hooks in the thread"""
111116
errors: List[Tuple[Exception, str]] = []
112117
for hook in self.hooks:
113118
try:
@@ -126,14 +131,33 @@ def _invoke_hooks(self) -> None:
126131

127132

128133
def _handle_exceptions(self) -> None:
129-
"""Raises exceptions if not suppressed"""
134+
"""Raises exceptions if not suppressed in the main thread"""
130135
if self.suppress_errors:
131136
return
132137

133138
for e in self.errors:
134139
raise e
135140

136141

142+
def global_trace(self, frame, event: str, arg) -> Callable | None:
143+
if event == 'call':
144+
return self.local_trace
145+
146+
def local_trace(self, frame, event, arg):
147+
if self.status == 'Killed' and event == 'line':
148+
print('KILLED ident:%s' % self.ident)
149+
raise SystemExit()
150+
return self.local_trace
151+
152+
def _run_with_trace(self) -> None:
153+
"""This will replace `threading.Thread`'s `run()` method"""
154+
if not self._run:
155+
raise exceptions.ThreadNotInitializedError('Running `_run_with_trace` may cause unintended behaviour, run `start` instead')
156+
157+
sys.settrace(self.global_trace)
158+
self._run()
159+
160+
137161
@property
138162
def result(self) -> Data_Out:
139163
"""
@@ -147,7 +171,7 @@ def result(self) -> Data_Out:
147171
"""
148172
if not self._initialized:
149173
raise exceptions.ThreadNotInitializedError()
150-
if self.status == 'Idle':
174+
if self.status in ['Idle', 'Killed']:
151175
raise exceptions.ThreadNotRunningError()
152176

153177
self._handle_exceptions()
@@ -204,7 +228,7 @@ def join(self, timeout: Optional[float] = None) -> bool:
204228
if not self._initialized:
205229
raise exceptions.ThreadNotInitializedError()
206230

207-
if self.status == 'Idle':
231+
if self.status == ['Idle', 'Killed']:
208232
raise exceptions.ThreadNotRunningError()
209233

210234
super().join(timeout)
@@ -224,6 +248,20 @@ def get_return_value(self) -> Data_Out:
224248
return self.result
225249

226250

251+
def kill(self) -> None:
252+
"""
253+
Kills the thread
254+
255+
Raises
256+
------
257+
ThreadNotInitializedError: If the thread is not initialized
258+
ThreadNotRunningError: If the thread is not running
259+
"""
260+
if not self.is_alive():
261+
raise exceptions.ThreadNotRunningError()
262+
self.status = 'Killed'
263+
264+
227265
def start(self) -> None:
228266
"""
229267
Starts the thread
@@ -236,6 +274,8 @@ def start(self) -> None:
236274
if self.is_alive():
237275
raise exceptions.ThreadStillRunningError()
238276

277+
self._run = self.run
278+
self.run = self._run_with_trace
239279
super().start()
240280

241281

@@ -338,10 +378,6 @@ def results(self) -> Data_Out:
338378
results: List[Data_Out] = []
339379
for thread in self._threads:
340380
results += thread.result
341-
if thread.status == 'Idle':
342-
raise exceptions.ThreadNotRunningError()
343-
elif thread.status == 'Running':
344-
raise exceptions.ThreadStillRunningError()
345381
return results
346382

347383

@@ -397,6 +433,19 @@ def join(self) -> bool:
397433
return True
398434

399435

436+
def kill(self) -> None:
437+
"""
438+
Kills the threads
439+
440+
Raises
441+
------
442+
ThreadNotInitializedError: If the thread is not initialized
443+
ThreadNotRunningError: If the thread is not running
444+
"""
445+
for thread in self._threads:
446+
thread.kill()
447+
448+
400449
def start(self) -> None:
401450
"""
402451
Starts the threads
@@ -424,3 +473,22 @@ def start(self) -> None:
424473
)
425474
self._threads.append(chunk_thread)
426475
chunk_thread.start()
476+
477+
478+
479+
480+
481+
# Handle abrupt exit
482+
def service_shutdown(signum, frame):
483+
print('\nCaught signal %d' % signum)
484+
print('Gracefully killing active threads')
485+
486+
for thread in threading.enumerate():
487+
if isinstance(thread, Thread):
488+
thread.kill()
489+
sys.exit(0)
490+
491+
492+
# Register the signal handlers
493+
signal.signal(signal.SIGTERM, service_shutdown)
494+
signal.signal(signal.SIGINT, service_shutdown)

0 commit comments

Comments
 (0)