Skip to content

Commit 0b1b649

Browse files
Prevent threads from being stuck in DynamicBatcher (#1915)
Co-authored-by: Robert Lucian Chiriac <robert.lucian.chiriac@gmail.com>
1 parent 132c2e9 commit 0b1b649

File tree

5 files changed

+165
-30
lines changed

5 files changed

+165
-30
lines changed

images/test/Dockerfile

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,10 @@ RUN pip install --upgrade pip && \
1010
COPY pkg /src
1111
COPY images/test/run.sh /src/run.sh
1212

13+
COPY pkg/cortex/serve/log_config.yaml /src/cortex/serve/log_config.yaml
14+
ENV CORTEX_LOG_LEVEL DEBUG
15+
ENV CORTEX_LOG_CONFIG_FILE /src/cortex/serve/log_config.yaml
16+
1317
RUN pip install --no-deps /src/cortex/serve/ && \
1418
rm -rf /root/.cache/pip*
1519

images/test/run.sh

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,12 @@
1818
err=0
1919
trap 'err=1' ERR
2020

21+
function substitute_env_vars() {
22+
file_to_run_substitution=$1
23+
python -c "from cortex_internal.lib import util; import os; util.expand_environment_vars_on_file('$file_to_run_substitution')"
24+
}
25+
26+
substitute_env_vars $CORTEX_LOG_CONFIG_FILE
2127
pytest lib/test
2228

2329
test $err = 0

pkg/cortex/serve/cortex_internal.requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
grpcio==1.32.0
12
boto3==1.14.53
23
google-cloud-storage==1.32.0
34
datadog==0.39.0

pkg/cortex/serve/cortex_internal/lib/api/batching.py

Lines changed: 42 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import itertools
1516
import threading as td
1617
import time
1718
import traceback
@@ -21,26 +22,32 @@
2122

2223
from starlette.responses import Response
2324

24-
from ..exceptions import UserRuntimeException
25-
from ..log import logger
25+
from cortex_internal.lib.exceptions import UserRuntimeException
26+
from cortex_internal.lib.log import logger
2627

2728

2829
class DynamicBatcher:
29-
def __init__(self, predictor_impl: Callable, max_batch_size: int, batch_interval: int):
30+
def __init__(
31+
self,
32+
predictor_impl: Callable,
33+
max_batch_size: int,
34+
batch_interval: int,
35+
test_mode: bool = False,
36+
):
3037
self.predictor_impl = predictor_impl
3138

3239
self.batch_max_size = max_batch_size
3340
self.batch_interval = batch_interval # measured in seconds
41+
self.test_mode = test_mode # only for unit testing
42+
self._test_batch_lengths = [] # only when unit testing
3443

35-
# waiter prevents new threads from modifying the input batch while a batch prediction is in progress
36-
self.waiter = td.Event()
37-
self.waiter.set()
38-
39-
self.barrier = td.Barrier(self.batch_max_size + 1, action=self.waiter.clear)
44+
self.barrier = td.Barrier(self.batch_max_size + 1)
4045

4146
self.samples = {}
4247
self.predictions = {}
43-
td.Thread(target=self._batch_engine).start()
48+
td.Thread(target=self._batch_engine, daemon=True).start()
49+
50+
self.sample_id_generator = itertools.count()
4451

4552
def _batch_engine(self):
4653
while True:
@@ -54,43 +61,48 @@ def _batch_engine(self):
5461
pass
5562

5663
self.predictions = {}
57-
64+
sample_ids = self._get_sample_ids(self.batch_max_size)
5865
try:
5966
if self.samples:
60-
batch = self._make_batch(self.samples)
67+
batch = self._make_batch(sample_ids)
6168

6269
predictions = self.predictor_impl.predict(**batch)
6370
if not isinstance(predictions, list):
6471
raise UserRuntimeException(
6572
f"please return a list when using server side batching, got {type(predictions)}"
6673
)
6774

68-
self.predictions = dict(zip(self.samples.keys(), predictions))
75+
if self.test_mode:
76+
self._test_batch_lengths.append(len(predictions))
77+
78+
self.predictions = dict(zip(sample_ids, predictions))
6979
except Exception as e:
70-
self.predictions = {thread_id: e for thread_id in self.samples}
80+
self.predictions = {sample_id: e for sample_id in sample_ids}
7181
logger.error(traceback.format_exc())
7282
finally:
73-
self.samples = {}
83+
for sample_id in sample_ids:
84+
del self.samples[sample_id]
7485
self.barrier.reset()
75-
self.waiter.set()
7686

77-
@staticmethod
78-
def _make_batch(samples: Dict[int, Dict[str, Any]]) -> Dict[str, List[Any]]:
87+
def _get_sample_ids(self, max_number: int) -> List[int]:
88+
if len(self.samples) <= max_number:
89+
return list(self.samples.keys())
90+
return sorted(self.samples)[:max_number]
91+
92+
def _make_batch(self, sample_ids: List[int]) -> Dict[str, List[Any]]:
7993
batched_samples = defaultdict(list)
80-
for thread_id in samples:
81-
for key, sample in samples[thread_id].items():
94+
for sample_id in sample_ids:
95+
for key, sample in self.samples[sample_id].items():
8296
batched_samples[key].append(sample)
8397

8498
return dict(batched_samples)
8599

86-
def _enqueue_request(self, **kwargs):
100+
def _enqueue_request(self, sample_id: int, **kwargs):
87101
"""
88102
Enqueue sample for batch inference. This is a blocking method.
89103
"""
90-
thread_id = td.get_ident()
91104

92-
self.waiter.wait()
93-
self.samples[thread_id] = kwargs
105+
self.samples[sample_id] = kwargs
94106
try:
95107
self.barrier.wait()
96108
except td.BrokenBarrierError:
@@ -101,20 +113,20 @@ def predict(self, **kwargs):
101113
Queues a request to be batched with other incoming request, waits for the response
102114
and returns the prediction result. This is a blocking method.
103115
"""
104-
self._enqueue_request(**kwargs)
105-
prediction = self._get_prediction()
116+
sample_id = next(self.sample_id_generator)
117+
self._enqueue_request(sample_id, **kwargs)
118+
prediction = self._get_prediction(sample_id)
106119
return prediction
107120

108-
def _get_prediction(self) -> Any:
121+
def _get_prediction(self, sample_id: int) -> Any:
109122
"""
110123
Return the prediction. This is a blocking method.
111124
"""
112-
thread_id = td.get_ident()
113-
while thread_id not in self.predictions:
125+
while sample_id not in self.predictions:
114126
time.sleep(0.001)
115127

116-
prediction = self.predictions[thread_id]
117-
del self.predictions[thread_id]
128+
prediction = self.predictions[sample_id]
129+
del self.predictions[sample_id]
118130

119131
if isinstance(prediction, Exception):
120132
return Response(
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
# Copyright 2021 Cortex Labs, Inc.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
import threading as td
17+
import itertools
18+
import time
19+
20+
import cortex_internal.lib.api.batching as batching
21+
22+
23+
class Predictor:
24+
def predict(self, payload):
25+
time.sleep(0.2)
26+
return payload
27+
28+
29+
def test_dynamic_batching_while_hitting_max_batch_size():
30+
max_batch_size = 32
31+
dynamic_batcher = batching.DynamicBatcher(
32+
Predictor(), max_batch_size=max_batch_size, batch_interval=0.1, test_mode=True
33+
)
34+
counter = itertools.count(1)
35+
event = td.Event()
36+
global_list = []
37+
38+
def submitter():
39+
while not event.is_set():
40+
global_list.append(dynamic_batcher.predict(payload=next(counter)))
41+
time.sleep(0.1)
42+
43+
running_threads = []
44+
for _ in range(128):
45+
thread = td.Thread(target=submitter, daemon=True)
46+
thread.start()
47+
running_threads.append(thread)
48+
49+
time.sleep(60)
50+
event.set()
51+
52+
# if this fails, then the submitter threads are getting stuck
53+
for thread in running_threads:
54+
thread.join(3.0)
55+
if thread.is_alive():
56+
raise TimeoutError("thread", thread.getName(), "got stuck")
57+
58+
sum1 = int(len(global_list) * (len(global_list) + 1) / 2)
59+
sum2 = sum(global_list)
60+
assert sum1 == sum2
61+
62+
# get the last 80% of batch lengths
63+
# we ignore the first 20% because it may take some time for all threads to start making requests
64+
batch_lengths = dynamic_batcher._test_batch_lengths
65+
batch_lengths = batch_lengths[int(len(batch_lengths) * 0.2) :]
66+
67+
# verify that the batch size is always equal to the max batch size
68+
assert len(set(batch_lengths)) == 1
69+
assert max_batch_size in batch_lengths
70+
71+
72+
def test_dynamic_batching_while_hitting_max_interval():
73+
max_batch_size = 32
74+
dynamic_batcher = batching.DynamicBatcher(
75+
Predictor(), max_batch_size=max_batch_size, batch_interval=1.0, test_mode=True
76+
)
77+
counter = itertools.count(1)
78+
event = td.Event()
79+
global_list = []
80+
81+
def submitter():
82+
while not event.is_set():
83+
global_list.append(dynamic_batcher.predict(payload=next(counter)))
84+
time.sleep(0.1)
85+
86+
running_threads = []
87+
for _ in range(2):
88+
thread = td.Thread(target=submitter, daemon=True)
89+
thread.start()
90+
running_threads.append(thread)
91+
92+
time.sleep(30)
93+
event.set()
94+
95+
# if this fails, then the submitter threads are getting stuck
96+
for thread in running_threads:
97+
thread.join(3.0)
98+
if thread.is_alive():
99+
raise TimeoutError("thread", thread.getName(), "got stuck")
100+
101+
sum1 = int(len(global_list) * (len(global_list) + 1) / 2)
102+
sum2 = sum(global_list)
103+
assert sum1 == sum2
104+
105+
# get the last 80% of batch lengths
106+
# we ignore the first 20% because it may take some time for all threads to start making requests
107+
batch_lengths = dynamic_batcher._test_batch_lengths
108+
batch_lengths = batch_lengths[int(len(batch_lengths) * 0.2) :]
109+
110+
# verify that the batch size is always equal to the number of running threads
111+
assert len(set(batch_lengths)) == 1
112+
assert len(running_threads) in batch_lengths

0 commit comments

Comments
 (0)