22from queue import Queue
33from typing import Any , Iterable
44
5- from labelbox .exceptions import ThreadException
6-
75
86class ThreadSafeGen :
97 """
@@ -25,27 +23,6 @@ def __next__(self):
2523 return next (self .iterable )
2624
2725
28- class PrefetchThread (threading .Thread ):
29- """Class to override the Thread class. Helps raise
30- exceptions to the main caller thread
31- """
32-
33- def __init__ (self , ** kwargs ):
34- super ().__init__ (** kwargs )
35- self .exc = None
36-
37- def run (self ):
38- try :
39- super ().run ()
40- except BaseException as e :
41- self .exc = e
42-
43- def join (self , timeout = None ):
44- threading .Thread .join (self )
45- if self .exc :
46- raise self .exc
47-
48-
4926class PrefetchGenerator :
5027 """
5128 Applys functions asynchronously to the output of a generator.
@@ -59,7 +36,7 @@ class PrefetchGenerator:
5936 def __init__ (self ,
6037 data : Iterable [Any ],
6138 prefetch_limit = 20 ,
62- num_executors = 1 ,
39+ num_executors = 4 ,
6340 multithread : bool = False ):
6441 if isinstance (data , (list , tuple )):
6542 self ._data = (r for r in data )
@@ -71,16 +48,18 @@ def __init__(self,
7148 self .completed_threads = 0
7249 # Can only iterate over once it the queue.get hangs forever.
7350 self .done = False
51+
7452 if multithread :
75- num_executors = 4
76- self .num_executors = num_executors
77- self .threads = [
78- PrefetchThread (target = self .fill_queue ) for _ in range (num_executors )
79- ]
80- for thread in self .threads :
81- thread .daemon = True
82- thread .start ()
83- thread .join ()
53+ self .num_executors = num_executors
54+ self .threads = [
55+ threading .Thread (target = self .fill_queue )
56+ for _ in range (num_executors )
57+ ]
58+ for thread in self .threads :
59+ thread .daemon = True
60+ thread .start ()
61+ else :
62+ self .fill_queue ()
8463
8564 def _process (self , value ) -> Any :
8665 raise NotImplementedError ("Abstract method needs to be implemented" )
@@ -93,16 +72,18 @@ def fill_queue(self):
9372 raise ValueError ("Unexpected None" )
9473 self .queue .put (value )
9574 except :
96- raise ThreadException (
97- "Unexpected exception while filling the queue." )
75+ self . queue . put (
76+ ValueError ( "Unexpected exception while filling queue." ) )
9877
9978 def __iter__ (self ):
10079 return self
10180
10281 def __next__ (self ) -> Any :
10382 if self .done :
10483 raise StopIteration
105- value = self .queue .get ()
84+ value = self .queue .get (block = False )
85+ if isinstance (value , ValueError ):
86+ raise value
10687 while value is None :
10788 self .completed_threads += 1
10889 if self .completed_threads == self .num_executors :
0 commit comments