22import threading
33from queue import Queue
44from typing import Any , Iterable
5+ from concurrent .futures import ThreadPoolExecutor
56
67logger = logging .getLogger (__name__ )
78
@@ -32,10 +33,7 @@ class PrefetchGenerator:
3233 Useful for modifying the generator results based on data from a network
3334 """
3435
35- def __init__ (self ,
36- data : Iterable [Any ],
37- prefetch_limit = 20 ,
38- max_concurrency = 4 ):
36+ def __init__ (self , data : Iterable [Any ], prefetch_limit = 20 , num_executors = 4 ):
3937 if isinstance (data , (list , tuple )):
4038 self ._data = (r for r in data )
4139 else :
@@ -44,14 +42,13 @@ def __init__(self,
4442 self .queue = Queue (prefetch_limit )
4543 self ._data = ThreadSafeGen (self ._data )
4644 self .completed_threads = 0
47- self .max_concurrency = max_concurrency
48- self .threads = [
49- threading .Thread (target = self .fill_queue )
50- for _ in range (max_concurrency )
51- ]
52- for thread in self .threads :
53- thread .daemon = True
54- thread .start ()
45+ # Can only iterate over once it the queue.get hangs forever.
46+ self .done = False
47+ self .num_executors = num_executors
48+ with ThreadPoolExecutor (max_workers = num_executors ) as executor :
49+ self .futures = [
50+ executor .submit (self .fill_queue ) for _ in range (num_executors )
51+ ]
5552
5653 def _process (self , value ) -> Any :
5754 raise NotImplementedError ("Abstract method needs to be implemented" )
@@ -73,10 +70,13 @@ def __iter__(self):
7370 return self
7471
7572 def __next__ (self ) -> Any :
73+ if self .done :
74+ raise StopIteration
7675 value = self .queue .get ()
7776 while value is None :
7877 self .completed_threads += 1
79- if self .completed_threads == self .max_concurrency :
78+ if self .completed_threads == self .num_executors :
79+ self .done = True
8080 raise StopIteration
8181 value = self .queue .get ()
8282 return value
0 commit comments