Skip to content

Commit 3e701a4

Browse files
authored
Merge pull request #959 from Labelbox/VB/upsert-labels-from-project_AL-2960
Add logic to upload labels by project id to model run
2 parents fd45975 + 978cb46 commit 3e701a4

File tree

3 files changed

+103
-5
lines changed

3 files changed

+103
-5
lines changed

labelbox/schema/model_run.py

Lines changed: 49 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -47,22 +47,43 @@ class Status(Enum):
4747
COMPLETE = "COMPLETE"
4848
FAILED = "FAILED"
4949

50-
def upsert_labels(self, label_ids, timeout_seconds=3600):
50+
def upsert_labels(self,
51+
label_ids: Optional[List[str]] = None,
52+
project_id: Optional[str] = None,
53+
timeout_seconds=3600):
5154
""" Adds data rows and labels to a Model Run
5255
Args:
5356
label_ids (list): label ids to insert
57+
project_id (string): project uuid, all project labels will be uploaded
58+
Either label_ids OR project_id is required but NOT both
5459
timeout_seconds (float): Max waiting time, in seconds.
5560
Returns:
5661
ID of newly generated async task
62+
5763
"""
5864

59-
if len(label_ids) < 1:
60-
raise ValueError("Must provide at least one label id")
65+
use_label_ids = label_ids is not None and len(label_ids) > 0
66+
use_project_id = project_id is not None
67+
68+
if not use_label_ids and not use_project_id:
69+
raise ValueError(
70+
"Must provide at least one label id or a project id")
71+
72+
if use_label_ids and use_project_id:
73+
raise ValueError("Must only one of label ids, project id")
74+
75+
if use_label_ids:
76+
return self._upsert_labels_by_label_ids(label_ids, timeout_seconds)
77+
else: # use_project_id
78+
return self._upsert_labels_by_project_id(project_id,
79+
timeout_seconds)
6180

81+
def _upsert_labels_by_label_ids(self, label_ids: List[str],
82+
timeout_seconds: int):
6283
mutation_name = 'createMEAModelRunLabelRegistrationTask'
6384
create_task_query_str = """mutation createMEAModelRunLabelRegistrationTaskPyApi($modelRunId: ID!, $labelIds : [ID!]!) {
64-
%s(where : { id : $modelRunId}, data : {labelIds: $labelIds})}
65-
""" % (mutation_name)
85+
%s(where : { id : $modelRunId}, data : {labelIds: $labelIds})}
86+
""" % (mutation_name)
6687

6788
res = self.client.execute(create_task_query_str, {
6889
'modelRunId': self.uid,
@@ -80,6 +101,29 @@ def upsert_labels(self, label_ids, timeout_seconds=3600):
80101
}})['MEALabelRegistrationTaskStatus'],
81102
timeout_seconds=timeout_seconds)
82103

104+
def _upsert_labels_by_project_id(self, project_id: str,
105+
timeout_seconds: int):
106+
mutation_name = 'createMEAModelRunProjectLabelRegistrationTask'
107+
create_task_query_str = """mutation createMEAModelRunProjectLabelRegistrationTaskPyApi($modelRunId: ID!, $projectId : ID!) {
108+
%s(where : { modelRunId : $modelRunId, projectId: $projectId})}
109+
""" % (mutation_name)
110+
111+
res = self.client.execute(create_task_query_str, {
112+
'modelRunId': self.uid,
113+
'projectId': project_id
114+
})
115+
task_id = res[mutation_name]
116+
117+
status_query_str = """query MEALabelRegistrationTaskStatusPyApi($where: WhereUniqueIdInput!){
118+
MEALabelRegistrationTaskStatus(where: $where) {status errorMessage}
119+
}
120+
"""
121+
return self._wait_until_done(lambda: self.client.execute(
122+
status_query_str, {'where': {
123+
'id': task_id
124+
}})['MEALabelRegistrationTaskStatus'],
125+
timeout_seconds=timeout_seconds)
126+
83127
def upsert_data_rows(self,
84128
data_row_ids=None,
85129
global_keys=None,

tests/integration/annotation_import/conftest.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -440,6 +440,22 @@ def model_run_with_model_run_data_rows(client, configured_project,
440440
# TODO: Delete resources when that is possible ..
441441

442442

443+
@pytest.fixture
444+
def model_run_with_all_project_labels(client, configured_project,
445+
model_run_predictions, model_run):
446+
configured_project.enable_model_assisted_labeling()
447+
448+
upload_task = LabelImport.create_from_objects(
449+
client, configured_project.uid, f"label-import-{uuid.uuid4()}",
450+
model_run_predictions)
451+
upload_task.wait_until_done()
452+
model_run.upsert_labels(project_id=configured_project.uid)
453+
time.sleep(3)
454+
yield model_run
455+
model_run.delete()
456+
# TODO: Delete resources when that is possible ..
457+
458+
443459
class AnnotationImportTestHelpers:
444460

445461
@classmethod

tests/integration/annotation_import/test_mea_prediction_import.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,44 @@ def test_create_from_objects(model_run_with_model_run_data_rows,
3838
annotation_import.wait_until_done()
3939

4040

41+
def test_create_from_objects_all_project_labels(
42+
model_run_with_all_project_labels, object_predictions,
43+
annotation_import_test_helpers):
44+
name = str(uuid.uuid4())
45+
46+
annotation_import = model_run_with_all_project_labels.add_predictions(
47+
name=name, predictions=object_predictions)
48+
49+
assert annotation_import.model_run_id == model_run_with_all_project_labels.uid
50+
annotation_import_test_helpers.check_running_state(annotation_import, name)
51+
annotation_import_test_helpers.assert_file_content(
52+
annotation_import.input_file_url, object_predictions)
53+
annotation_import.wait_until_done()
54+
55+
56+
def test_model_run_project_labels(model_run_with_all_project_labels,
57+
model_run_predictions):
58+
model_run = model_run_with_all_project_labels
59+
model_run_exported_labels = model_run.export_labels(download=True)
60+
labels_indexed_by_schema_id = {}
61+
for label in model_run_exported_labels:
62+
# assuming exported array of label 'objects' has only one label per data row... as usually is when there are no label revisions
63+
schema_id = label['Label']['objects'][0]['schemaId']
64+
labels_indexed_by_schema_id[schema_id] = label
65+
66+
assert (len(
67+
labels_indexed_by_schema_id.keys())) == len(model_run_predictions)
68+
69+
# making sure the labels are in this model run are all labels uploaded to the project
70+
# by comparing some 'immutable' attributes
71+
for expected_label in model_run_predictions:
72+
schema_id = expected_label['schemaId']
73+
actual_label = labels_indexed_by_schema_id[schema_id]
74+
assert actual_label['Label']['objects'][0]['title'] == expected_label[
75+
'name']
76+
assert actual_label['DataRow ID'] == expected_label['dataRow']['id']
77+
78+
4179
def test_create_from_label_objects(model_run_with_model_run_data_rows,
4280
object_predictions,
4381
annotation_import_test_helpers):

0 commit comments

Comments
 (0)