11from typing import Dict , Iterable , Union
22from pathlib import Path
33import os
4+ import time
45
56from labelbox .pagination import PaginatedCollection
67from labelbox .schema .annotation_import import MEAPredictionImport
@@ -16,20 +17,54 @@ class ModelRun(DbObject):
1617 created_by_id = Field .String ("created_by_id" , "createdBy" )
1718 model_id = Field .String ("model_id" )
1819
19- def upsert_labels (self , label_ids ):
20+ def upsert_labels (self , label_ids , timeout_seconds = 60 ):
21+ """ Calls GraphQL API to start the MEA labels registration process
22+ Args:
23+ label_ids (list): label ids to insert
24+ timeout_seconds (float): Max waiting time, in seconds.
25+ Returns:
26+ ID of newly generated async task
27+ """
2028
2129 if len (label_ids ) < 1 :
2230 raise ValueError ("Must provide at least one label id" )
2331
24- query_str = """mutation upsertModelRunLabelsPyApi($modelRunId: ID!, $labelIds : [ID!]!) {
25- upsertModelRunLabels(where : { id : $modelRunId}, data : {labelIds: $labelIds})}
26- """
27- res = self .client .execute (query_str , {
32+ sleep_time = 5
33+
34+ mutation_name = 'createMEAModelRunLabelRegistrationTask'
35+ create_task_query_str = """mutation createMEAModelRunLabelRegistrationTaskPyApi($modelRunId: ID!, $labelIds : [ID!]!) {
36+ %s(where : { id : $modelRunId}, data : {labelIds: $labelIds})}
37+ """ % (mutation_name )
38+
39+ res = self .client .execute (create_task_query_str , {
2840 'modelRunId' : self .uid ,
2941 'labelIds' : label_ids
3042 })
31- # TODO: Return a task
32- return True
43+ task_id = res [mutation_name ]
44+
45+ status_query_str = """query MEALabelRegistrationTaskStatusPyApi($where: WhereUniqueIdInput!){
46+ MEALabelRegistrationTaskStatus(where: $where) {status errorMessage}
47+ }
48+ """
49+
50+ while True :
51+ res = self .client .execute (status_query_str ,
52+ {'where' : {
53+ 'id' : task_id
54+ }})['MEALabelRegistrationTaskStatus' ]
55+ if res ['status' ] == 'COMPLETE' :
56+ return res
57+ elif res ['status' ] == 'FAILED' :
58+ raise Exception (
59+ f"MEA Label Import Failed. Details : { res ['errorMessage' ]} " )
60+
61+ timeout_seconds -= sleep_time
62+ if timeout_seconds <= 0 :
63+ raise TimeoutError (
64+ f"Unable to complete import within { timeout_seconds } seconds."
65+ )
66+
67+ time .sleep (sleep_time )
3368
3469 def add_predictions (
3570 self ,
0 commit comments