Skip to content

Commit dbadecd

Browse files
authored
Merge pull request #840 from Labelbox/mno/al-4398
[AL-4398] Add Model Run exports
2 parents 80f07e3 + c9d8265 commit dbadecd

File tree

4 files changed

+137
-0
lines changed

4 files changed

+137
-0
lines changed

labelbox/schema/export_params.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
import sys
2+
3+
from typing import Optional
4+
if sys.version_info >= (3, 8):
5+
from typing import TypedDict
6+
else:
7+
from typing_extensions import TypedDict
8+
9+
10+
class DataRowParams(TypedDict):
11+
data_row_details: Optional[bool]
12+
media_attributes: Optional[bool]
13+
metadata_fields: Optional[bool]
14+
attachments: Optional[bool]
15+
16+
17+
class ProjectExportParams(DataRowParams):
18+
project_details: Optional[bool]
19+
label_details: Optional[bool]
20+
performance_details: Optional[bool]
21+
22+
23+
class ModelRunExportParams(DataRowParams):
24+
# TODO: Add model run fields
25+
pass

labelbox/schema/model_run.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@
1212
from labelbox.orm.query import results_query_part
1313
from labelbox.orm.model import Field, Relationship, Entity
1414
from labelbox.orm.db_object import DbObject, experimental
15+
from labelbox.schema.export_params import ModelRunExportParams
16+
from labelbox.schema.task import Task
17+
from labelbox.schema.user import User
1518

1619
if TYPE_CHECKING:
1720
from labelbox import MEAPredictionImport
@@ -446,6 +449,63 @@ def export_labels(
446449
self.uid)
447450
time.sleep(sleep_time)
448451

452+
"""
453+
Creates a model run export task with the given params and returns the task.
454+
455+
>>> export_task = export_v2("my_export_task", params={"media_attributes": True})
456+
457+
"""
458+
459+
def export_v2(self, task_name: str,
460+
params: Optional[ModelRunExportParams]) -> Task:
461+
_params = params or {}
462+
mutation_name = "exportDataRowsInModelRun"
463+
create_task_query_str = """mutation exportDataRowsInModelRunPyApi($input: ExportDataRowsInModelRunInput!){
464+
%s(input: $input) {taskId} }
465+
""" % (mutation_name)
466+
params = {
467+
"input": {
468+
"taskName": task_name,
469+
"filters": {
470+
"modelRunId": self.uid
471+
},
472+
"params": {
473+
"includeAttachments":
474+
_params.get('attachments', False),
475+
"includeMediaAttributes":
476+
_params.get('media_attributes', False),
477+
"includeMetadata":
478+
_params.get('metadata_fields', False),
479+
"includeDataRowDetails":
480+
_params.get('data_row_details', False),
481+
# Arguments locked based on exectuion context
482+
"includeProjectDetails":
483+
False,
484+
"includeLabels":
485+
False,
486+
"includePerformanceDetails":
487+
False,
488+
},
489+
}
490+
}
491+
res = self.client.execute(
492+
create_task_query_str,
493+
params,
494+
)
495+
res = res[mutation_name]
496+
task_id = res["taskId"]
497+
user: User = self.client.get_user()
498+
tasks: List[Task] = list(
499+
user.created_tasks(where=Entity.Task.uid == task_id))
500+
# Cache user in a private variable as the relationship can't be
501+
# resolved due to server-side limitations (see Task.created_by)
502+
# for more info.
503+
if len(tasks) != 1:
504+
raise ResourceNotFoundError(Entity.Task, task_id)
505+
task: Task = tasks[0]
506+
task._user = user
507+
return task
508+
449509

450510
class ModelRunDataRow(DbObject):
451511
label_id = Field.String("label_id")

labelbox/schema/task.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,9 @@ def wait_till_done(self, timeout_seconds=300) -> None:
8181
def errors(self) -> Optional[Dict[str, Any]]:
8282
""" Fetch the error associated with an import task.
8383
"""
84+
# TODO: We should handle error messages for export v2 tasks in the future.
85+
if self.name != 'JSON Import':
86+
return None
8487
if self.status == "FAILED":
8588
result = self._fetch_remote_json()
8689
return result["error"]

tests/integration/annotation_import/test_model_run.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
1+
import json
12
import time
23
import os
34
import pytest
45

56
from collections import Counter
7+
8+
import requests
69
from labelbox import DataSplit, ModelRun
710

811

@@ -114,6 +117,52 @@ def test_model_run_export_labels(model_run_with_model_run_data_rows):
114117
assert len(labels) == 3
115118

116119

120+
def test_model_run_export_v2(model_run_with_model_run_data_rows,
121+
configured_project):
122+
task_name = "test_task"
123+
124+
media_attributes = True
125+
params = {"media_attributes": media_attributes}
126+
task = model_run_with_model_run_data_rows.export_v2(task_name,
127+
params=params)
128+
assert task.name == task_name
129+
task.wait_till_done()
130+
assert task.status == "COMPLETE"
131+
132+
def download_result(result_url):
133+
response = requests.get(result_url)
134+
response.raise_for_status()
135+
data = [json.loads(line) for line in response.text.splitlines()]
136+
return data
137+
138+
task_results = download_result(task.result_url)
139+
140+
label_ids = [label.uid for label in configured_project.labels()]
141+
label_ids_set = set(label_ids)
142+
143+
assert len(task_results) == len(label_ids)
144+
for task_result in task_results:
145+
assert len(task_result['errors']) == 0
146+
# Check export param handling
147+
if media_attributes:
148+
assert 'media_attributes' in task_result and task_result[
149+
'media_attributes'] is not None
150+
else:
151+
assert 'media_attributes' not in task_result or task_result[
152+
'media_attributes'] is None
153+
model_run = task_result['models'][
154+
model_run_with_model_run_data_rows.model_id]['model_runs'][
155+
model_run_with_model_run_data_rows.uid]
156+
task_label_ids_set = set(
157+
map(lambda label: label['id'], model_run['labels']))
158+
task_prediction_ids_set = set(
159+
map(lambda prediction: prediction['id'], model_run['predictions']))
160+
for label_id in task_label_ids_set:
161+
assert label_id in label_ids_set
162+
for prediction_id in task_prediction_ids_set:
163+
assert prediction_id in label_ids_set
164+
165+
117166
@pytest.mark.skipif(condition=os.environ['LABELBOX_TEST_ENVIRON'] == "onprem",
118167
reason="does not work for onprem")
119168
def test_model_run_status(model_run_with_model_run_data_rows):

0 commit comments

Comments
 (0)