Skip to content

Commit 6b147ff

Browse files
authored
Merge pull request #593 from Labelbox/ms/fetch-result
improve result fetching experience
2 parents 5c498a6 + fa5784e commit 6b147ff

File tree

4 files changed

+89
-14
lines changed

4 files changed

+89
-14
lines changed

labelbox/schema/project.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -327,7 +327,7 @@ def _validate_datetime(string_date: str) -> bool:
327327
try:
328328
datetime.strptime(string_date, "%Y-%m-%d")
329329
except ValueError:
330-
raise ValueError(f"""Incorrect format for: {string_date}.
330+
raise ValueError(f"""Incorrect format for: {string_date}.
331331
Format must be \"YYYY-MM-DD\"""")
332332
return True
333333

labelbox/schema/task.py

Lines changed: 46 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import logging
2+
import json
23
import requests
34
import time
4-
from typing import TYPE_CHECKING, Optional
5+
from typing import TYPE_CHECKING, TypeVar, Callable, Optional, Dict, Any, List
56

67
from labelbox.exceptions import ResourceNotFoundError
78
from labelbox.orm.db_object import DbObject
@@ -10,6 +11,11 @@
1011
if TYPE_CHECKING:
1112
from labelbox import User
1213

14+
def lru_cache() -> Callable[..., Callable[..., Dict[str, Any]]]:
15+
pass
16+
else:
17+
from functools import lru_cache
18+
1319
logger = logging.getLogger(__name__)
1420

1521

@@ -32,7 +38,7 @@ class Task(DbObject):
3238
name = Field.String("name")
3339
status = Field.String("status")
3440
completion_percentage = Field.Float("completion_percentage")
35-
result = Field.String("result")
41+
result_url = Field.String("result_url", "result")
3642
_user: Optional["User"] = None
3743

3844
# Relationships
@@ -68,12 +74,44 @@ def wait_till_done(self, timeout_seconds=300) -> None:
6874
time.sleep(sleep_time_seconds)
6975
self.refresh()
7076

71-
def errors(self):
77+
@property
78+
def errors(self) -> Optional[Dict[str, Any]]:
7279
""" Downloads the result file from Task
7380
"""
74-
if self.status == "FAILED" and self.result:
75-
response = requests.get(self.result)
76-
response.raise_for_status()
77-
data = response.json()
78-
return data.get('error')
81+
if self.status == "FAILED":
82+
result = self._fetch_remote_json()
83+
return result['error']
7984
return None
85+
86+
@property
87+
def result(self) -> List[Dict[str, Any]]:
88+
""" Fetch the result for a task
89+
"""
90+
if self.status == "FAILED":
91+
raise ValueError(f"Job failed. Errors : {self.errors}")
92+
else:
93+
result = self._fetch_remote_json()
94+
return [{
95+
'id': data_row['id'],
96+
'external_id': data_row.get('externalId'),
97+
'row_data': data_row['rowData']
98+
} for data_row in result['createdDataRows']]
99+
100+
@lru_cache()
101+
def _fetch_remote_json(self) -> Dict[str, Any]:
102+
""" Function for fetching and caching the result data.
103+
"""
104+
if self.name != 'JSON Import':
105+
raise ValueError(
106+
"Task result is only supported for `JSON Import` tasks."
107+
" Download task.result_url manually to access the result for other tasks."
108+
)
109+
self.wait_till_done(timeout_seconds=600)
110+
if self.status == "IN_PROGRESS":
111+
raise ValueError(
112+
"Job status still in `IN_PROGRESS`. The result is not available. Call task.wait_till_done() with a larger timeout or contact support."
113+
)
114+
115+
response = requests.get(self.result_url)
116+
response.raise_for_status()
117+
return response.json()

tests/integration/test_data_rows.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -543,3 +543,19 @@ def test_delete_data_row_attachment(datarow, image_url):
543543
attachment.delete()
544544

545545
assert len(list(datarow.attachments())) == 0
546+
547+
548+
def test_create_data_rows_result(client, dataset, image_url):
549+
task = dataset.create_data_rows([
550+
{
551+
DataRow.row_data: image_url,
552+
DataRow.external_id: "row1",
553+
},
554+
{
555+
DataRow.row_data: image_url,
556+
DataRow.external_id: "row1",
557+
},
558+
])
559+
assert task.errors is None
560+
for result in task.result:
561+
client.get_data_row(result['id'])

tests/integration/test_task.py

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1-
from labelbox import DataRow
1+
import pytest
2+
3+
from labelbox import DataRow, Task
24
from labelbox.schema.data_row_metadata import DataRowMetadataField
35

46
EMBEDDING_SCHEMA_ID = "ckpyije740000yxdk81pbgjdc"
@@ -22,11 +24,14 @@ def test_task_errors(dataset, image_url):
2224
assert task in client.get_user().created_tasks()
2325
task.wait_till_done()
2426
assert task.status == "FAILED"
25-
assert task.errors() is not None
26-
assert 'message' in task.errors()
27+
assert task.errors is not None
28+
assert 'message' in task.errors
29+
with pytest.raises(Exception) as exc_info:
30+
task.result
31+
assert str(exc_info.value).startswith("Job failed. Errors : {")
2732

2833

29-
def test_task_success(dataset, image_url):
34+
def test_task_success_json(dataset, image_url):
3035
client = dataset.client
3136
task = dataset.create_data_rows([
3237
{
@@ -36,4 +41,20 @@ def test_task_success(dataset, image_url):
3641
assert task in client.get_user().created_tasks()
3742
task.wait_till_done()
3843
assert task.status == "COMPLETE"
39-
assert task.errors() is None
44+
assert task.errors is None
45+
assert task.result is not None
46+
assert len(task.result)
47+
48+
49+
def test_task_success_label_export(client, configured_project_with_label):
50+
project, _, _, _ = configured_project_with_label
51+
project.export_labels()
52+
user = client.get_user()
53+
task = None
54+
for task in user.created_tasks():
55+
if task.name != 'JSON Import':
56+
break
57+
58+
with pytest.raises(ValueError) as exc_info:
59+
task.result
60+
assert str(exc_info.value).startswith("Task result is only supported for")

0 commit comments

Comments
 (0)