Skip to content

Commit 655b0df

Browse files
Added Progress bar support
1 parent 94c3660 commit 655b0df

File tree

1 file changed

+27
-19
lines changed

1 file changed

+27
-19
lines changed

src/thread/thread.py

Lines changed: 27 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from typing import (
1111
Any, List,
1212
Callable, Union, Optional, Literal,
13-
Mapping, Sequence, Tuple
13+
Mapping, Sequence, Tuple, TypedDict
1414
)
1515

1616

@@ -81,7 +81,7 @@ def __init__(
8181
:param **: These are arguments parsed to `thread.Thread`
8282
"""
8383
_target = self._wrap_target(target)
84-
self.returned_values = None
84+
self.returned_value = None
8585
self.status = 'Idle'
8686
self.hooks = []
8787

@@ -312,6 +312,14 @@ def start(self) -> None:
312312

313313

314314

315+
class _ThreadWorker:
316+
progress: float
317+
thread: Thread
318+
319+
def __init__(self, thread: Thread, progress: float = 0) -> None:
320+
self.thread = thread
321+
self.progress = progress
322+
315323
class ParallelProcessing:
316324
"""
317325
Multi-Threaded Parallel Processing
@@ -320,15 +328,14 @@ class ParallelProcessing:
320328
Type-Safe and provides more functionality on top
321329
"""
322330

323-
_threads : List[Thread]
331+
_threads : List[_ThreadWorker]
324332
_completed : int
325-
_return_vales : Mapping[int, List[Data_Out]]
326333

327334
status : ThreadStatus
328335
function : Callable[..., List[Data_Out]]
329336
dataset : Sequence[Data_In]
330337
max_threads : int
331-
338+
332339
overflow_args : Sequence[Overflow_In]
333340
overflow_kwargs: Mapping[str, Overflow_In]
334341

@@ -378,11 +385,12 @@ def _wrap_function(
378385
function: Callable[..., Data_Out]
379386
) -> Callable[..., List[Data_Out]]:
380387
@wraps(function)
381-
def wrapper(data_chunk: Sequence[Data_In], *args: Any, **kwargs: Any) -> List[Data_Out]:
388+
def wrapper(index: int, data_chunk: Sequence[Data_In], *args: Any, **kwargs: Any) -> List[Data_Out]:
382389
computed: List[Data_Out] = []
383-
for data_entry in data_chunk:
390+
for i, data_entry in enumerate(data_chunk):
384391
v = function(data_entry, *args, **kwargs)
385392
computed.append(v)
393+
self._threads[index].progress = round(i/len(data_chunk), 2)
386394

387395
self._completed += 1
388396
if self._completed == len(self._threads):
@@ -407,8 +415,8 @@ def results(self) -> Data_Out:
407415
raise exceptions.ThreadNotInitializedError()
408416

409417
results: List[Data_Out] = []
410-
for thread in self._threads:
411-
results += thread.result
418+
for entry in self._threads:
419+
results += entry.thread.result
412420
return results
413421

414422

@@ -422,7 +430,7 @@ def is_alive(self) -> bool:
422430
"""
423431
if len(self._threads) == 0:
424432
raise exceptions.ThreadNotInitializedError()
425-
return any(thread.is_alive() for thread in self._threads)
433+
return any(entry.thread.is_alive() for entry in self._threads)
426434

427435

428436
def get_return_values(self) -> List[Data_Out]:
@@ -434,9 +442,9 @@ def get_return_values(self) -> List[Data_Out]:
434442
:returns Any: The return value of the target function
435443
"""
436444
results: List[Data_Out] = []
437-
for thread in self._threads:
438-
thread.join()
439-
results += thread.result
445+
for entry in self._threads:
446+
entry.thread.join()
447+
results += entry.thread.result
440448
return results
441449

442450

@@ -459,8 +467,8 @@ def join(self) -> bool:
459467
if self.status == 'Idle':
460468
raise exceptions.ThreadNotRunningError()
461469

462-
for thread in self._threads:
463-
thread.join()
470+
for entry in self._threads:
471+
entry.thread.join()
464472
return True
465473

466474

@@ -473,8 +481,8 @@ def kill(self) -> None:
473481
ThreadNotInitializedError: If the thread is not initialized
474482
ThreadNotRunningError: If the thread is not running
475483
"""
476-
for thread in self._threads:
477-
thread.kill()
484+
for entry in self._threads:
485+
entry.thread.kill()
478486

479487

480488
def start(self) -> None:
@@ -498,11 +506,11 @@ def start(self) -> None:
498506
for i, data_chunk in enumerate(numpy.array_split(self.dataset, max_threads)):
499507
chunk_thread = Thread(
500508
target = self.function,
501-
args = [data_chunk.tolist(), *parsed_args, *self.overflow_args],
509+
args = [i, data_chunk.tolist(), *parsed_args, *self.overflow_args],
502510
name = name_format and name_format % i or None,
503511
**self.overflow_kwargs
504512
)
505-
self._threads.append(chunk_thread)
513+
self._threads.append(_ThreadWorker(chunk_thread, 0))
506514
chunk_thread.start()
507515

508516

0 commit comments

Comments
 (0)