Skip to content

Commit b012e39

Browse files
Grzegorz Szpakrllin
andauthored
Added BulkImportRequest integration (#27)
* Added method to create BulkImportRequest from dictionaries * Added method to upload local ndjson with predictions to Labelbox' GCS * Bugfix: field_type * Moved part of try block to else block * Removed UploadedFileType enum * Fix * Creating BulkImportRequest from url * Creating BulkImportRequest objects from objects and local file * Added ndjson validation + sending contentLength * Making relationships work for BulkImportRequest * Added tests for BulkImportRequests * Added test for BulkImportRequest.refresh() * Added docstrings * Updated changelog and setup.py * Vhanged test URL * Using existing URL in tests * Implemented BulkImportRequest.wait_till_done method * Actually sleeping * Bumped version to 2.4.3 * Yapfing the whole project * Made mypy happy * Made mypy happy one more time * freeze dependencies Co-authored-by: rllin <randall@labelbox.com>
1 parent ad6d439 commit b012e39

File tree

10 files changed

+500
-4
lines changed

10 files changed

+500
-4
lines changed

CHANGELOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
11
# Changelog
22

3+
## Version 2.4.3 (2020-08-04)
4+
5+
### Added
6+
* `BulkImportRequest` data type
7+
38
## Version 2.4.2 (2020-08-01)
49
### Fixed
510
* `Client.upload_data` will now pass the correct `content-length` when uploading data.

labelbox/orm/db_object.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,8 @@ def _set_field_values(self, field_values):
6464
logger.warning(
6565
"Failed to convert value '%s' to datetime for "
6666
"field %s", value, field)
67+
elif isinstance(field.field_type, Field.EnumType):
68+
value = field.field_type.enum_cls[value]
6769
setattr(self, field.name, value)
6870

6971
def __repr__(self):

labelbox/orm/model.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
from enum import Enum, auto
2+
from typing import Union
23

34
from labelbox import utils
4-
from labelbox.exceptions import InvalidAttributeError, LabelboxError
5+
from labelbox.exceptions import InvalidAttributeError
56
from labelbox.orm.comparison import Comparison
67
""" Defines Field, Relationship and Entity. These classes are building
78
blocks for defining the Labelbox schema, DB object operations and
@@ -42,6 +43,15 @@ class Type(Enum):
4243
ID = auto()
4344
DateTime = auto()
4445

46+
class EnumType:
47+
48+
def __init__(self, enum_cls: type):
49+
self.enum_cls = enum_cls
50+
51+
@property
52+
def name(self):
53+
return self.enum_cls.__name__
54+
4555
class Order(Enum):
4656
""" Type of sort ordering. """
4757
Asc = auto()
@@ -71,7 +81,14 @@ def ID(*args):
7181
def DateTime(*args):
7282
return Field(Field.Type.DateTime, *args)
7383

74-
def __init__(self, field_type, name, graphql_name=None):
84+
@staticmethod
85+
def Enum(enum_cls: type, *args):
86+
return Field(Field.EnumType(enum_cls), *args)
87+
88+
def __init__(self,
89+
field_type: Union[Type, EnumType],
90+
name,
91+
graphql_name=None):
7592
""" Field init.
7693
Args:
7794
field_type (Field.Type): The type of the field.
Lines changed: 320 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,320 @@
1+
import json
2+
import logging
3+
import time
4+
from pathlib import Path
5+
from typing import BinaryIO
6+
from typing import Iterable
7+
from typing import Tuple
8+
from typing import Union
9+
10+
import backoff
11+
import ndjson
12+
import requests
13+
14+
import labelbox.exceptions
15+
from labelbox import Client
16+
from labelbox import Project
17+
from labelbox import User
18+
from labelbox.orm import query
19+
from labelbox.orm.db_object import DbObject
20+
from labelbox.orm.model import Field
21+
from labelbox.orm.model import Relationship
22+
from labelbox.schema.enums import BulkImportRequestState
23+
24+
NDJSON_MIME_TYPE = "application/x-ndjson"
25+
logger = logging.getLogger(__name__)
26+
27+
28+
class BulkImportRequest(DbObject):
29+
project = Relationship.ToOne("Project")
30+
name = Field.String("name")
31+
created_at = Field.DateTime("created_at")
32+
created_by = Relationship.ToOne("User", False, "created_by")
33+
input_file_url = Field.String("input_file_url")
34+
error_file_url = Field.String("error_file_url")
35+
status_file_url = Field.String("status_file_url")
36+
state = Field.Enum(BulkImportRequestState, "state")
37+
38+
@classmethod
39+
def create_from_url(cls, client: Client, project_id: str, name: str,
40+
url: str) -> 'BulkImportRequest':
41+
"""
42+
Creates a BulkImportRequest from a publicly accessible URL
43+
to an ndjson file with predictions.
44+
45+
Args:
46+
client (Client): a Labelbox client
47+
project_id (str): id of project for which predictions will be imported
48+
name (str): name of BulkImportRequest
49+
url (str): publicly accessible URL pointing to ndjson file containing predictions
50+
Returns:
51+
BulkImportRequest object
52+
"""
53+
query_str = """mutation createBulkImportRequestPyApi(
54+
$projectId: ID!, $name: String!, $fileUrl: String!) {
55+
createBulkImportRequest(data: {
56+
projectId: $projectId,
57+
name: $name,
58+
fileUrl: $fileUrl
59+
}) {
60+
%s
61+
}
62+
}
63+
""" % cls.__build_results_query_part()
64+
params = {"projectId": project_id, "name": name, "fileUrl": url}
65+
bulk_import_request_response = client.execute(query_str, params=params)
66+
return cls.__build_bulk_import_request_from_result(
67+
client, bulk_import_request_response["createBulkImportRequest"])
68+
69+
@classmethod
70+
def create_from_objects(cls, client: Client, project_id: str, name: str,
71+
predictions: Iterable[dict]) -> 'BulkImportRequest':
72+
"""
73+
Creates a BulkImportRequest from an iterable of dictionaries conforming to
74+
JSON predictions format, e.g.:
75+
``{
76+
"uuid": "9fd9a92e-2560-4e77-81d4-b2e955800092",
77+
"schemaId": "ckappz7d700gn0zbocmqkwd9i",
78+
"dataRow": {
79+
"id": "ck1s02fqxm8fi0757f0e6qtdc"
80+
},
81+
"bbox": {
82+
"top": 48,
83+
"left": 58,
84+
"height": 865,
85+
"width": 1512
86+
}
87+
}``
88+
89+
Args:
90+
client (Client): a Labelbox client
91+
project_id (str): id of project for which predictions will be imported
92+
name (str): name of BulkImportRequest
93+
predictions (Iterable[dict]): iterable of dictionaries representing predictions
94+
Returns:
95+
BulkImportRequest object
96+
"""
97+
data_str = ndjson.dumps(predictions)
98+
data = data_str.encode('utf-8')
99+
file_name = cls.__make_file_name(project_id, name)
100+
request_data = cls.__make_request_data(project_id, name, len(data_str),
101+
file_name)
102+
file_data = (file_name, data, NDJSON_MIME_TYPE)
103+
response_data = cls.__send_create_file_command(client, request_data,
104+
file_name, file_data)
105+
return cls.__build_bulk_import_request_from_result(
106+
client, response_data["createBulkImportRequest"])
107+
108+
@classmethod
109+
def create_from_local_file(cls,
110+
client: Client,
111+
project_id: str,
112+
name: str,
113+
file: Path,
114+
validate_file=True) -> 'BulkImportRequest':
115+
"""
116+
Creates a BulkImportRequest from a local ndjson file with predictions.
117+
118+
Args:
119+
client (Client): a Labelbox client
120+
project_id (str): id of project for which predictions will be imported
121+
name (str): name of BulkImportRequest
122+
file (Path): local ndjson file with predictions
123+
validate_file (bool): a flag indicating if there should be a validation
124+
if `file` is a valid ndjson file
125+
Returns:
126+
BulkImportRequest object
127+
"""
128+
file_name = cls.__make_file_name(project_id, name)
129+
content_length = file.stat().st_size
130+
request_data = cls.__make_request_data(project_id, name, content_length,
131+
file_name)
132+
with file.open('rb') as f:
133+
file_data: Tuple[str, Union[bytes, BinaryIO], str]
134+
if validate_file:
135+
data = f.read()
136+
try:
137+
ndjson.loads(data)
138+
except ValueError:
139+
raise ValueError(f"{file} is not a valid ndjson file")
140+
file_data = (file.name, data, NDJSON_MIME_TYPE)
141+
else:
142+
file_data = (file.name, f, NDJSON_MIME_TYPE)
143+
response_data = cls.__send_create_file_command(
144+
client, request_data, file_name, file_data)
145+
return cls.__build_bulk_import_request_from_result(
146+
client, response_data["createBulkImportRequest"])
147+
148+
# TODO(gszpak): building query body should be handled by the client
149+
@classmethod
150+
def get(cls, client: Client, project_id: str,
151+
name: str) -> 'BulkImportRequest':
152+
"""
153+
Fetches existing BulkImportRequest.
154+
155+
Args:
156+
client (Client): a Labelbox client
157+
project_id (str): BulkImportRequest's project id
158+
name (str): name of BulkImportRequest
159+
Returns:
160+
BulkImportRequest object
161+
"""
162+
query_str = """query getBulkImportRequestPyApi(
163+
$projectId: ID!, $name: String!) {
164+
bulkImportRequest(where: {
165+
projectId: $projectId,
166+
name: $name
167+
}) {
168+
%s
169+
}
170+
}
171+
""" % cls.__build_results_query_part()
172+
params = {"projectId": project_id, "name": name}
173+
bulk_import_request_kwargs = \
174+
client.execute(query_str, params=params).get("bulkImportRequest")
175+
if bulk_import_request_kwargs is None:
176+
raise labelbox.exceptions.ResourceNotFoundError(
177+
BulkImportRequest, {
178+
"projectId": project_id,
179+
"name": name
180+
})
181+
return cls.__build_bulk_import_request_from_result(
182+
client, bulk_import_request_kwargs)
183+
184+
def refresh(self) -> None:
185+
"""
186+
Synchronizes values of all fields with the database.
187+
"""
188+
bulk_import_request = self.get(self.client,
189+
self.project().uid, self.name)
190+
for field in self.fields():
191+
setattr(self, field.name, getattr(bulk_import_request, field.name))
192+
193+
def wait_until_done(self, sleep_time_seconds: int = 30) -> None:
194+
"""
195+
Blocks until the BulkImportRequest.state changes either to
196+
`BulkImportRequestState.FINISHED` or `BulkImportRequestState.FAILED`,
197+
periodically refreshing object's state.
198+
199+
Args:
200+
sleep_time_seconds (str): a time to block between subsequent API calls
201+
"""
202+
while self.state == BulkImportRequestState.RUNNING:
203+
logger.info(f"Sleeping for {sleep_time_seconds} seconds...")
204+
time.sleep(sleep_time_seconds)
205+
self.__exponential_backoff_refresh()
206+
207+
@backoff.on_exception(
208+
backoff.expo,
209+
(labelbox.exceptions.ApiLimitError, labelbox.exceptions.TimeoutError,
210+
labelbox.exceptions.NetworkError),
211+
max_tries=10,
212+
jitter=None)
213+
def __exponential_backoff_refresh(self) -> None:
214+
self.refresh()
215+
216+
# TODO(gszpak): project() and created_by() methods
217+
# TODO(gszpak): are hacky ways to eagerly load the relationships
218+
def project(self): # type: ignore
219+
if self.__project is not None:
220+
return self.__project
221+
return None
222+
223+
def created_by(self): # type: ignore
224+
if self.__user is not None:
225+
return self.__user
226+
return None
227+
228+
@classmethod
229+
def __make_file_name(cls, project_id: str, name: str) -> str:
230+
return f"{project_id}__{name}.ndjson"
231+
232+
# TODO(gszpak): move it to client.py
233+
@classmethod
234+
def __make_request_data(cls, project_id: str, name: str,
235+
content_length: int, file_name: str) -> dict:
236+
query_str = """mutation createBulkImportRequestFromFilePyApi(
237+
$projectId: ID!, $name: String!, $file: Upload!, $contentLength: Int!) {
238+
createBulkImportRequest(data: {
239+
projectId: $projectId,
240+
name: $name,
241+
filePayload: {
242+
file: $file,
243+
contentLength: $contentLength
244+
}
245+
}) {
246+
%s
247+
}
248+
}
249+
""" % cls.__build_results_query_part()
250+
variables = {
251+
"projectId": project_id,
252+
"name": name,
253+
"file": None,
254+
"contentLength": content_length
255+
}
256+
operations = json.dumps({"variables": variables, "query": query_str})
257+
258+
return {
259+
"operations": operations,
260+
"map": (None, json.dumps({file_name: ["variables.file"]}))
261+
}
262+
263+
# TODO(gszpak): move it to client.py
264+
@classmethod
265+
def __send_create_file_command(
266+
cls, client: Client, request_data: dict, file_name: str,
267+
file_data: Tuple[str, Union[bytes, BinaryIO], str]) -> dict:
268+
response = requests.post(
269+
client.endpoint,
270+
headers={"authorization": "Bearer %s" % client.api_key},
271+
data=request_data,
272+
files={file_name: file_data})
273+
274+
try:
275+
response_json = response.json()
276+
except ValueError:
277+
raise labelbox.exceptions.LabelboxError(
278+
"Failed to parse response as JSON: %s" % response.text)
279+
280+
response_data = response_json.get("data", None)
281+
if response_data is None:
282+
raise labelbox.exceptions.LabelboxError(
283+
"Failed to upload, message: %s" %
284+
response_json.get("errors", None))
285+
286+
if not response_data.get("createBulkImportRequest", None):
287+
raise labelbox.exceptions.LabelboxError(
288+
"Failed to create BulkImportRequest, message: %s" %
289+
response_json.get("errors", None) or
290+
response_data.get("error", None))
291+
292+
return response_data
293+
294+
# TODO(gszpak): all the code below should be handled automatically by Relationship
295+
@classmethod
296+
def __build_results_query_part(cls) -> str:
297+
return """
298+
project {
299+
%s
300+
}
301+
createdBy {
302+
%s
303+
}
304+
%s
305+
""" % (query.results_query_part(Project),
306+
query.results_query_part(User),
307+
query.results_query_part(BulkImportRequest))
308+
309+
@classmethod
310+
def __build_bulk_import_request_from_result(
311+
cls, client: Client, result: dict) -> 'BulkImportRequest':
312+
project = result.pop("project")
313+
user = result.pop("createdBy")
314+
bulk_import_request = BulkImportRequest(client, result)
315+
if project is not None:
316+
bulk_import_request.__project = Project( # type: ignore
317+
client, project)
318+
if user is not None:
319+
bulk_import_request.__user = User(client, user) # type: ignore
320+
return bulk_import_request

labelbox/schema/enums.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
from enum import Enum
2+
3+
4+
class BulkImportRequestState(Enum):
5+
RUNNING = "RUNNING"
6+
FAILED = "FAILED"
7+
FINISHED = "FINISHED"

mypy.ini

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
[mypy-backoff.*]
2+
ignore_missing_imports = True
3+
4+
[mypy-ndjson.*]
5+
ignore_missing_imports = True

0 commit comments

Comments
 (0)