Skip to content

Commit 5f1905c

Browse files
+ syntax changes
+ ParallelProcessing.join method
1 parent e7457db commit 5f1905c

File tree

1 file changed

+88
-30
lines changed

1 file changed

+88
-30
lines changed

src/thread/thread.py

Lines changed: 88 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ class Thread:
4141
kwargs : Mapping[str, Data_In]
4242

4343
errors : List[Exception]
44-
ignore_errors : Sequence[Exception]
44+
ignore_errors : Sequence[type[Exception]]
4545
suppress_errors: bool
4646

4747
overflow_args : Sequence[Overflow_In]
@@ -52,7 +52,7 @@ def __init__(
5252
target: Callable[Concatenate[Data_In, ...], Data_Out],
5353
args: Sequence[Data_In] = (),
5454
kwargs: Mapping[str, Data_In] = {},
55-
ignore_errors: Sequence[Exception] = (),
55+
ignore_errors: Sequence[type[Exception]] = (),
5656
suppress_errors: bool = False,
5757

5858
name: Optional[str] = None,
@@ -75,8 +75,10 @@ def __init__(
7575
:param *: These are arguments parsed to `threading.Thread`
7676
:param **: These are arguments parsed to `thread.Thread`
7777
"""
78+
self._thread = None
7879
self.status = 'Idle'
7980
self.hooks = []
81+
self.returned_value = None
8082

8183
self.target = self._wrap_target(target)
8284
self.args = args
@@ -105,38 +107,41 @@ def wrapper(*args: Any, **kwargs: Any) -> Any:
105107
try:
106108
self.returned_value = target(*args, **kwargs)
107109
except Exception as e:
108-
if e not in self.ignore_errors:
110+
if type(e) not in self.ignore_errors:
109111
self.status = 'Errored'
110112
self.errors.append(e)
111-
112-
if not self.suppress_errors:
113-
raise
114-
else:
115-
return
113+
return
116114

117115
self.status = 'Invoking hooks'
118116
self._invoke_hooks()
119117
self.status = 'Completed'
120118
return wrapper
121-
119+
122120

123121
def _invoke_hooks(self) -> None:
124122
trace = exceptions.HookRuntimeError()
125123
for hook in self.hooks:
126124
try:
127125
hook(self.returned_value)
128126
except Exception as e:
129-
if not self.suppress_errors and (e not in self.ignore_errors):
127+
if type(e) not in self.ignore_errors:
130128
trace.add_exception_case(
131129
hook.__name__,
132130
e
133131
)
134132

135133
if trace.count > 0:
136134
self.errors.append(trace)
137-
if not self.suppress_errors:
138-
raise trace
135+
136+
137+
def _handle_exceptions(self) -> None:
138+
"""Raises exceptions if not suppressed"""
139+
if self.suppress_errors:
140+
return
139141

142+
for e in self.errors:
143+
raise e
144+
140145

141146
@property
142147
def result(self) -> Data_Out:
@@ -145,12 +150,35 @@ def result(self) -> Data_Out:
145150
146151
Raises
147152
------
148-
ThreadeStillRunningError: If the thread is still running
153+
ThreadNotInitializedError: If the thread is not initialized
154+
ThreadNotRunningError: If the thread is not running
155+
ThreadStillRunningError: If the thread is still running
149156
"""
157+
if not self._thread:
158+
raise exceptions.ThreadNotInitializedError()
159+
160+
if self.status == 'Idle':
161+
raise exceptions.ThreadNotRunningError()
162+
163+
self._handle_exceptions()
150164
if self.status in ['Invoking hooks', 'Completed']:
151165
return self.returned_value
152166
else:
153167
raise exceptions.ThreadStillRunningError()
168+
169+
170+
def add_hook(self, hook: Callable[[Data_Out], Any | None]) -> None:
171+
"""
172+
Adds a hook to the thread
173+
-------------------------
174+
Hooks are executed automatically after a successful thread execution.
175+
The returned value is parsed directly into the hook
176+
177+
Parameters
178+
----------
179+
:param hook: This should be a function which takes the output value of `target` and should return None
180+
"""
181+
self.hooks.append(hook)
154182

155183

156184
def join(self, timeout: Optional[float] = None) -> 'JoinTerminatedStatus':
@@ -177,6 +205,7 @@ def join(self, timeout: Optional[float] = None) -> 'JoinTerminatedStatus':
177205
raise exceptions.ThreadNotRunningError()
178206

179207
self._thread.join(timeout)
208+
self._handle_exceptions()
180209
return JoinTerminatedStatus(self._thread.is_alive() and 'Timeout Exceeded' or 'Thread terminated')
181210

182211

@@ -202,7 +231,7 @@ def start(self) -> None:
202231
"""
203232
if self._thread is not None and self._thread.is_alive():
204233
raise exceptions.ThreadStillRunningError()
205-
234+
206235
self._thread = threading.Thread(
207236
target = self.target,
208237
args = self.args,
@@ -214,6 +243,7 @@ def start(self) -> None:
214243

215244

216245

246+
217247
class ParallelProcessing:
218248
"""
219249
Multi-Threaded Parallel Processing
@@ -282,11 +312,8 @@ def _wrap_function(
282312
def wrapper(data_chunk: Sequence[Data_In], *args: Any, **kwargs: Any) -> List[Data_Out]:
283313
computed: List[Data_Out] = []
284314
for data_entry in data_chunk:
285-
try:
286-
v = function(data_entry, *args, **kwargs)
287-
computed.append(v)
288-
except Exception:
289-
pass
315+
v = function(data_entry, *args, **kwargs)
316+
computed.append(v)
290317

291318
self._completed += 1
292319
if self._completed == len(self._threads):
@@ -303,16 +330,19 @@ def results(self) -> Data_Out:
303330
304331
Raises
305332
------
306-
ThreadeStillRunningError: If the threads are still running
333+
ThreadNotInitializedError: If the threads are not initialized
334+
ThreadNotRunningError: If the threads are not running
335+
ThreadStillRunningError: If the threads are still running
307336
"""
308-
if self._completed == len(self._threads):
309-
results: List[Data_Out] = []
310-
for thread in self._threads:
311-
results += thread.result
312-
return results
313-
else:
314-
raise exceptions.ThreadStillRunningError()
337+
if len(self._threads) == 0:
338+
raise exceptions.ThreadNotInitializedError()
339+
340+
results: List[Data_Out] = []
341+
for thread in self._threads:
342+
results += thread.result
343+
return results
315344

345+
316346
def get_return_values(self) -> List[Data_Out]:
317347
"""
318348
Halts the current thread execution until the thread completes
@@ -328,6 +358,30 @@ def get_return_values(self) -> List[Data_Out]:
328358
return results
329359

330360

361+
def join(self) -> 'JoinTerminatedStatus':
362+
"""
363+
Halts the current thread execution until a thread completes or exceeds the timeout
364+
365+
Returns
366+
-------
367+
:returns JoinTerminatedStatus: Why the method stoped halting the thread
368+
369+
Raises
370+
------
371+
ThreadNotInitializedError: If the thread is not initialized
372+
ThreadNotRunningError: If the thread is not running
373+
"""
374+
if len(self._threads) == 0:
375+
raise exceptions.ThreadNotInitializedError()
376+
377+
if self.status == 'Idle':
378+
raise exceptions.ThreadNotRunningError()
379+
380+
for thread in self._threads:
381+
thread.join()
382+
return JoinTerminatedStatus('Thread terminated')
383+
384+
331385
def start(self) -> None:
332386
"""
333387
Starts the threads
@@ -342,11 +396,15 @@ def start(self) -> None:
342396
self.status = 'Running'
343397
max_threads = min(self.max_threads, len(self.dataset))
344398

345-
for data_chunk in numpy.array_split(self.dataset, max_threads):
399+
parsed_args = self.overflow_kwargs.get('args', [])
400+
name_format = self.overflow_kwargs.get('name') and self.overflow_kwargs['name'] + '%s'
401+
self.overflow_kwargs = { i: v for i,v in self.overflow_kwargs.items() if i != 'name' and i != 'args' }
402+
403+
for i, data_chunk in enumerate(numpy.array_split(self.dataset, max_threads)):
346404
chunk_thread = Thread(
347405
target = self.function,
348-
args = data_chunk.tolist(),
349-
*self.overflow_args,
406+
args = [data_chunk.tolist(), *parsed_args, *self.overflow_args],
407+
name = name_format and name_format % i or None,
350408
**self.overflow_kwargs
351409
)
352410
self._threads.append(chunk_thread)

0 commit comments

Comments
 (0)