55import logging
66import requests
77import ndjson
8+ from enum import Enum
89
910from labelbox .pagination import PaginatedCollection
1011from labelbox .orm .query import results_query_part
1718logger = logging .getLogger (__name__ )
1819
1920
21+ class DataSplit (Enum ):
22+ TRAINING = "TRAINING"
23+ TEST = "TEST"
24+ VALIDATION = "VALIDATION"
25+ UNASSIGNED = "UNASSIGNED"
26+
27+
2028class ModelRun (DbObject ):
2129 name = Field .String ("name" )
2230 updated_at = Field .DateTime ("updated_at" )
2331 created_at = Field .DateTime ("created_at" )
2432 created_by_id = Field .String ("created_by_id" , "createdBy" )
2533 model_id = Field .String ("model_id" )
2634
35+ class Status (Enum ):
36+ EXPORTING_DATA = "EXPORTING_DATA"
37+ PREPARING_DATA = "PREPARING_DATA"
38+ TRAINING_MODEL = "TRAINING_MODEL"
39+ COMPLETE = "COMPLETE"
40+ FAILED = "FAILED"
41+
2742 def upsert_labels (self , label_ids , timeout_seconds = 60 ):
2843 """ Adds data rows and labels to a model run
2944 Args:
@@ -90,7 +105,7 @@ def upsert_data_rows(self, data_row_ids, timeout_seconds=60):
90105 }})['MEADataRowRegistrationTaskStatus' ],
91106 timeout_seconds = timeout_seconds )
92107
93- def _wait_until_done (self , status_fn , timeout_seconds = 60 , sleep_time = 5 ):
108+ def _wait_until_done (self , status_fn , timeout_seconds = 120 , sleep_time = 5 ):
94109 # Do not use this function outside of the scope of upsert_data_rows or upsert_labels. It could change.
95110 original_timeout = timeout_seconds
96111 while True :
@@ -105,7 +120,6 @@ def _wait_until_done(self, status_fn, timeout_seconds=60, sleep_time=5):
105120 raise TimeoutError (
106121 f"Unable to complete import within { original_timeout } seconds."
107122 )
108-
109123 time .sleep (sleep_time )
110124
111125 def add_predictions (
@@ -162,7 +176,7 @@ def delete(self):
162176 deleteModelRuns(where: {ids: [$%s]})}""" % (ids_param , ids_param )
163177 self .client .execute (query_str , {ids_param : str (self .uid )})
164178
165- def delete_model_run_data_rows (self , data_row_ids ):
179+ def delete_model_run_data_rows (self , data_row_ids : List [ str ] ):
166180 """ Deletes data rows from model runs.
167181
168182 Args:
@@ -183,11 +197,20 @@ def delete_model_run_data_rows(self, data_row_ids):
183197
184198 @experimental
185199 def assign_data_rows_to_split (self ,
186- data_row_ids ,
187- split ,
200+ data_row_ids : List [ str ] ,
201+ split : Union [ DataSplit , str ] ,
188202 timeout_seconds = 60 ):
189- valid_splits = ["TRAINING" , "TEST" , "VALIDATION" ]
190- if split not in valid_splits :
203+
204+ split_value = split .value if isinstance (split , DataSplit ) else split
205+
206+ if split_value == DataSplit .UNASSIGNED .value :
207+ raise ValueError (
208+ f"Cannot assign split value of `{ DataSplit .UNASSIGNED .value } `." )
209+
210+ valid_splits = filter (lambda name : name != DataSplit .UNASSIGNED .value ,
211+ DataSplit ._member_names_ )
212+
213+ if split_value not in valid_splits :
191214 raise ValueError (
192215 f"split must be one of : `{ valid_splits } `. Found : `{ split } `" )
193216
@@ -198,7 +221,7 @@ def assign_data_rows_to_split(self,
198221 'modelRunId' : self .uid ,
199222 'data' : {
200223 'assignments' : [{
201- 'split' : split ,
224+ 'split' : split_value ,
202225 'dataRowIds' : data_row_ids
203226 }]
204227 }
@@ -216,20 +239,18 @@ def assign_data_rows_to_split(self,
216239
217240 @experimental
218241 def update_status (self ,
219- status : str ,
242+ status : Union [ str , "ModelRun.Status" ] ,
220243 metadata : Optional [Dict [str , str ]] = None ,
221244 error_message : Optional [str ] = None ):
222245
223- valid_statuses = [
224- "EXPORTING_DATA" , "PREPARING_DATA" , "TRAINING_MODEL" , "COMPLETE" ,
225- "FAILED"
226- ]
227- if status not in valid_statuses :
246+ status_value = status .value if isinstance (status ,
247+ ModelRun .Status ) else status
248+ if status_value not in ModelRun .Status ._member_names_ :
228249 raise ValueError (
229- f"Status must be one of : `{ valid_statuses } `. Found : `{ status } `"
250+ f"Status must be one of : `{ ModelRun . Status . _member_names_ } `. Found : `{ status_value } `"
230251 )
231252
232- data : Dict [str , Any ] = {'status' : status }
253+ data : Dict [str , Any ] = {'status' : status_value }
233254 if error_message :
234255 data ['errorMessage' ] = error_message
235256
@@ -298,7 +319,7 @@ def export_labels(
298319class ModelRunDataRow (DbObject ):
299320 label_id = Field .String ("label_id" )
300321 model_run_id = Field .String ("model_run_id" )
301- data_split = Field .String ( "data_split" )
322+ data_split = Field .Enum ( DataSplit , "data_split" )
302323 data_row = Relationship .ToOne ("DataRow" , False , cache = True )
303324
304325 def __init__ (self , client , model_id , * args , ** kwargs ):
0 commit comments