|
26 | 26 | from labelbox.schema.resource_tag import ResourceTag |
27 | 27 | from labelbox.schema.task import Task |
28 | 28 | from labelbox.schema.user import User |
| 29 | +from labelbox.schema.task_queue import TaskQueue |
29 | 30 |
|
30 | 31 | if TYPE_CHECKING: |
31 | 32 | from labelbox import BulkImportRequest |
@@ -69,6 +70,7 @@ class Project(DbObject, Updateable, Deletable): |
69 | 70 | webhooks (Relationship): `ToMany` relationship to Webhook |
70 | 71 | benchmarks (Relationship): `ToMany` relationship to Benchmark |
71 | 72 | ontology (Relationship): `ToOne` relationship to Ontology |
| 73 | + task_queues (Relationship): `ToMany` relationship to TaskQueue |
72 | 74 | """ |
73 | 75 |
|
74 | 76 | name = Field.String("name") |
@@ -794,54 +796,33 @@ def _create_batch_async(self, |
794 | 796 |
|
795 | 797 | task_id = res['taskId'] |
796 | 798 |
|
797 | | - timeout_seconds = 600 |
798 | | - sleep_time = 2 |
799 | | - get_task_query_str = """query %s($taskId: ID!) { |
800 | | - task(where: {id: $taskId}) { |
801 | | - status |
| 799 | + status = self._wait_for_task(task_id) |
| 800 | + if status != "COMPLETE": |
| 801 | + raise LabelboxError(f"Batch was not created successfully.") |
| 802 | + |
| 803 | + # obtain batch entity to return |
| 804 | + get_batch_str = """query %s($projectId: ID!, $batchId: ID!) { |
| 805 | + project(where: {id: $projectId}) { |
| 806 | + batches(where: {id: $batchId}) { |
| 807 | + nodes { |
| 808 | + %s |
| 809 | + } |
| 810 | + } |
802 | 811 | } |
803 | 812 | } |
804 | | - """ % "getTaskPyApi" |
| 813 | + """ % ("getProjectBatchPyApi", |
| 814 | + query.results_query_part(Entity.Batch)) |
805 | 815 |
|
806 | | - while True: |
807 | | - task_status = self.client.execute( |
808 | | - get_task_query_str, {'taskId': task_id}, |
809 | | - experimental=True)['task']['status'] |
| 816 | + batch = self.client.execute( |
| 817 | + get_batch_str, { |
| 818 | + "projectId": self.uid, |
| 819 | + "batchId": batch_id |
| 820 | + }, |
| 821 | + timeout=180.0, |
| 822 | + experimental=True)["project"]["batches"]["nodes"][0] |
810 | 823 |
|
811 | | - if task_status == "COMPLETE": |
812 | | - # obtain batch entity to return |
813 | | - get_batch_str = """query %s($projectId: ID!, $batchId: ID!) { |
814 | | - project(where: {id: $projectId}) { |
815 | | - batches(where: {id: $batchId}) { |
816 | | - nodes { |
817 | | - %s |
818 | | - } |
819 | | - } |
820 | | - } |
821 | | - } |
822 | | - """ % ("getProjectBatchPyApi", |
823 | | - query.results_query_part(Entity.Batch)) |
824 | | - |
825 | | - batch = self.client.execute( |
826 | | - get_batch_str, { |
827 | | - "projectId": self.uid, |
828 | | - "batchId": batch_id |
829 | | - }, |
830 | | - timeout=180.0, |
831 | | - experimental=True)["project"]["batches"]["nodes"][0] |
832 | | - |
833 | | - # TODO async endpoints currently do not provide failed_data_row_ids in response |
834 | | - return Entity.Batch(self.client, self.uid, batch) |
835 | | - elif task_status == "IN_PROGRESS": |
836 | | - timeout_seconds -= sleep_time |
837 | | - if timeout_seconds <= 0: |
838 | | - raise LabelboxError( |
839 | | - f"Timed out while waiting for batch to be created.") |
840 | | - logger.debug("Creating batch, waiting for server...", self.uid) |
841 | | - time.sleep(sleep_time) |
842 | | - continue |
843 | | - else: |
844 | | - raise LabelboxError(f"Batch was not created successfully.") |
| 824 | + # TODO async endpoints currently do not provide failed_data_row_ids in response |
| 825 | + return Entity.Batch(self.client, self.uid, batch) |
845 | 826 |
|
846 | 827 | def _update_queue_mode(self, mode: "QueueMode") -> "QueueMode": |
847 | 828 | """ |
@@ -1127,6 +1108,99 @@ def batches(self) -> PaginatedCollection: |
1127 | 1108 | cursor_path=['project', 'batches', 'pageInfo', 'endCursor'], |
1128 | 1109 | experimental=True) |
1129 | 1110 |
|
| 1111 | + def task_queues(self) -> List[TaskQueue]: |
| 1112 | + """ Fetch all task queues that belong to this project |
| 1113 | +
|
| 1114 | + Returns: |
| 1115 | + A `List of `TaskQueue`s |
| 1116 | + """ |
| 1117 | + query_str = """query GetProjectTaskQueuesPyApi($projectId: ID!) { |
| 1118 | + project(where: {id: $projectId}) { |
| 1119 | + taskQueues { |
| 1120 | + %s |
| 1121 | + } |
| 1122 | + } |
| 1123 | + } |
| 1124 | + """ % (query.results_query_part(Entity.TaskQueue)) |
| 1125 | + |
| 1126 | + task_queue_values = self.client.execute( |
| 1127 | + query_str, {"projectId": self.uid}, |
| 1128 | + timeout=180.0, |
| 1129 | + experimental=True)["project"]["taskQueues"] |
| 1130 | + |
| 1131 | + return [ |
| 1132 | + Entity.TaskQueue(self.client, field_values) |
| 1133 | + for field_values in task_queue_values |
| 1134 | + ] |
| 1135 | + |
| 1136 | + def move_data_rows_to_task(self, data_row_ids: List[str], |
| 1137 | + task_queue_id: str): |
| 1138 | + """ |
| 1139 | +
|
| 1140 | + Moves data rows to the specified task queue. |
| 1141 | +
|
| 1142 | + Args: |
| 1143 | + data_row_ids: a list of data row ids to be moved |
| 1144 | + task_queue_id: the task queue id to be moved to, or None to specify the "Done" queue |
| 1145 | +
|
| 1146 | + Returns: |
| 1147 | + None if successful, or a raised error on failure |
| 1148 | +
|
| 1149 | + """ |
| 1150 | + method = "createBulkAddRowsToQueueTask" |
| 1151 | + query_str = """mutation AddDataRowsToTaskQueueAsyncPyApi( |
| 1152 | + $projectId: ID! |
| 1153 | + $queueId: ID |
| 1154 | + $dataRowIds: [ID!]! |
| 1155 | + ) { |
| 1156 | + project(where: { id: $projectId }) { |
| 1157 | + %s( |
| 1158 | + data: { queueId: $queueId, dataRowIds: $dataRowIds } |
| 1159 | + ) { |
| 1160 | + taskId |
| 1161 | + } |
| 1162 | + } |
| 1163 | + } |
| 1164 | + """ % method |
| 1165 | + |
| 1166 | + task_id = self.client.execute( |
| 1167 | + query_str, { |
| 1168 | + "projectId": self.uid, |
| 1169 | + "queueId": task_queue_id, |
| 1170 | + "dataRowIds": data_row_ids |
| 1171 | + }, |
| 1172 | + timeout=180.0, |
| 1173 | + experimental=True)["project"][method]["taskId"] |
| 1174 | + |
| 1175 | + status = self._wait_for_task(task_id) |
| 1176 | + if status != "COMPLETE": |
| 1177 | + raise LabelboxError(f"Data rows were not moved successfully") |
| 1178 | + |
| 1179 | + def _wait_for_task(self, task_id: str): |
| 1180 | + timeout_seconds = 600 |
| 1181 | + sleep_time = 2 |
| 1182 | + get_task_query_str = """query %s($taskId: ID!) { |
| 1183 | + task(where: {id: $taskId}) { |
| 1184 | + status |
| 1185 | + } |
| 1186 | + } |
| 1187 | + """ % "getTaskPyApi" |
| 1188 | + |
| 1189 | + while True: |
| 1190 | + task_status = self.client.execute( |
| 1191 | + get_task_query_str, {'taskId': task_id}, |
| 1192 | + experimental=True)['task']['status'] |
| 1193 | + |
| 1194 | + if task_status == "IN_PROGRESS": |
| 1195 | + timeout_seconds -= sleep_time |
| 1196 | + if timeout_seconds <= 0: |
| 1197 | + raise LabelboxError( |
| 1198 | + f"Timed out while waiting for task to be completed.") |
| 1199 | + time.sleep(sleep_time) |
| 1200 | + continue |
| 1201 | + |
| 1202 | + return task_status |
| 1203 | + |
1130 | 1204 | def upload_annotations( |
1131 | 1205 | self, |
1132 | 1206 | name: str, |
|
0 commit comments