Skip to content

Commit 9963a0f

Browse files
author
Val Brodsky
committed
Add logic to upload labels by project id to model run
1 parent 5af2944 commit 9963a0f

File tree

1 file changed

+49
-6
lines changed

1 file changed

+49
-6
lines changed

labelbox/schema/model_run.py

Lines changed: 49 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -47,22 +47,41 @@ class Status(Enum):
4747
COMPLETE = "COMPLETE"
4848
FAILED = "FAILED"
4949

50-
def upsert_labels(self, label_ids, timeout_seconds=3600):
50+
def upsert_labels(self,
51+
label_ids: Optional[List[str]] = None,
52+
project_id: Optional[str] = None,
53+
timeout_seconds=3600):
5154
""" Adds data rows and labels to a Model Run
5255
Args:
5356
label_ids (list): label ids to insert
57+
project_id (string): project uuid, all project labels will be uploaded
5458
timeout_seconds (float): Max waiting time, in seconds.
5559
Returns:
5660
ID of newly generated async task
5761
"""
5862

59-
if len(label_ids) < 1:
60-
raise ValueError("Must provide at least one label id")
63+
use_label_ids = label_ids is not None and len(label_ids) > 0
64+
use_project_id = project_id is not None
6165

66+
if not use_label_ids and not use_project_id:
67+
raise ValueError("Must provide at least one label id or a project id")
68+
69+
if use_label_ids and use_project_id:
70+
raise ValueError("Must only one of label ids, project id")
71+
72+
if use_label_ids:
73+
return self._upsert_labels_by_label_ids(label_ids)
74+
else: # use_project_id
75+
return self._upsert_labels_by_project_id(project_id)
76+
77+
78+
79+
def _upsert_labels_by_label_ids(self,
80+
label_ids: List[str]):
6281
mutation_name = 'createMEAModelRunLabelRegistrationTask'
6382
create_task_query_str = """mutation createMEAModelRunLabelRegistrationTaskPyApi($modelRunId: ID!, $labelIds : [ID!]!) {
64-
%s(where : { id : $modelRunId}, data : {labelIds: $labelIds})}
65-
""" % (mutation_name)
83+
%s(where : { id : $modelRunId}, data : {labelIds: $labelIds})}
84+
""" % (mutation_name)
6685

6786
res = self.client.execute(create_task_query_str, {
6887
'modelRunId': self.uid,
@@ -78,7 +97,31 @@ def upsert_labels(self, label_ids, timeout_seconds=3600):
7897
status_query_str, {'where': {
7998
'id': task_id
8099
}})['MEALabelRegistrationTaskStatus'],
81-
timeout_seconds=timeout_seconds)
100+
timeout_seconds=timeout_seconds)
101+
102+
def _upsert_labels_by_project_id(self,
103+
project_id: str):
104+
mutation_name = 'createMEAModelRunProjectLabelRegistrationTask'
105+
create_task_query_str = """mutation createMEAModelRunProjectLabelRegistrationTaskPyApi($modelRunId: ID!, $projectId : ID!) {
106+
%s(where : { modelRunId : $modelRunId, projectId: $projectId}}
107+
""" % (mutation_name)
108+
109+
res = self.client.execute(create_task_query_str, {
110+
'modelRunId': self.uid,
111+
'projectId': project_id
112+
})
113+
task_id = res[mutation_name]
114+
115+
status_query_str = """query MEALabelRegistrationTaskStatusPyApi($where: WhereUniqueIdInput!){
116+
MEALabelRegistrationTaskStatus(where: $where) {status errorMessage}
117+
}
118+
"""
119+
return self._wait_until_done(lambda: self.client.execute(
120+
status_query_str, {'where': {
121+
'id': task_id
122+
}})['MEALabelRegistrationTaskStatus'],
123+
timeout_seconds=timeout_seconds)
124+
82125

83126
def upsert_data_rows(self,
84127
data_row_ids=None,

0 commit comments

Comments
 (0)