Skip to content

Commit 1f41253

Browse files
author
Matt Sokoloff
committed
use enums
1 parent edb9e18 commit 1f41253

File tree

3 files changed

+64
-26
lines changed

3 files changed

+64
-26
lines changed

labelbox/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from labelbox.schema.role import Role, ProjectRole
2222
from labelbox.schema.invite import Invite, InviteLimit
2323
from labelbox.schema.data_row_metadata import DataRowMetadataOntology
24-
from labelbox.schema.model_run import ModelRun
24+
from labelbox.schema.model_run import ModelRun, DataSplit
2525
from labelbox.schema.benchmark import Benchmark
2626
from labelbox.schema.iam_integration import IAMIntegration
2727
from labelbox.schema.resource_tag import ResourceTag

labelbox/schema/model_run.py

Lines changed: 38 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import logging
66
import requests
77
import ndjson
8+
from enum import Enum
89

910
from labelbox.pagination import PaginatedCollection
1011
from labelbox.orm.query import results_query_part
@@ -17,13 +18,27 @@
1718
logger = logging.getLogger(__name__)
1819

1920

21+
class DataSplit(Enum):
22+
TRAINING = "TRAINING"
23+
TEST = "TEST"
24+
VALIDATION = "VALIDATION"
25+
UNASSIGNED = "UNASSIGNED"
26+
27+
2028
class ModelRun(DbObject):
2129
name = Field.String("name")
2230
updated_at = Field.DateTime("updated_at")
2331
created_at = Field.DateTime("created_at")
2432
created_by_id = Field.String("created_by_id", "createdBy")
2533
model_id = Field.String("model_id")
2634

35+
class Status(Enum):
36+
EXPORTING_DATA = "EXPORTING_DATA"
37+
PREPARING_DATA = "PREPARING_DATA"
38+
TRAINING_MODEL = "TRAINING_MODEL"
39+
COMPLETE = "COMPLETE"
40+
FAILED = "FAILED"
41+
2742
def upsert_labels(self, label_ids, timeout_seconds=60):
2843
""" Adds data rows and labels to a model run
2944
Args:
@@ -90,7 +105,7 @@ def upsert_data_rows(self, data_row_ids, timeout_seconds=60):
90105
}})['MEADataRowRegistrationTaskStatus'],
91106
timeout_seconds=timeout_seconds)
92107

93-
def _wait_until_done(self, status_fn, timeout_seconds=60, sleep_time=5):
108+
def _wait_until_done(self, status_fn, timeout_seconds=120, sleep_time=5):
94109
# Do not use this function outside of the scope of upsert_data_rows or upsert_labels. It could change.
95110
original_timeout = timeout_seconds
96111
while True:
@@ -105,7 +120,6 @@ def _wait_until_done(self, status_fn, timeout_seconds=60, sleep_time=5):
105120
raise TimeoutError(
106121
f"Unable to complete import within {original_timeout} seconds."
107122
)
108-
109123
time.sleep(sleep_time)
110124

111125
def add_predictions(
@@ -162,7 +176,7 @@ def delete(self):
162176
deleteModelRuns(where: {ids: [$%s]})}""" % (ids_param, ids_param)
163177
self.client.execute(query_str, {ids_param: str(self.uid)})
164178

165-
def delete_model_run_data_rows(self, data_row_ids):
179+
def delete_model_run_data_rows(self, data_row_ids: List[str]):
166180
""" Deletes data rows from model runs.
167181
168182
Args:
@@ -183,11 +197,20 @@ def delete_model_run_data_rows(self, data_row_ids):
183197

184198
@experimental
185199
def assign_data_rows_to_split(self,
186-
data_row_ids,
187-
split,
200+
data_row_ids: List[str],
201+
split: Union[DataSplit, str],
188202
timeout_seconds=60):
189-
valid_splits = ["TRAINING", "TEST", "VALIDATION"]
190-
if split not in valid_splits:
203+
204+
split_value = split.value if isinstance(split, DataSplit) else split
205+
206+
if split_value == DataSplit.UNASSIGNED.value:
207+
raise ValueError(
208+
f"Cannot assign split value of `{DataSplit.UNASSIGNED.value}`.")
209+
210+
valid_splits = filter(lambda name: name != DataSplit.UNASSIGNED.value,
211+
DataSplit._member_names_)
212+
213+
if split_value not in valid_splits:
191214
raise ValueError(
192215
f"split must be one of : `{valid_splits}`. Found : `{split}`")
193216

@@ -198,7 +221,7 @@ def assign_data_rows_to_split(self,
198221
'modelRunId': self.uid,
199222
'data': {
200223
'assignments': [{
201-
'split': split,
224+
'split': split_value,
202225
'dataRowIds': data_row_ids
203226
}]
204227
}
@@ -216,20 +239,18 @@ def assign_data_rows_to_split(self,
216239

217240
@experimental
218241
def update_status(self,
219-
status: str,
242+
status: Union[str, "ModelRun.Status"],
220243
metadata: Optional[Dict[str, str]] = None,
221244
error_message: Optional[str] = None):
222245

223-
valid_statuses = [
224-
"EXPORTING_DATA", "PREPARING_DATA", "TRAINING_MODEL", "COMPLETE",
225-
"FAILED"
226-
]
227-
if status not in valid_statuses:
246+
status_value = status.value if isinstance(status,
247+
ModelRun.Status) else status
248+
if status_value not in ModelRun.Status._member_names_:
228249
raise ValueError(
229-
f"Status must be one of : `{valid_statuses}`. Found : `{status}`"
250+
f"Status must be one of : `{ModelRun.Status._member_names_}`. Found : `{status_value}`"
230251
)
231252

232-
data: Dict[str, Any] = {'status': status}
253+
data: Dict[str, Any] = {'status': status_value}
233254
if error_message:
234255
data['errorMessage'] = error_message
235256

@@ -298,7 +319,7 @@ def export_labels(
298319
class ModelRunDataRow(DbObject):
299320
label_id = Field.String("label_id")
300321
model_run_id = Field.String("model_run_id")
301-
data_split = Field.String("data_split")
322+
data_split = Field.Enum(DataSplit, "data_split")
302323
data_row = Relationship.ToOne("DataRow", False, cache=True)
303324

304325
def __init__(self, client, model_id, *args, **kwargs):

tests/integration/annotation_import/test_model_run.py

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import pytest
44

55
from collections import Counter
6+
from labelbox import DataSplit, ModelRun
67

78

89
def test_model_run(client, configured_project_with_label, rand_gen):
@@ -122,6 +123,16 @@ def get_model_run_status():
122123
assert model_run_status['metadata'] == {**metadata, **extra_metadata}
123124
assert model_run_status['errorMessage'] == errorMessage
124125

126+
status = ModelRun.Status.FAILED
127+
model_run_with_model_run_data_rows.update_status(status, metadata,
128+
errorMessage)
129+
model_run_status = get_model_run_status()
130+
assert model_run_status['status'] == status.value
131+
132+
with pytest.raises(ValueError):
133+
model_run_with_model_run_data_rows.update_status(
134+
"INVALID", metadata, errorMessage)
135+
125136

126137
def test_model_run_split_assignment(model_run, dataset, image_url):
127138
n_data_rows = 10
@@ -132,13 +143,19 @@ def test_model_run_split_assignment(model_run, dataset, image_url):
132143

133144
model_run.upsert_data_rows(data_row_ids)
134145

135-
for split in ["TRAINING", "TEST", "VALIDATION"]:
136-
model_run.assign_data_rows_to_split(data_row_ids[:(n_data_rows // 2)],
137-
split)
138-
counts = Counter()
139-
for data_row in model_run.model_run_data_rows():
140-
counts[data_row.data_split] += 1
141-
assert counts[split] == n_data_rows // 2
142-
143146
with pytest.raises(ValueError):
144147
model_run.assign_data_rows_to_split(data_row_ids, "INVALID SPLIT")
148+
149+
with pytest.raises(ValueError):
150+
model_run.assign_data_rows_to_split(data_row_ids, DataSplit.UNASSIGNED)
151+
152+
for split in ["TRAINING", "TEST", "VALIDATION", *DataSplit]:
153+
if split == DataSplit.UNASSIGNED:
154+
continue
155+
156+
model_run.assign_data_rows_to_split(data_row_ids, split)
157+
counts = Counter()
158+
for data_row in model_run.model_run_data_rows():
159+
counts[data_row.data_split.value] += 1
160+
split = split.value if isinstance(split, DataSplit) else split
161+
assert counts[split] == n_data_rows

0 commit comments

Comments
 (0)