1- import logging
21import threading
32from queue import Queue
43from typing import Any , Iterable
5- import threading
64
7- logger = logging . getLogger ( __name__ )
5+ from labelbox . exceptions import ThreadException
86
97
108class ThreadSafeGen :
@@ -27,13 +25,43 @@ def __next__(self):
2725 return next (self .iterable )
2826
2927
28+ class PrefetchThread (threading .Thread ):
29+ """Class to override the Thread class. Helps raise
30+ exceptions to the main caller thread
31+ """
32+
33+ def __init__ (self , ** kwargs ):
34+ super ().__init__ (** kwargs )
35+ self .exc = None
36+
37+ def run (self ):
38+ try :
39+ super ().run ()
40+ except BaseException as e :
41+ self .exc = e
42+
43+ def join (self , timeout = None ):
44+ threading .Thread .join (self )
45+ if self .exc :
46+ raise self .exc
47+
48+
3049class PrefetchGenerator :
3150 """
3251 Applys functions asynchronously to the output of a generator.
3352 Useful for modifying the generator results based on data from a network
3453 """
3554
36- def __init__ (self , data : Iterable [Any ], prefetch_limit = 20 , num_executors = 4 ):
55+ #maybe change num exec to just 1, and if 1, make sync
56+ #instead of self.get qeue in next, itll return just self._data.next
57+ #kwarg on export for multithread, and all other things that use prefetch
58+
59+ # def __init__(self, data: Iterable[Any], prefetch_limit=20, num_executors=4):
60+ def __init__ (self ,
61+ data : Iterable [Any ],
62+ prefetch_limit = 20 ,
63+ num_executors = 1 ,
64+ multithread : bool = False ):
3765 if isinstance (data , (list , tuple )):
3866 self ._data = (r for r in data )
3967 else :
@@ -44,14 +72,17 @@ def __init__(self, data: Iterable[Any], prefetch_limit=20, num_executors=4):
4472 self .completed_threads = 0
4573 # Can only iterate over once it the queue.get hangs forever.
4674 self .done = False
75+ if multithread :
76+ num_executors = 4
4777 self .num_executors = num_executors
4878 self .threads = [
49- threading .Thread (target = self .fill_queue )
50- for _ in range (num_executors )
79+ PrefetchThread (target = self .fill_queue ) for _ in range (num_executors )
5180 ]
5281 for thread in self .threads :
5382 thread .daemon = True
5483 thread .start ()
84+ for thread in self .threads :
85+ thread .join ()
5586
5687 def _process (self , value ) -> Any :
5788 raise NotImplementedError ("Abstract method needs to be implemented" )
@@ -63,11 +94,9 @@ def fill_queue(self):
6394 if value is None :
6495 raise ValueError ("Unexpected None" )
6596 self .queue .put (value )
66- except Exception as e :
67- logger .warning ("Unexpected exception while filling the queue. %r" ,
68- e )
69- finally :
70- self .queue .put (None )
97+ except :
98+ raise ThreadException (
99+ "Unexpected exception while filling the queue." )
71100
72101 def __iter__ (self ):
73102 return self
0 commit comments