Skip to content

Commit 056b3fe

Browse files
author
Matt Sokoloff
committed
add partition and model run status
1 parent 2fe156f commit 056b3fe

File tree

4 files changed

+72
-3
lines changed

4 files changed

+72
-3
lines changed

labelbox/client.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -379,7 +379,8 @@ def upload_data(self,
379379

380380
if not file_data or not file_data.get("uploadFile", None):
381381
raise labelbox.exceptions.LabelboxError(
382-
"Failed to upload, message: %s" % file_data.get("error", None))
382+
"Failed to upload, message: %s" % file_data or
383+
file_data.get("error"))
383384

384385
return file_data["uploadFile"]["url"]
385386

@@ -918,4 +919,4 @@ def get_model_run(self, model_run_id: str) -> ModelRun:
918919
Returns:
919920
A ModelRun object.
920921
"""
921-
return self._get_single(Entity.ModelRun, model_run_id)
922+
return self._get_single(Entity.ModelRun, model_run_id)

labelbox/data/serialization/labelbox_v1/label.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,7 @@ class LBV1Label(BaseModel):
141141
has_open_issues: Optional[float] = Extra('Has Open Issues')
142142
skipped: Optional[bool] = Extra('Skipped')
143143
media_type: Optional[str] = Extra('media_type')
144+
data_split: Optional[str] = Extra('Data Split')
144145

145146
def to_common(self) -> Label:
146147
if isinstance(self.label, list):

labelbox/schema/model_run.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,38 @@ def delete_model_run_data_rows(self, data_row_ids):
180180
data_row_ids_param: data_row_ids
181181
})
182182

183+
@experimental
184+
def update_status(self,
185+
status: str,
186+
metadata: Optional[Dict[str, str]] = None,
187+
error_message: Optional[Dict[str, str]] = None):
188+
189+
valid_statuses = [
190+
"EXPORTING_DATA", "PREPARING_DATA", "TRAINING_MODEL", "COMPLETE",
191+
"FAILED"
192+
]
193+
if status not in valid_statuses:
194+
raise ValueError(
195+
f"Status must be one of : `{valid_statuses}`. Found : `{status}`"
196+
)
197+
198+
data = {'status': status}
199+
if error_message:
200+
data['errorMessage'] = error_message
201+
202+
if metadata:
203+
data['metadata'] = metadata
204+
205+
self.client.execute(
206+
"""mutation setPipelineStatusPyApi($modelRunId: ID!, $data: UpdateTrainingPipelineInput!){
207+
updateTrainingPipeline(modelRun: {id : $modelRunId}, data: $data){status}
208+
}
209+
""", {
210+
'modelRunId': self.uid,
211+
'data': data
212+
},
213+
experimental=True)
214+
183215
@experimental
184216
def export_labels(
185217
self,
@@ -196,7 +228,7 @@ def export_labels(
196228
Returns:
197229
URL of the data file with this ModelRun's labels.
198230
If download=True, this instead returns the contents as NDJSON format.
199-
If the server didn't generate during the `timeout_seconds` period,
231+
If the server didn't generate during the `timeout_seconds` period,
200232
None is returned.
201233
"""
202234
sleep_time = 2

tests/integration/annotation_import/test_model_run.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from labelbox.orm.db_object import experimental
2+
from labelbox.schema.model_run import ModelRun
13
import time
24

35

@@ -82,3 +84,36 @@ def test_model_run_upsert_data_rows_with_existing_labels(
8284
def test_model_run_export_labels(model_run_with_model_run_data_rows):
8385
labels = model_run_with_model_run_data_rows.export_labels(download=True)
8486
assert len(labels) == 3
87+
88+
89+
def test_model_run_status(model_run_with_model_run_data_rows: ModelRun):
90+
91+
def get_model_run_status():
92+
return model_run_with_model_run_data_rows.client.execute(
93+
"""query trainingPipelinePyApi($modelRunId: ID!) {
94+
trainingPipeline(where: {id : $modelRunId}) {status, errorMessage, metadata}}
95+
""", {'modelRunId': model_run_with_model_run_data_rows.uid},
96+
experimental=True)['trainingPipeline']
97+
98+
model_run_status = get_model_run_status()
99+
assert model_run_status['status'] is None
100+
assert model_run_status['metadata'] is None
101+
assert model_run_status['errorMessage'] is None
102+
103+
status = "COMPLETE"
104+
metadata = {'key1': 'value1'}
105+
errorMessage = "an error"
106+
model_run_with_model_run_data_rows.update_status(status, metadata,
107+
errorMessage)
108+
109+
model_run_status = get_model_run_status()
110+
assert model_run_status['status'] == status
111+
assert model_run_status['metadata'] == metadata
112+
assert model_run_status['errorMessage'] == errorMessage
113+
114+
extra_metadata = {'key2': 'value2'}
115+
model_run_with_model_run_data_rows.update_status(status, extra_metadata)
116+
model_run_status = get_model_run_status()
117+
assert model_run_status['status'] == status
118+
assert model_run_status['metadata'] == {**metadata, **extra_metadata}
119+
assert model_run_status['errorMessage'] == errorMessage

0 commit comments

Comments
 (0)