Skip to content

Commit 978cb46

Browse files
author
Val Brodsky
committed
PR: added extra test to verify project labels in a model run
1 parent f1633c0 commit 978cb46

File tree

3 files changed

+37
-14
lines changed

3 files changed

+37
-14
lines changed

labelbox/schema/model_run.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,11 @@ def upsert_labels(self,
5555
Args:
5656
label_ids (list): label ids to insert
5757
project_id (string): project uuid, all project labels will be uploaded
58+
Either label_ids OR project_id is required but NOT both
5859
timeout_seconds (float): Max waiting time, in seconds.
5960
Returns:
6061
ID of newly generated async task
62+
6163
"""
6264

6365
use_label_ids = label_ids is not None and len(label_ids) > 0
@@ -71,13 +73,13 @@ def upsert_labels(self,
7173
raise ValueError("Must only one of label ids, project id")
7274

7375
if use_label_ids:
74-
return self._upsert_labels_by_label_ids(label_ids)
76+
return self._upsert_labels_by_label_ids(label_ids, timeout_seconds)
7577
else: # use_project_id
76-
return self._upsert_labels_by_project_id(project_id)
78+
return self._upsert_labels_by_project_id(project_id,
79+
timeout_seconds)
7780

78-
def _upsert_labels_by_label_ids(self,
79-
label_ids: List[str],
80-
timeout_seconds=3600):
81+
def _upsert_labels_by_label_ids(self, label_ids: List[str],
82+
timeout_seconds: int):
8183
mutation_name = 'createMEAModelRunLabelRegistrationTask'
8284
create_task_query_str = """mutation createMEAModelRunLabelRegistrationTaskPyApi($modelRunId: ID!, $labelIds : [ID!]!) {
8385
%s(where : { id : $modelRunId}, data : {labelIds: $labelIds})}
@@ -99,9 +101,8 @@ def _upsert_labels_by_label_ids(self,
99101
}})['MEALabelRegistrationTaskStatus'],
100102
timeout_seconds=timeout_seconds)
101103

102-
def _upsert_labels_by_project_id(self,
103-
project_id: str,
104-
timeout_seconds=3600):
104+
def _upsert_labels_by_project_id(self, project_id: str,
105+
timeout_seconds: int):
105106
mutation_name = 'createMEAModelRunProjectLabelRegistrationTask'
106107
create_task_query_str = """mutation createMEAModelRunProjectLabelRegistrationTaskPyApi($modelRunId: ID!, $projectId : ID!) {
107108
%s(where : { modelRunId : $modelRunId, projectId: $projectId})}

tests/integration/annotation_import/conftest.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -441,9 +441,8 @@ def model_run_with_model_run_data_rows(client, configured_project,
441441

442442

443443
@pytest.fixture
444-
def model_run_with_model_run_all_project_data_rows(client, configured_project,
445-
model_run_predictions,
446-
model_run):
444+
def model_run_with_all_project_labels(client, configured_project,
445+
model_run_predictions, model_run):
447446
configured_project.enable_model_assisted_labeling()
448447

449448
upload_task = LabelImport.create_from_objects(

tests/integration/annotation_import/test_mea_prediction_import.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,20 +39,43 @@ def test_create_from_objects(model_run_with_model_run_data_rows,
3939

4040

4141
def test_create_from_objects_all_project_labels(
42-
model_run_with_model_run_all_project_data_rows, object_predictions,
42+
model_run_with_all_project_labels, object_predictions,
4343
annotation_import_test_helpers):
4444
name = str(uuid.uuid4())
4545

46-
annotation_import = model_run_with_model_run_all_project_data_rows.add_predictions(
46+
annotation_import = model_run_with_all_project_labels.add_predictions(
4747
name=name, predictions=object_predictions)
4848

49-
assert annotation_import.model_run_id == model_run_with_model_run_all_project_data_rows.uid
49+
assert annotation_import.model_run_id == model_run_with_all_project_labels.uid
5050
annotation_import_test_helpers.check_running_state(annotation_import, name)
5151
annotation_import_test_helpers.assert_file_content(
5252
annotation_import.input_file_url, object_predictions)
5353
annotation_import.wait_until_done()
5454

5555

56+
def test_model_run_project_labels(model_run_with_all_project_labels,
57+
model_run_predictions):
58+
model_run = model_run_with_all_project_labels
59+
model_run_exported_labels = model_run.export_labels(download=True)
60+
labels_indexed_by_schema_id = {}
61+
for label in model_run_exported_labels:
62+
# assuming exported array of label 'objects' has only one label per data row... as usually is when there are no label revisions
63+
schema_id = label['Label']['objects'][0]['schemaId']
64+
labels_indexed_by_schema_id[schema_id] = label
65+
66+
assert (len(
67+
labels_indexed_by_schema_id.keys())) == len(model_run_predictions)
68+
69+
# making sure the labels are in this model run are all labels uploaded to the project
70+
# by comparing some 'immutable' attributes
71+
for expected_label in model_run_predictions:
72+
schema_id = expected_label['schemaId']
73+
actual_label = labels_indexed_by_schema_id[schema_id]
74+
assert actual_label['Label']['objects'][0]['title'] == expected_label[
75+
'name']
76+
assert actual_label['DataRow ID'] == expected_label['dataRow']['id']
77+
78+
5679
def test_create_from_label_objects(model_run_with_model_run_data_rows,
5780
object_predictions,
5881
annotation_import_test_helpers):

0 commit comments

Comments
 (0)