|
26 | 26 | from labelbox.schema.resource_tag import ResourceTag |
27 | 27 | from labelbox.schema.task import Task |
28 | 28 | from labelbox.schema.user import User |
| 29 | +from labelbox.schema.task_queue import TaskQueue |
29 | 30 |
|
30 | 31 | if TYPE_CHECKING: |
31 | 32 | from labelbox import BulkImportRequest |
@@ -69,6 +70,7 @@ class Project(DbObject, Updateable, Deletable): |
69 | 70 | webhooks (Relationship): `ToMany` relationship to Webhook |
70 | 71 | benchmarks (Relationship): `ToMany` relationship to Benchmark |
71 | 72 | ontology (Relationship): `ToOne` relationship to Ontology |
| 73 | + task_queues (Relationship): `ToMany` relationship to TaskQueue |
72 | 74 | """ |
73 | 75 |
|
74 | 76 | name = Field.String("name") |
@@ -806,54 +808,34 @@ def _create_batch_async(self, |
806 | 808 |
|
807 | 809 | task_id = res['taskId'] |
808 | 810 |
|
809 | | - timeout_seconds = 600 |
810 | | - sleep_time = 2 |
811 | | - get_task_query_str = """query %s($taskId: ID!) { |
812 | | - task(where: {id: $taskId}) { |
813 | | - status |
| 811 | + task = self._wait_for_task(task_id) |
| 812 | + if task.status != "COMPLETE": |
| 813 | + raise LabelboxError(f"Batch was not created successfully: " + |
| 814 | + json.dumps(task.errors)) |
| 815 | + |
| 816 | + # obtain batch entity to return |
| 817 | + get_batch_str = """query %s($projectId: ID!, $batchId: ID!) { |
| 818 | + project(where: {id: $projectId}) { |
| 819 | + batches(where: {id: $batchId}) { |
| 820 | + nodes { |
| 821 | + %s |
| 822 | + } |
| 823 | + } |
814 | 824 | } |
815 | 825 | } |
816 | | - """ % "getTaskPyApi" |
| 826 | + """ % ("getProjectBatchPyApi", |
| 827 | + query.results_query_part(Entity.Batch)) |
817 | 828 |
|
818 | | - while True: |
819 | | - task_status = self.client.execute( |
820 | | - get_task_query_str, {'taskId': task_id}, |
821 | | - experimental=True)['task']['status'] |
822 | | - |
823 | | - if task_status == "COMPLETE": |
824 | | - # obtain batch entity to return |
825 | | - get_batch_str = """query %s($projectId: ID!, $batchId: ID!) { |
826 | | - project(where: {id: $projectId}) { |
827 | | - batches(where: {id: $batchId}) { |
828 | | - nodes { |
829 | | - %s |
830 | | - } |
831 | | - } |
832 | | - } |
833 | | - } |
834 | | - """ % ("getProjectBatchPyApi", |
835 | | - query.results_query_part(Entity.Batch)) |
836 | | - |
837 | | - batch = self.client.execute( |
838 | | - get_batch_str, { |
839 | | - "projectId": self.uid, |
840 | | - "batchId": batch_id |
841 | | - }, |
842 | | - timeout=180.0, |
843 | | - experimental=True)["project"]["batches"]["nodes"][0] |
844 | | - |
845 | | - # TODO async endpoints currently do not provide failed_data_row_ids in response |
846 | | - return Entity.Batch(self.client, self.uid, batch) |
847 | | - elif task_status == "IN_PROGRESS": |
848 | | - timeout_seconds -= sleep_time |
849 | | - if timeout_seconds <= 0: |
850 | | - raise LabelboxError( |
851 | | - f"Timed out while waiting for batch to be created.") |
852 | | - logger.debug("Creating batch, waiting for server...", self.uid) |
853 | | - time.sleep(sleep_time) |
854 | | - continue |
855 | | - else: |
856 | | - raise LabelboxError(f"Batch was not created successfully.") |
| 829 | + batch = self.client.execute( |
| 830 | + get_batch_str, { |
| 831 | + "projectId": self.uid, |
| 832 | + "batchId": batch_id |
| 833 | + }, |
| 834 | + timeout=180.0, |
| 835 | + experimental=True)["project"]["batches"]["nodes"][0] |
| 836 | + |
| 837 | + # TODO async endpoints currently do not provide failed_data_row_ids in response |
| 838 | + return Entity.Batch(self.client, self.uid, batch) |
857 | 839 |
|
858 | 840 | def _update_queue_mode(self, mode: "QueueMode") -> "QueueMode": |
859 | 841 | """ |
@@ -1139,6 +1121,81 @@ def batches(self) -> PaginatedCollection: |
1139 | 1121 | cursor_path=['project', 'batches', 'pageInfo', 'endCursor'], |
1140 | 1122 | experimental=True) |
1141 | 1123 |
|
| 1124 | + def task_queues(self) -> List[TaskQueue]: |
| 1125 | + """ Fetch all task queues that belong to this project |
| 1126 | +
|
| 1127 | + Returns: |
| 1128 | + A `List of `TaskQueue`s |
| 1129 | + """ |
| 1130 | + query_str = """query GetProjectTaskQueuesPyApi($projectId: ID!) { |
| 1131 | + project(where: {id: $projectId}) { |
| 1132 | + taskQueues { |
| 1133 | + %s |
| 1134 | + } |
| 1135 | + } |
| 1136 | + } |
| 1137 | + """ % (query.results_query_part(Entity.TaskQueue)) |
| 1138 | + |
| 1139 | + task_queue_values = self.client.execute( |
| 1140 | + query_str, {"projectId": self.uid}, |
| 1141 | + timeout=180.0, |
| 1142 | + experimental=True)["project"]["taskQueues"] |
| 1143 | + |
| 1144 | + return [ |
| 1145 | + Entity.TaskQueue(self.client, field_values) |
| 1146 | + for field_values in task_queue_values |
| 1147 | + ] |
| 1148 | + |
| 1149 | + def move_data_rows_to_task_queue(self, data_row_ids: List[str], |
| 1150 | + task_queue_id: str): |
| 1151 | + """ |
| 1152 | +
|
| 1153 | + Moves data rows to the specified task queue. |
| 1154 | +
|
| 1155 | + Args: |
| 1156 | + data_row_ids: a list of data row ids to be moved |
| 1157 | + task_queue_id: the task queue id to be moved to, or None to specify the "Done" queue |
| 1158 | +
|
| 1159 | + Returns: |
| 1160 | + None if successful, or a raised error on failure |
| 1161 | +
|
| 1162 | + """ |
| 1163 | + method = "createBulkAddRowsToQueueTask" |
| 1164 | + query_str = """mutation AddDataRowsToTaskQueueAsyncPyApi( |
| 1165 | + $projectId: ID! |
| 1166 | + $queueId: ID |
| 1167 | + $dataRowIds: [ID!]! |
| 1168 | + ) { |
| 1169 | + project(where: { id: $projectId }) { |
| 1170 | + %s( |
| 1171 | + data: { queueId: $queueId, dataRowIds: $dataRowIds } |
| 1172 | + ) { |
| 1173 | + taskId |
| 1174 | + } |
| 1175 | + } |
| 1176 | + } |
| 1177 | + """ % method |
| 1178 | + |
| 1179 | + task_id = self.client.execute( |
| 1180 | + query_str, { |
| 1181 | + "projectId": self.uid, |
| 1182 | + "queueId": task_queue_id, |
| 1183 | + "dataRowIds": data_row_ids |
| 1184 | + }, |
| 1185 | + timeout=180.0, |
| 1186 | + experimental=True)["project"][method]["taskId"] |
| 1187 | + |
| 1188 | + task = self._wait_for_task(task_id) |
| 1189 | + if task.status != "COMPLETE": |
| 1190 | + raise LabelboxError(f"Data rows were not moved successfully: " + |
| 1191 | + json.dumps(task.errors)) |
| 1192 | + |
| 1193 | + def _wait_for_task(self, task_id: str) -> Task: |
| 1194 | + task = Task.get_task(self.client, task_id) |
| 1195 | + task.wait_till_done() |
| 1196 | + |
| 1197 | + return task |
| 1198 | + |
1142 | 1199 | def upload_annotations( |
1143 | 1200 | self, |
1144 | 1201 | name: str, |
|
0 commit comments