1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414
15+ import itertools
1516import threading as td
1617import time
1718import traceback
2122
2223from starlette .responses import Response
2324
24- from . .exceptions import UserRuntimeException
25- from . .log import logger
25+ from cortex_internal . lib .exceptions import UserRuntimeException
26+ from cortex_internal . lib .log import logger
2627
2728
2829class DynamicBatcher :
29- def __init__ (self , predictor_impl : Callable , max_batch_size : int , batch_interval : int ):
30+ def __init__ (
31+ self ,
32+ predictor_impl : Callable ,
33+ max_batch_size : int ,
34+ batch_interval : int ,
35+ test_mode : bool = False ,
36+ ):
3037 self .predictor_impl = predictor_impl
3138
3239 self .batch_max_size = max_batch_size
3340 self .batch_interval = batch_interval # measured in seconds
41+ self .test_mode = test_mode # only for unit testing
42+ self ._test_batch_lengths = [] # only when unit testing
3443
35- # waiter prevents new threads from modifying the input batch while a batch prediction is in progress
36- self .waiter = td .Event ()
37- self .waiter .set ()
38-
39- self .barrier = td .Barrier (self .batch_max_size + 1 , action = self .waiter .clear )
44+ self .barrier = td .Barrier (self .batch_max_size + 1 )
4045
4146 self .samples = {}
4247 self .predictions = {}
43- td .Thread (target = self ._batch_engine ).start ()
48+ td .Thread (target = self ._batch_engine , daemon = True ).start ()
49+
50+ self .sample_id_generator = itertools .count ()
4451
4552 def _batch_engine (self ):
4653 while True :
@@ -54,43 +61,48 @@ def _batch_engine(self):
5461 pass
5562
5663 self .predictions = {}
57-
64+ sample_ids = self . _get_sample_ids ( self . batch_max_size )
5865 try :
5966 if self .samples :
60- batch = self ._make_batch (self . samples )
67+ batch = self ._make_batch (sample_ids )
6168
6269 predictions = self .predictor_impl .predict (** batch )
6370 if not isinstance (predictions , list ):
6471 raise UserRuntimeException (
6572 f"please return a list when using server side batching, got { type (predictions )} "
6673 )
6774
68- self .predictions = dict (zip (self .samples .keys (), predictions ))
75+ if self .test_mode :
76+ self ._test_batch_lengths .append (len (predictions ))
77+
78+ self .predictions = dict (zip (sample_ids , predictions ))
6979 except Exception as e :
70- self .predictions = {thread_id : e for thread_id in self . samples }
80+ self .predictions = {sample_id : e for sample_id in sample_ids }
7181 logger .error (traceback .format_exc ())
7282 finally :
73- self .samples = {}
83+ for sample_id in sample_ids :
84+ del self .samples [sample_id ]
7485 self .barrier .reset ()
75- self .waiter .set ()
7686
77- @staticmethod
78- def _make_batch (samples : Dict [int , Dict [str , Any ]]) -> Dict [str , List [Any ]]:
87+ def _get_sample_ids (self , max_number : int ) -> List [int ]:
88+ if len (self .samples ) <= max_number :
89+ return list (self .samples .keys ())
90+ return sorted (self .samples )[:max_number ]
91+
92+ def _make_batch (self , sample_ids : List [int ]) -> Dict [str , List [Any ]]:
7993 batched_samples = defaultdict (list )
80- for thread_id in samples :
81- for key , sample in samples [thread_id ].items ():
94+ for sample_id in sample_ids :
95+ for key , sample in self . samples [sample_id ].items ():
8296 batched_samples [key ].append (sample )
8397
8498 return dict (batched_samples )
8599
86- def _enqueue_request (self , ** kwargs ):
100+ def _enqueue_request (self , sample_id : int , ** kwargs ):
87101 """
88102 Enqueue sample for batch inference. This is a blocking method.
89103 """
90- thread_id = td .get_ident ()
91104
92- self .waiter .wait ()
93- self .samples [thread_id ] = kwargs
105+ self .samples [sample_id ] = kwargs
94106 try :
95107 self .barrier .wait ()
96108 except td .BrokenBarrierError :
@@ -101,20 +113,20 @@ def predict(self, **kwargs):
101113 Queues a request to be batched with other incoming request, waits for the response
102114 and returns the prediction result. This is a blocking method.
103115 """
104- self ._enqueue_request (** kwargs )
105- prediction = self ._get_prediction ()
116+ sample_id = next (self .sample_id_generator )
117+ self ._enqueue_request (sample_id , ** kwargs )
118+ prediction = self ._get_prediction (sample_id )
106119 return prediction
107120
108- def _get_prediction (self ) -> Any :
121+ def _get_prediction (self , sample_id : int ) -> Any :
109122 """
110123 Return the prediction. This is a blocking method.
111124 """
112- thread_id = td .get_ident ()
113- while thread_id not in self .predictions :
125+ while sample_id not in self .predictions :
114126 time .sleep (0.001 )
115127
116- prediction = self .predictions [thread_id ]
117- del self .predictions [thread_id ]
128+ prediction = self .predictions [sample_id ]
129+ del self .predictions [sample_id ]
118130
119131 if isinstance (prediction , Exception ):
120132 return Response (
0 commit comments