Skip to content

Commit 202749a

Browse files
author
Andrei Lapanik
authored
Failed batches processing with dead-letter queue (#1713)
1 parent 5dfebf0 commit 202749a

File tree

13 files changed

+304
-150
lines changed

13 files changed

+304
-150
lines changed

cli/cmd/lib_batch_apis.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ func batchAPITable(batchAPI schema.APIResponse) string {
121121
{Title: "job id"},
122122
{Title: "status"},
123123
{Title: "progress"}, // (succeeded/total)
124-
{Title: "failed", Hidden: totalFailed == 0},
124+
{Title: "failed attempts", Hidden: totalFailed == 0},
125125
{Title: "start time"},
126126
{Title: "duration"},
127127
},
@@ -200,7 +200,7 @@ func getJob(env cliconfig.Environment, apiName string, jobID string) (string, er
200200
Headers: []table.Header{
201201
{Title: "total"},
202202
{Title: "succeeded"},
203-
{Title: "failed"},
203+
{Title: "failed attempts"},
204204
{Title: "avg time per batch"},
205205
},
206206
Rows: [][]interface{}{

docs/workloads/batch/endpoints.md

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,10 @@ POST <batch_api_endpoint>:
3333
{
3434
"workers": <int>, # the number of workers to allocate for this job (required)
3535
"timeout": <int>, # duration in seconds since the submission of a job before it is terminated (optional)
36+
"sqs_dead_letter_queue": { # specify a queue to redirect failed batches (optional)
37+
"arn": <string>, # arn of dead letter queue e.g. arn:aws:sqs:us-west-2:123456789:failed.fifo
38+
"max_receive_count": <int> # number of a times a batch is allowed to be handled by a worker before it is considered to be failed and transferred to the dead letter queue (must be >= 1)
39+
},
3640
"item_list": {
3741
"items": [ # a list items that can be of any type (required)
3842
<any>,
@@ -54,6 +58,10 @@ RESPONSE:
5458
"api_id": <string>,
5559
"sqs_url": <string>,
5660
"timeout": <int>,
61+
"sqs_dead_letter_queue": {
62+
"arn": <string>,
63+
"max_receive_count": <int>
64+
},
5765
"created_time": <string> # e.g. 2020-07-16T14:56:10.276007415Z
5866
}
5967
```
@@ -76,6 +84,10 @@ POST <batch_api_endpoint>:
7684
{
7785
"workers": <int>, # the number of workers to allocate for this job (required)
7886
"timeout": <int>, # duration in seconds since the submission of a job before it is terminated (optional)
87+
"sqs_dead_letter_queue": { # specify a queue to redirect failed batches (optional)
88+
"arn": <string>, # arn of dead letter queue e.g. arn:aws:sqs:us-west-2:123456789:failed.fifo
89+
"max_receive_count": <int> # number of a times a batch is allowed to be handled by a worker before it is considered to be failed and transferred to the dead letter queue (must be >= 1)
90+
},
7991
"file_path_lister": {
8092
"s3_paths": [<string>], # can be S3 prefixes or complete S3 paths (required)
8193
"includes": [<string>], # glob patterns (optional)
@@ -96,6 +108,10 @@ RESPONSE:
96108
"api_id": <string>,
97109
"sqs_url": <string>,
98110
"timeout": <int>,
111+
"sqs_dead_letter_queue": {
112+
"arn": <string>,
113+
"max_receive_count": <int>
114+
},
99115
"created_time": <string> # e.g. 2020-07-16T14:56:10.276007415Z
100116
}
101117
```
@@ -117,6 +133,10 @@ POST <batch_api_endpoint>:
117133
{
118134
"workers": <int>, # the number of workers to allocate for this job (required)
119135
"timeout": <int>, # duration in seconds since the submission of a job before it is terminated (optional)
136+
"sqs_dead_letter_queue": { # specify a queue to redirect failed batches (optional)
137+
"arn": <string>, # arn of dead letter queue e.g. arn:aws:sqs:us-west-2:123456789:failed.fifo
138+
"max_receive_count": <int> # number of a times a batch is allowed to be handled by a worker before it is considered to be failed and transferred to the dead letter queue (must be >= 1)
139+
},
120140
"delimited_files": {
121141
"s3_paths": [<string>], # can be S3 prefixes or complete S3 paths (required)
122142
"includes": [<string>], # glob patterns (optional)
@@ -137,6 +157,10 @@ RESPONSE:
137157
"api_id": <string>,
138158
"sqs_url": <string>,
139159
"timeout": <int>,
160+
"sqs_dead_letter_queue": {
161+
"arn": <string>,
162+
"max_receive_count": <int>
163+
},
140164
"created_time": <string> # e.g. 2020-07-16T14:56:10.276007415Z
141165
}
142166
```
@@ -163,8 +187,8 @@ RESPONSE:
163187
"batches_in_queue": <int> # number of batches remaining in the queue
164188
"batch_metrics": {
165189
"succeeded": <int> # number of succeeded batches
166-
"failed": int # number of failed batches
167-
"avg_time_per_batch": <float> (optional) # only available if batches have been completed
190+
"failed": int # number of failed attempts
191+
"avg_time_per_batch": <float> (optional) # average time spent working on a batch (only considers successful attempts)
168192
},
169193
"worker_counts": { # worker counts are only available while a job is running
170194
"pending": <int>, # number of workers that are waiting for compute resources to be provisioned

docs/workloads/batch/example.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@ Deploy batch APIs that can orchestrate distributed batch inference jobs on large
55
## Key features
66

77
* Distributed inference
8-
* Fault tolerance with queues
8+
* Automatic batch retries
9+
* Collect failed batches for debugging
910
* Metrics and log aggregation
1011
* `on_job_complete` webhook
1112
* Scale to 0

pkg/cortex/serve/start/batch.py

Lines changed: 146 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import threading
2222
import math
2323
import pathlib
24+
import uuid
2425

2526
import boto3
2627
import botocore
@@ -34,7 +35,11 @@
3435
from cortex_internal.lib.exceptions import UserRuntimeException
3536

3637
API_LIVENESS_UPDATE_PERIOD = 5 # seconds
37-
MAXIMUM_MESSAGE_VISIBILITY = 60 * 60 * 12 # 12 hours is the maximum message visibility
38+
SQS_POLL_WAIT_TIME = 10 # seconds
39+
MESSAGE_NOT_FOUND_SLEEP = 10 # seconds
40+
INITIAL_MESSAGE_VISIBILITY = 30 # seconds
41+
MESSAGE_RENEWAL_PERIOD = 15 # seconds
42+
JOB_COMPLETE_MESSAGE_RENEWAL = 10 # seconds
3843

3944
local_cache = {
4045
"api_spec": None,
@@ -48,6 +53,10 @@
4853
}
4954

5055

56+
receipt_handle_mutex = threading.Lock()
57+
stop_renewal = set()
58+
59+
5160
def dimensions():
5261
return [
5362
{"Name": "APIName", "Value": local_cache["api_spec"].name},
@@ -67,6 +76,40 @@ def time_per_batch_metric(total_time_seconds):
6776
return {"MetricName": "TimePerBatch", "Dimensions": dimensions(), "Value": total_time_seconds}
6877

6978

79+
def renew_message_visibility(receipt_handle: str):
80+
queue_url = local_cache["job_spec"]["sqs_url"]
81+
interval = MESSAGE_RENEWAL_PERIOD
82+
new_timeout = INITIAL_MESSAGE_VISIBILITY
83+
cur_time = time.time()
84+
85+
while True:
86+
time.sleep((cur_time + interval) - time.time())
87+
cur_time += interval
88+
new_timeout += interval
89+
90+
with receipt_handle_mutex:
91+
if receipt_handle in stop_renewal:
92+
stop_renewal.remove(receipt_handle)
93+
break
94+
95+
try:
96+
local_cache["sqs_client"].change_message_visibility(
97+
QueueUrl=queue_url, ReceiptHandle=receipt_handle, VisibilityTimeout=new_timeout
98+
)
99+
except botocore.exceptions.ClientError as e:
100+
if e.response["Error"]["Code"] == "InvalidParameterValue":
101+
# unexpected; this error is thrown when attempting to renew a message that has been deleted
102+
continue
103+
elif e.response["Error"]["Code"] == "AWS.SimpleQueueService.NonExistentQueue":
104+
# there may be a delay between the cron may deleting the queue and this worker stopping
105+
cx_logger().info(
106+
"failed to renew message visibility because the queue was not found"
107+
)
108+
else:
109+
stop_renewal.remove(receipt_handle)
110+
raise e
111+
112+
70113
def build_predict_args(payload, batch_id):
71114
args = {}
72115

@@ -102,49 +145,6 @@ def get_total_messages_in_queue():
102145
return visible_count, not_visible_count
103146

104147

105-
def handle_on_complete(message):
106-
job_spec = local_cache["job_spec"]
107-
predictor_impl = local_cache["predictor_impl"]
108-
sqs_client = local_cache["sqs_client"]
109-
queue_url = job_spec["sqs_url"]
110-
receipt_handle = message["ReceiptHandle"]
111-
112-
try:
113-
if not getattr(predictor_impl, "on_job_complete", None):
114-
sqs_client.delete_message(QueueUrl=queue_url, ReceiptHandle=receipt_handle)
115-
return True
116-
117-
should_run_on_job_complete = False
118-
119-
while True:
120-
visible_count, not_visible_count = get_total_messages_in_queue()
121-
122-
# if there are other messages that are visible, release this message and get the other ones (should rarely happen for FIFO)
123-
if visible_count > 0:
124-
sqs_client.change_message_visibility(
125-
QueueUrl=queue_url, ReceiptHandle=receipt_handle, VisibilityTimeout=0
126-
)
127-
return False
128-
129-
if should_run_on_job_complete:
130-
# double check that the queue is still empty (except for the job_complete message)
131-
if not_visible_count <= 1:
132-
logger().info("executing on_job_complete")
133-
predictor_impl.on_job_complete()
134-
sqs_client.delete_message(QueueUrl=queue_url, ReceiptHandle=receipt_handle)
135-
return True
136-
else:
137-
should_run_on_job_complete = False
138-
139-
if not_visible_count <= 1:
140-
should_run_on_job_complete = True
141-
142-
time.sleep(20)
143-
except:
144-
sqs_client.delete_message(QueueUrl=queue_url, ReceiptHandle=receipt_handle)
145-
raise
146-
147-
148148
def sqs_loop():
149149
job_spec = local_cache["job_spec"]
150150
api_spec = local_cache["api_spec"]
@@ -159,52 +159,118 @@ def sqs_loop():
159159
response = sqs_client.receive_message(
160160
QueueUrl=queue_url,
161161
MaxNumberOfMessages=1,
162-
WaitTimeSeconds=10,
163-
VisibilityTimeout=MAXIMUM_MESSAGE_VISIBILITY,
162+
WaitTimeSeconds=SQS_POLL_WAIT_TIME,
163+
VisibilityTimeout=INITIAL_MESSAGE_VISIBILITY,
164164
MessageAttributeNames=["All"],
165165
)
166166

167167
if response.get("Messages") is None or len(response["Messages"]) == 0:
168-
if no_messages_found_in_previous_iteration:
169-
logger().info("no batches left in queue, exiting...")
170-
return
171-
else:
168+
visible_messages, invisible_messages = get_total_messages_in_queue()
169+
if visible_messages + invisible_messages == 0:
170+
if no_messages_found_in_previous_iteration:
171+
logger().info("no batches left in queue, exiting...")
172+
return
172173
no_messages_found_in_previous_iteration = True
173-
continue
174-
else:
175-
no_messages_found_in_previous_iteration = False
176174

177-
message = response["Messages"][0]
175+
time.sleep(MESSAGE_NOT_FOUND_SLEEP)
176+
continue
178177

178+
no_messages_found_in_previous_iteration = False
179+
message = response["Messages"][0]
179180
receipt_handle = message["ReceiptHandle"]
180181

181-
if "MessageAttributes" in message and "job_complete" in message["MessageAttributes"]:
182-
handled_on_complete = handle_on_complete(message)
183-
if handled_on_complete:
184-
logger().info("no batches left in queue, job has been completed")
185-
return
182+
renewer = threading.Thread(
183+
target=renew_message_visibility, args=(receipt_handle,), daemon=True
184+
)
185+
renewer.start()
186+
187+
if is_on_job_complete(message):
188+
handle_on_job_complete(message)
189+
else:
190+
handle_batch_message(message)
191+
192+
193+
def is_on_job_complete(message) -> bool:
194+
return "MessageAttributes" in message and "job_complete" in message["MessageAttributes"]
195+
196+
197+
def handle_batch_message(message):
198+
job_spec = local_cache["job_spec"]
199+
predictor_impl = local_cache["predictor_impl"]
200+
sqs_client = local_cache["sqs_client"]
201+
queue_url = job_spec["sqs_url"]
202+
receipt_handle = message["ReceiptHandle"]
203+
api_spec = local_cache["api_spec"]
204+
205+
start_time = time.time()
206+
207+
try:
208+
logger().info(f"processing batch {message['MessageId']}")
209+
payload = json.loads(message["Body"])
210+
batch_id = message["MessageId"]
211+
predictor_impl.predict(**build_predict_args(payload, batch_id))
212+
213+
api_spec.post_metrics(
214+
[success_counter_metric(), time_per_batch_metric(time.time() - start_time)]
215+
)
216+
except:
217+
api_spec.post_metrics([failed_counter_metric()])
218+
logger().exception(f"failed processing batch {message['MessageId']}")
219+
with receipt_handle_mutex:
220+
stop_renewal.add(receipt_handle)
221+
if job_spec.get("sqs_dead_letter_queue") is not None:
222+
sqs_client.change_message_visibility( # return message
223+
QueueUrl=queue_url, ReceiptHandle=receipt_handle, VisibilityTimeout=0
224+
)
225+
else:
226+
sqs_client.delete_message(QueueUrl=queue_url, ReceiptHandle=receipt_handle)
227+
else:
228+
with receipt_handle_mutex:
229+
stop_renewal.add(receipt_handle)
230+
sqs_client.delete_message(QueueUrl=queue_url, ReceiptHandle=receipt_handle)
231+
232+
233+
def handle_on_job_complete(message):
234+
job_spec = local_cache["job_spec"]
235+
predictor_impl = local_cache["predictor_impl"]
236+
sqs_client = local_cache["sqs_client"]
237+
queue_url = job_spec["sqs_url"]
238+
receipt_handle = message["ReceiptHandle"]
239+
240+
should_run_on_job_complete = False
241+
try:
242+
while True:
243+
visible_messages, invisible_messages = get_total_messages_in_queue()
244+
total_messages = visible_messages + invisible_messages
245+
if total_messages > 1:
246+
new_message_id = uuid.uuid4()
247+
time.sleep(JOB_COMPLETE_MESSAGE_RENEWAL)
248+
sqs_client.send_message(
249+
QueueUrl=queue_url,
250+
MessageBody='"job_complete"',
251+
MessageAttributes={
252+
"job_complete": {"StringValue": "true", "DataType": "String"},
253+
"api_name": {"StringValue": job_spec["api_name"], "DataType": "String"},
254+
"job_id": {"StringValue": job_spec["job_id"], "DataType": "String"},
255+
},
256+
MessageDeduplicationId=str(new_message_id),
257+
MessageGroupId=str(new_message_id),
258+
)
259+
break
186260
else:
187-
# sometimes on_job_complete message will be released if there are other messages still to be processed
188-
continue
189-
190-
try:
191-
logger().info(f"processing batch {message['MessageId']}")
192-
193-
start_time = time.time()
194-
195-
payload = json.loads(message["Body"])
196-
batch_id = message["MessageId"]
197-
predictor_impl.predict(**build_predict_args(payload, batch_id))
198-
199-
api_spec.post_metrics(
200-
[success_counter_metric(), time_per_batch_metric(time.time() - start_time)]
201-
)
202-
except Exception:
203-
api_spec.post_metrics(
204-
[failed_counter_metric(), time_per_batch_metric(time.time() - start_time)]
205-
)
206-
logger().exception("failed to process batch")
207-
finally:
261+
if should_run_on_job_complete:
262+
if getattr(predictor_impl, "on_job_complete", None):
263+
logger().info("executing on_job_complete")
264+
predictor_impl.on_job_complete()
265+
break
266+
should_run_on_job_complete = True
267+
time.sleep(10) # verify that the queue is empty one more time
268+
except:
269+
logger.exception("failed to handle on_job_complete")
270+
raise
271+
finally:
272+
with receipt_handle_mutex:
273+
stop_renewal.add(receipt_handle)
208274
sqs_client.delete_message(QueueUrl=queue_url, ReceiptHandle=receipt_handle)
209275

210276

0 commit comments

Comments
 (0)