1+ # type: ignore
12from typing import TYPE_CHECKING , Dict , Iterable , Union , List , Optional , Any
23from pathlib import Path
34import os
45import time
56import logging
67import requests
78import ndjson
9+ from enum import Enum
810
911from labelbox .pagination import PaginatedCollection
1012from labelbox .orm .query import results_query_part
1719logger = logging .getLogger (__name__ )
1820
1921
22+ class DataSplit (Enum ):
23+ TRAINING = "TRAINING"
24+ TEST = "TEST"
25+ VALIDATION = "VALIDATION"
26+ UNASSIGNED = "UNASSIGNED"
27+
28+
2029class ModelRun (DbObject ):
2130 name = Field .String ("name" )
2231 updated_at = Field .DateTime ("updated_at" )
2332 created_at = Field .DateTime ("created_at" )
2433 created_by_id = Field .String ("created_by_id" , "createdBy" )
2534 model_id = Field .String ("model_id" )
2635
36+ class Status (Enum ):
37+ EXPORTING_DATA = "EXPORTING_DATA"
38+ PREPARING_DATA = "PREPARING_DATA"
39+ TRAINING_MODEL = "TRAINING_MODEL"
40+ COMPLETE = "COMPLETE"
41+ FAILED = "FAILED"
42+
2743 def upsert_labels (self , label_ids , timeout_seconds = 60 ):
2844 """ Adds data rows and labels to a model run
2945 Args:
@@ -90,8 +106,9 @@ def upsert_data_rows(self, data_row_ids, timeout_seconds=60):
90106 }})['MEADataRowRegistrationTaskStatus' ],
91107 timeout_seconds = timeout_seconds )
92108
93- def _wait_until_done (self , status_fn , timeout_seconds = 60 , sleep_time = 5 ):
109+ def _wait_until_done (self , status_fn , timeout_seconds = 120 , sleep_time = 5 ):
94110 # Do not use this function outside of the scope of upsert_data_rows or upsert_labels. It could change.
111+ original_timeout = timeout_seconds
95112 while True :
96113 res = status_fn ()
97114 if res ['status' ] == 'COMPLETE' :
@@ -102,9 +119,8 @@ def _wait_until_done(self, status_fn, timeout_seconds=60, sleep_time=5):
102119 timeout_seconds -= sleep_time
103120 if timeout_seconds <= 0 :
104121 raise TimeoutError (
105- f"Unable to complete import within { timeout_seconds } seconds."
122+ f"Unable to complete import within { original_timeout } seconds."
106123 )
107-
108124 time .sleep (sleep_time )
109125
110126 def add_predictions (
@@ -161,7 +177,7 @@ def delete(self):
161177 deleteModelRuns(where: {ids: [$%s]})}""" % (ids_param , ids_param )
162178 self .client .execute (query_str , {ids_param : str (self .uid )})
163179
164- def delete_model_run_data_rows (self , data_row_ids ):
180+ def delete_model_run_data_rows (self , data_row_ids : List [ str ] ):
165181 """ Deletes data rows from model runs.
166182
167183 Args:
@@ -180,22 +196,62 @@ def delete_model_run_data_rows(self, data_row_ids):
180196 data_row_ids_param : data_row_ids
181197 })
182198
199+ @experimental
200+ def assign_data_rows_to_split (self ,
201+ data_row_ids : List [str ],
202+ split : Union [DataSplit , str ],
203+ timeout_seconds = 120 ):
204+
205+ split_value = split .value if isinstance (split , DataSplit ) else split
206+
207+ if split_value == DataSplit .UNASSIGNED .value :
208+ raise ValueError (
209+ f"Cannot assign split value of `{ DataSplit .UNASSIGNED .value } `." )
210+
211+ valid_splits = filter (lambda name : name != DataSplit .UNASSIGNED .value ,
212+ DataSplit ._member_names_ )
213+
214+ if split_value not in valid_splits :
215+ raise ValueError (
216+ f"`split` must be one of : `{ valid_splits } `. Found : `{ split } `" )
217+
218+ task_id = self .client .execute (
219+ """mutation assignDataSplitPyApi($modelRunId: ID!, $data: CreateAssignDataRowsToDataSplitTaskInput!){
220+ createAssignDataRowsToDataSplitTask(modelRun : {id: $modelRunId}, data: $data)}
221+ """ , {
222+ 'modelRunId' : self .uid ,
223+ 'data' : {
224+ 'assignments' : [{
225+ 'split' : split_value ,
226+ 'dataRowIds' : data_row_ids
227+ }]
228+ }
229+ },
230+ experimental = True )['createAssignDataRowsToDataSplitTask' ]
231+
232+ status_query_str = """query assignDataRowsToDataSplitTaskStatusPyApi($id: ID!){
233+ assignDataRowsToDataSplitTaskStatus(where: {id : $id}){status errorMessage}}
234+ """
235+
236+ return self ._wait_until_done (lambda : self .client .execute (
237+ status_query_str , {'id' : task_id }, experimental = True )[
238+ 'assignDataRowsToDataSplitTaskStatus' ],
239+ timeout_seconds = timeout_seconds )
240+
183241 @experimental
184242 def update_status (self ,
185- status : str ,
243+ status : Union [ str , "ModelRun.Status" ] ,
186244 metadata : Optional [Dict [str , str ]] = None ,
187245 error_message : Optional [str ] = None ):
188246
189- valid_statuses = [
190- "EXPORTING_DATA" , "PREPARING_DATA" , "TRAINING_MODEL" , "COMPLETE" ,
191- "FAILED"
192- ]
193- if status not in valid_statuses :
247+ status_value = status .value if isinstance (status ,
248+ ModelRun .Status ) else status
249+ if status_value not in ModelRun .Status ._member_names_ :
194250 raise ValueError (
195- f"Status must be one of : `{ valid_statuses } `. Found : `{ status } `"
251+ f"Status must be one of : `{ ModelRun . Status . _member_names_ } `. Found : `{ status_value } `"
196252 )
197253
198- data : Dict [str , Any ] = {'status' : status }
254+ data : Dict [str , Any ] = {'status' : status_value }
199255 if error_message :
200256 data ['errorMessage' ] = error_message
201257
@@ -264,6 +320,7 @@ def export_labels(
264320class ModelRunDataRow (DbObject ):
265321 label_id = Field .String ("label_id" )
266322 model_run_id = Field .String ("model_run_id" )
323+ data_split = Field .Enum (DataSplit , "data_split" )
267324 data_row = Relationship .ToOne ("DataRow" , False , cache = True )
268325
269326 def __init__ (self , client , model_id , * args , ** kwargs ):
0 commit comments