|
| 1 | +from typing import Optional, List |
| 2 | +from labelbox.exceptions import ResourceNotFoundError |
1 | 3 | from labelbox.orm.db_object import DbObject |
2 | | -from labelbox.orm.model import Field |
| 4 | +from labelbox.orm.model import Entity, Field |
3 | 5 | from labelbox.pagination import PaginatedCollection |
| 6 | +from labelbox.schema.export_params import CatalogSliceExportParams |
| 7 | +from labelbox.schema.task import Task |
| 8 | +from labelbox.schema.user import User |
4 | 9 |
|
5 | 10 |
|
6 | 11 | class Slice(DbObject): |
@@ -59,6 +64,85 @@ def get_data_row_ids(self) -> PaginatedCollection: |
59 | 64 | obj_class=lambda _, data_row_id: data_row_id, |
60 | 65 | cursor_path=['getDataRowIdsBySavedQuery', 'pageInfo', 'endCursor']) |
61 | 66 |
|
| 67 | + def export_v2(self, |
| 68 | + task_name: Optional[str] = None, |
| 69 | + params: Optional[CatalogSliceExportParams] = None) -> Task: |
| 70 | + """ |
| 71 | + Creates a slice export task with the given params and returns the task. |
| 72 | + >>> slice = client.get_catalog_slice("SLICE_ID") |
| 73 | + >>> task = slice.export_v2( |
| 74 | + >>> params={"performance_details": False, "label_details": True} |
| 75 | + >>> ) |
| 76 | + >>> task.wait_till_done() |
| 77 | + >>> task.result |
| 78 | + """ |
| 79 | + |
| 80 | + _params = params or CatalogSliceExportParams({ |
| 81 | + "attachments": False, |
| 82 | + "metadata_fields": False, |
| 83 | + "data_row_details": False, |
| 84 | + "project_details": False, |
| 85 | + "performance_details": False, |
| 86 | + "label_details": False, |
| 87 | + "media_type_override": None, |
| 88 | + "model_runs_ids": None, |
| 89 | + "projects_ids": None, |
| 90 | + }) |
| 91 | + |
| 92 | + mutation_name = "exportDataRowsInSlice" |
| 93 | + create_task_query_str = """mutation exportDataRowsInSlicePyApi($input: ExportDataRowsInSliceInput!){ |
| 94 | + %s(input: $input) {taskId} } |
| 95 | + """ % (mutation_name) |
| 96 | + |
| 97 | + media_type_override = _params.get('media_type_override', None) |
| 98 | + query_params = { |
| 99 | + "input": { |
| 100 | + "taskName": task_name, |
| 101 | + "filters": { |
| 102 | + "sliceId": self.uid |
| 103 | + }, |
| 104 | + "params": { |
| 105 | + "mediaTypeOverride": |
| 106 | + media_type_override.value |
| 107 | + if media_type_override is not None else None, |
| 108 | + "includeAttachments": |
| 109 | + _params.get('attachments', False), |
| 110 | + "includeMetadata": |
| 111 | + _params.get('metadata_fields', False), |
| 112 | + "includeDataRowDetails": |
| 113 | + _params.get('data_row_details', False), |
| 114 | + "includeProjectDetails": |
| 115 | + _params.get('project_details', False), |
| 116 | + "includePerformanceDetails": |
| 117 | + _params.get('performance_details', False), |
| 118 | + "includeLabelDetails": |
| 119 | + _params.get('label_details', False), |
| 120 | + "projectIds": |
| 121 | + _params.get('projects_ids', None), |
| 122 | + "modelRunIds": |
| 123 | + _params.get('model_runs_ids', None), |
| 124 | + }, |
| 125 | + } |
| 126 | + } |
| 127 | + |
| 128 | + res = self.client.execute( |
| 129 | + create_task_query_str, |
| 130 | + query_params, |
| 131 | + ) |
| 132 | + res = res[mutation_name] |
| 133 | + task_id = res["taskId"] |
| 134 | + user: User = self.client.get_user() |
| 135 | + tasks: List[Task] = list( |
| 136 | + user.created_tasks(where=Entity.Task.uid == task_id)) |
| 137 | + # Cache user in a private variable as the relationship can't be |
| 138 | + # resolved due to server-side limitations (see Task.created_by) |
| 139 | + # for more info. |
| 140 | + if len(tasks) != 1: |
| 141 | + raise ResourceNotFoundError(Entity.Task, task_id) |
| 142 | + task: Task = tasks[0] |
| 143 | + task._user = user |
| 144 | + return task |
| 145 | + |
62 | 146 |
|
63 | 147 | class ModelSlice(Slice): |
64 | 148 | """ |
|
0 commit comments