|
5 | 5 | import os |
6 | 6 | from typing import Tuple |
7 | 7 |
|
| 8 | +from google.api_core import retry |
8 | 9 | import requests |
9 | 10 | import requests.exceptions |
10 | 11 |
|
@@ -60,6 +61,8 @@ def __init__(self, |
60 | 61 | 'Authorization': 'Bearer %s' % api_key |
61 | 62 | } |
62 | 63 |
|
| 64 | + @retry.Retry(predicate=retry.if_exception_type( |
| 65 | + labelbox.exceptions.InternalServerError)) |
63 | 66 | def execute(self, query, params=None, timeout=10.0): |
64 | 67 | """ Sends a request to the server for the execution of the |
65 | 68 | given query. Checks the response for errors and wraps errors |
@@ -121,12 +124,15 @@ def convert_value(value): |
121 | 124 | "Unknown error during Client.query(): " + str(e), e) |
122 | 125 |
|
123 | 126 | try: |
124 | | - response = response.json() |
| 127 | + r_json = response.json() |
125 | 128 | except: |
| 129 | + error_502 = '502 Bad Gateway' |
| 130 | + if error_502 in response.text: |
| 131 | + raise labelbox.exceptions.InternalServerError(error_502) |
126 | 132 | raise labelbox.exceptions.LabelboxError( |
127 | 133 | "Failed to parse response as JSON: %s" % response.text) |
128 | 134 |
|
129 | | - errors = response.get("errors", []) |
| 135 | + errors = r_json.get("errors", []) |
130 | 136 |
|
131 | 137 | def check_errors(keywords, *path): |
132 | 138 | """ Helper that looks for any of the given `keywords` in any of |
@@ -166,16 +172,32 @@ def check_errors(keywords, *path): |
166 | 172 | graphql_error["message"]) |
167 | 173 |
|
168 | 174 | # Check if API limit was exceeded |
169 | | - response_msg = response.get("message", "") |
| 175 | + response_msg = r_json.get("message", "") |
170 | 176 | if response_msg.startswith("You have exceeded"): |
171 | 177 | raise labelbox.exceptions.ApiLimitError(response_msg) |
172 | 178 |
|
| 179 | + prisma_error = check_errors(["INTERNAL_SERVER_ERROR"], "extensions", |
| 180 | + "code") |
| 181 | + if prisma_error: |
| 182 | + raise labelbox.exceptions.InternalServerError( |
| 183 | + prisma_error["message"]) |
| 184 | + |
173 | 185 | if len(errors) > 0: |
174 | 186 | logger.warning("Unparsed errors on query execution: %r", errors) |
175 | 187 | raise labelbox.exceptions.LabelboxError("Unknown error: %s" % |
176 | 188 | str(errors)) |
177 | 189 |
|
178 | | - return response["data"] |
| 190 | + # if we do return a proper error code, and didn't catch this above |
| 191 | + # reraise |
| 192 | + # this mainly catches a 401 for API access disabled for free tier |
| 193 | + # TODO: need to unify API errors to handle things more uniformly |
| 194 | + # in the SDK |
| 195 | + if response.status_code != requests.codes.ok: |
| 196 | + message = f"{response.status_code} {response.reason}" |
| 197 | + cause = r_json.get('message') |
| 198 | + raise labelbox.exceptions.LabelboxError(message, cause) |
| 199 | + |
| 200 | + return r_json["data"] |
179 | 201 |
|
180 | 202 | def upload_file(self, path: str) -> str: |
181 | 203 | """Uploads given path to local file. |
|
0 commit comments