|
5 | 5 | import os |
6 | 6 | from typing import Tuple |
7 | 7 |
|
| 8 | +from google.api_core import retry |
8 | 9 | import requests |
9 | 10 | import requests.exceptions |
10 | 11 |
|
@@ -60,6 +61,8 @@ def __init__(self, |
60 | 61 | 'Authorization': 'Bearer %s' % api_key |
61 | 62 | } |
62 | 63 |
|
| 64 | + @retry.Retry(predicate=retry.if_exception_type( |
| 65 | + labelbox.exceptions.InternalServerError)) |
63 | 66 | def execute(self, query, params=None, timeout=10.0): |
64 | 67 | """ Sends a request to the server for the execution of the |
65 | 68 | given query. Checks the response for errors and wraps errors |
@@ -123,6 +126,9 @@ def convert_value(value): |
123 | 126 | try: |
124 | 127 | r_json = response.json() |
125 | 128 | except: |
| 129 | + error_502 = '502 Bad Gateway' |
| 130 | + if error_502 in response.text: |
| 131 | + raise labelbox.exceptions.InternalServerError(error_502) |
126 | 132 | raise labelbox.exceptions.LabelboxError( |
127 | 133 | "Failed to parse response as JSON: %s" % response.text) |
128 | 134 |
|
@@ -170,6 +176,12 @@ def check_errors(keywords, *path): |
170 | 176 | if response_msg.startswith("You have exceeded"): |
171 | 177 | raise labelbox.exceptions.ApiLimitError(response_msg) |
172 | 178 |
|
| 179 | + prisma_error = check_errors(["INTERNAL_SERVER_ERROR"], "extensions", |
| 180 | + "code") |
| 181 | + if prisma_error: |
| 182 | + raise labelbox.exceptions.InternalServerError( |
| 183 | + prisma_error["message"]) |
| 184 | + |
173 | 185 | if len(errors) > 0: |
174 | 186 | logger.warning("Unparsed errors on query execution: %r", errors) |
175 | 187 | raise labelbox.exceptions.LabelboxError("Unknown error: %s" % |
|
0 commit comments