@@ -47,22 +47,43 @@ class Status(Enum):
4747 COMPLETE = "COMPLETE"
4848 FAILED = "FAILED"
4949
50- def upsert_labels (self , label_ids , timeout_seconds = 3600 ):
50+ def upsert_labels (self ,
51+ label_ids : Optional [List [str ]] = None ,
52+ project_id : Optional [str ] = None ,
53+ timeout_seconds = 3600 ):
5154 """ Adds data rows and labels to a Model Run
5255 Args:
5356 label_ids (list): label ids to insert
57+ project_id (string): project uuid, all project labels will be uploaded
58+ Either label_ids OR project_id is required but NOT both
5459 timeout_seconds (float): Max waiting time, in seconds.
5560 Returns:
5661 ID of newly generated async task
62+
5763 """
5864
59- if len (label_ids ) < 1 :
60- raise ValueError ("Must provide at least one label id" )
65+ use_label_ids = label_ids is not None and len (label_ids ) > 0
66+ use_project_id = project_id is not None
67+
68+ if not use_label_ids and not use_project_id :
69+ raise ValueError (
70+ "Must provide at least one label id or a project id" )
71+
72+ if use_label_ids and use_project_id :
73+ raise ValueError ("Must only one of label ids, project id" )
74+
75+ if use_label_ids :
76+ return self ._upsert_labels_by_label_ids (label_ids , timeout_seconds )
77+ else : # use_project_id
78+ return self ._upsert_labels_by_project_id (project_id ,
79+ timeout_seconds )
6180
81+ def _upsert_labels_by_label_ids (self , label_ids : List [str ],
82+ timeout_seconds : int ):
6283 mutation_name = 'createMEAModelRunLabelRegistrationTask'
6384 create_task_query_str = """mutation createMEAModelRunLabelRegistrationTaskPyApi($modelRunId: ID!, $labelIds : [ID!]!) {
64- %s(where : { id : $modelRunId}, data : {labelIds: $labelIds})}
65- """ % (mutation_name )
85+ %s(where : { id : $modelRunId}, data : {labelIds: $labelIds})}
86+ """ % (mutation_name )
6687
6788 res = self .client .execute (create_task_query_str , {
6889 'modelRunId' : self .uid ,
@@ -80,6 +101,29 @@ def upsert_labels(self, label_ids, timeout_seconds=3600):
80101 }})['MEALabelRegistrationTaskStatus' ],
81102 timeout_seconds = timeout_seconds )
82103
104+ def _upsert_labels_by_project_id (self , project_id : str ,
105+ timeout_seconds : int ):
106+ mutation_name = 'createMEAModelRunProjectLabelRegistrationTask'
107+ create_task_query_str = """mutation createMEAModelRunProjectLabelRegistrationTaskPyApi($modelRunId: ID!, $projectId : ID!) {
108+ %s(where : { modelRunId : $modelRunId, projectId: $projectId})}
109+ """ % (mutation_name )
110+
111+ res = self .client .execute (create_task_query_str , {
112+ 'modelRunId' : self .uid ,
113+ 'projectId' : project_id
114+ })
115+ task_id = res [mutation_name ]
116+
117+ status_query_str = """query MEALabelRegistrationTaskStatusPyApi($where: WhereUniqueIdInput!){
118+ MEALabelRegistrationTaskStatus(where: $where) {status errorMessage}
119+ }
120+ """
121+ return self ._wait_until_done (lambda : self .client .execute (
122+ status_query_str , {'where' : {
123+ 'id' : task_id
124+ }})['MEALabelRegistrationTaskStatus' ],
125+ timeout_seconds = timeout_seconds )
126+
83127 def upsert_data_rows (self ,
84128 data_row_ids = None ,
85129 global_keys = None ,
0 commit comments