Skip to content

Commit 32c30f2

Browse files
committed
Update tests
1 parent dc7d985 commit 32c30f2

File tree

3 files changed

+63
-30
lines changed

3 files changed

+63
-30
lines changed

labelbox/schema/export_params.py

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,23 @@
11
import sys
22

33
from typing import Optional
4-
if sys.version_info >= (3, 8):
5-
from typing import TypedDict
6-
else:
7-
from typing_extensions import TypedDict
84

5+
from pydantic import BaseModel
96

10-
class DataRowParams(TypedDict):
11-
include_data_row_details: Optional[bool]
12-
include_media_attributes: Optional[bool]
13-
include_metadata_fields: Optional[bool]
14-
include_attachments: Optional[bool]
157

8+
class DataRowParams(BaseModel):
9+
include_data_row_details: Optional[bool] = None
10+
include_media_attributes: Optional[bool] = None
11+
include_metadata_fields: Optional[bool] = None
12+
include_attachments: Optional[bool] = None
1613

17-
class ProjectExportParams(DataRowParams):
18-
include_project_details: Optional[bool]
19-
include_label_details: Optional[bool]
20-
include_performance_details: Optional[bool]
2114

15+
class ProjectExportParams(BaseModel):
16+
include_project_details: Optional[bool] = None
17+
include_label_details: Optional[bool] = None
18+
include_performance_details: Optional[bool] = None
2219

23-
class ModelRunExportParams(DataRowParams):
20+
21+
class ModelRunExportParams(BaseModel):
2422
# TODO: Add model run fields
2523
pass

labelbox/schema/model_run.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from labelbox.orm.db_object import DbObject, experimental
1515
from labelbox.schema.export_params import ModelRunExportParams
1616
from labelbox.schema.task import Task
17-
from labelbox.schema.user import User # type: ignore
17+
from labelbox.schema.user import User
1818

1919
if TYPE_CHECKING:
2020
from labelbox import MEAPredictionImport
@@ -450,25 +450,24 @@ def export_labels(
450450
time.sleep(sleep_time)
451451

452452
"""
453-
Creates a model run export task with the given filter and returns the task.
453+
Creates a model run export task with the given params and returns the task.
454454
455-
>>> export_task = export_labels_v2("my_export_task", filter={"media_attributes": True})
455+
>>> export_task = export_v2("my_export_task", params={"media_attributes": True})
456456
457457
"""
458458

459-
def export_labels_v2(self, task_name: str,
460-
params: Optional[ModelRunExportParams]) -> Task:
459+
def export_v2(self, task_name: str,
460+
params: Optional[ModelRunExportParams]) -> Task:
461461
_params = params or {}
462-
mutation_name = "exportDataRows"
463-
create_task_query_str = """mutation exportDataRowsPyApi($input: ExportDataRowsInput!){
462+
mutation_name = "exportDataRowsInModelRun"
463+
create_task_query_str = """mutation exportDataRowsInModelRunPyApi($input: ExportDataRowsInModelRunInput!){
464464
%s(input: $input) {taskId} }
465465
""" % (mutation_name)
466466
params = {
467467
"input": {
468468
"taskName": task_name,
469469
"filters": {
470-
"modelRunIds": [self.uid],
471-
"projectIds": []
470+
"modelRunId": self.uid
472471
},
473472
"params": {
474473
"includeAttachments":
@@ -480,8 +479,6 @@ def export_labels_v2(self, task_name: str,
480479
"includeDataRowDetails":
481480
_params.get('include_data_row_details', False),
482481
# Arguments locked based on exectuion context
483-
"includeModelRuns":
484-
True,
485482
"includeProjectDetails":
486483
False,
487484
"includeLabels":

tests/integration/annotation_import/test_model_run.py

Lines changed: 43 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
1+
import json
12
import time
23
import os
34
import pytest
45

56
from collections import Counter
7+
8+
import requests
69
from labelbox import DataSplit, ModelRun
710

811

@@ -114,15 +117,50 @@ def test_model_run_export_labels(model_run_with_model_run_data_rows):
114117
assert len(labels) == 3
115118

116119

117-
def test_model_run_export_labels_v2(model_run_with_model_run_data_rows):
120+
def test_model_run_export_v2(model_run_with_model_run_data_rows,
121+
configured_project):
118122
task_name = "test_task"
119-
params = {"media_attributes": True}
120-
task = model_run_with_model_run_data_rows.export_labels_v2(task_name,
121-
params=params)
123+
124+
media_attributes = True
125+
params = {"media_attributes": media_attributes}
126+
task = model_run_with_model_run_data_rows.export_v2(task_name,
127+
params=params)
122128
assert task.name == task_name
123129
task.wait_till_done()
124130
assert task.status == "COMPLETE"
125-
# TODO: Download result and check it
131+
132+
def download_result(result_url):
133+
response = requests.get(result_url)
134+
response.raise_for_status()
135+
data = [json.loads(line) for line in response.text.splitlines()]
136+
return data
137+
138+
task_results = download_result(task.result_url)
139+
140+
label_ids = [label.uid for label in configured_project.labels()]
141+
label_ids_set = set(label_ids)
142+
143+
assert len(task_results) == len(label_ids)
144+
for task_result in task_results:
145+
assert len(task_result['errors']) == 0
146+
# Check export param handling
147+
if media_attributes:
148+
assert 'media_attributes' in task_result and task_result[
149+
'media_attributes'] is not None
150+
else:
151+
assert 'media_attributes' not in task_result or task_result[
152+
'media_attributes'] is None
153+
model_run = task_result['models'][
154+
model_run_with_model_run_data_rows.model_id]['model_runs'][
155+
model_run_with_model_run_data_rows.uid]
156+
task_label_ids_set = set(
157+
map(lambda label: label['id'], model_run['labels']))
158+
task_prediction_ids_set = set(
159+
map(lambda prediction: prediction['id'], model_run['predictions']))
160+
for label_id in task_label_ids_set:
161+
assert label_id in label_ids_set
162+
for prediction_id in task_prediction_ids_set:
163+
assert prediction_id in label_ids_set
126164

127165

128166
@pytest.mark.skipif(condition=os.environ['LABELBOX_TEST_ENVIRON'] == "onprem",

0 commit comments

Comments
 (0)