|
1 | 1 | import logging |
2 | | -from typing import TYPE_CHECKING, Optional |
| 2 | +from typing import TYPE_CHECKING, Collection, Dict, List, Optional |
3 | 3 | import json |
| 4 | +from labelbox.exceptions import ResourceNotFoundError |
4 | 5 |
|
5 | 6 | from labelbox.orm import query |
6 | 7 | from labelbox.orm.db_object import DbObject, Updateable, BulkDeletable |
7 | 8 | from labelbox.orm.model import Entity, Field, Relationship |
8 | 9 | from labelbox.schema.data_row_metadata import DataRowMetadataField # type: ignore |
| 10 | +from labelbox.schema.export_params import CatalogExportParams |
| 11 | +from labelbox.schema.task import Task |
| 12 | +from labelbox.schema.user import User # type: ignore |
9 | 13 |
|
10 | 14 | if TYPE_CHECKING: |
11 | | - from labelbox import AssetAttachment |
| 15 | + from labelbox import AssetAttachment, Client |
12 | 16 |
|
13 | 17 | logger = logging.getLogger(__name__) |
14 | 18 |
|
@@ -150,3 +154,108 @@ def create_attachment(self, |
150 | 154 | }) |
151 | 155 | return Entity.AssetAttachment(self.client, |
152 | 156 | res["createDataRowAttachment"]) |
| 157 | + |
| 158 | + @staticmethod |
| 159 | + def export_v2(client: 'Client', |
| 160 | + data_rows: List['DataRow'], |
| 161 | + task_name: Optional[str] = None, |
| 162 | + params: Optional[CatalogExportParams] = None) -> Task: |
| 163 | + """ |
| 164 | + Creates a data rows export task with the given list, params and returns the task. |
| 165 | + |
| 166 | + >>> dataset = client.get_dataset(DATASET_ID) |
| 167 | + >>> task = DataRow.export_v2( |
| 168 | + >>> data_rows_ids=[data_row.uid for data_row in dataset.data_rows.list()], |
| 169 | + >>> filters={ |
| 170 | + >>> "last_activity_at": ["2000-01-01 00:00:00", "2050-01-01 00:00:00"], |
| 171 | + >>> "label_created_at": ["2000-01-01 00:00:00", "2050-01-01 00:00:00"] |
| 172 | + >>> }, |
| 173 | + >>> params={ |
| 174 | + >>> "performance_details": False, |
| 175 | + >>> "label_details": True |
| 176 | + >>> }) |
| 177 | + >>> task.wait_till_done() |
| 178 | + >>> task.result |
| 179 | + """ |
| 180 | + print('export start') |
| 181 | + |
| 182 | + _params = params or CatalogExportParams({ |
| 183 | + "attachments": False, |
| 184 | + "metadata_fields": False, |
| 185 | + "data_row_details": False, |
| 186 | + "project_details": False, |
| 187 | + "performance_details": False, |
| 188 | + "label_details": False, |
| 189 | + "media_type_override": None, |
| 190 | + "model_runs_ids": None, |
| 191 | + "projects_ids": None, |
| 192 | + }) |
| 193 | + |
| 194 | + mutation_name = "exportDataRowsInCatalog" |
| 195 | + create_task_query_str = """mutation exportDataRowsInCatalogPyApi($input: ExportDataRowsInCatalogInput!){ |
| 196 | + %s(input: $input) {taskId} } |
| 197 | + """ % (mutation_name) |
| 198 | + |
| 199 | + data_rows_ids = [data_row.uid for data_row in data_rows] |
| 200 | + search_query: List[Dict[str, Collection[str]]] = [] |
| 201 | + search_query.append({ |
| 202 | + "ids": data_rows_ids, |
| 203 | + "operator": "is", |
| 204 | + "type": "data_row_id" |
| 205 | + }) |
| 206 | + |
| 207 | + print(search_query) |
| 208 | + media_type_override = _params.get('media_type_override', None) |
| 209 | + |
| 210 | + if task_name is None: |
| 211 | + task_name = f"Export v2: data rows (%s)" % len(data_rows_ids) |
| 212 | + query_params = { |
| 213 | + "input": { |
| 214 | + "taskName": task_name, |
| 215 | + "filters": { |
| 216 | + "searchQuery": { |
| 217 | + "scope": None, |
| 218 | + "query": search_query |
| 219 | + } |
| 220 | + }, |
| 221 | + "params": { |
| 222 | + "mediaTypeOverride": |
| 223 | + media_type_override.value |
| 224 | + if media_type_override is not None else None, |
| 225 | + "includeAttachments": |
| 226 | + _params.get('attachments', False), |
| 227 | + "includeMetadata": |
| 228 | + _params.get('metadata_fields', False), |
| 229 | + "includeDataRowDetails": |
| 230 | + _params.get('data_row_details', False), |
| 231 | + "includeProjectDetails": |
| 232 | + _params.get('project_details', False), |
| 233 | + "includePerformanceDetails": |
| 234 | + _params.get('performance_details', False), |
| 235 | + "includeLabelDetails": |
| 236 | + _params.get('label_details', False) |
| 237 | + }, |
| 238 | + } |
| 239 | + } |
| 240 | + |
| 241 | + print('export execution') |
| 242 | + print(client) |
| 243 | + |
| 244 | + res = client.execute( |
| 245 | + create_task_query_str, |
| 246 | + query_params, |
| 247 | + ) |
| 248 | + print(res) |
| 249 | + res = res[mutation_name] |
| 250 | + task_id = res["taskId"] |
| 251 | + user: User = client.get_user() |
| 252 | + tasks: List[Task] = list( |
| 253 | + user.created_tasks(where=Entity.Task.uid == task_id)) |
| 254 | + # Cache user in a private variable as the relationship can't be |
| 255 | + # resolved due to server-side limitations (see Task.created_by) |
| 256 | + # for more info. |
| 257 | + if len(tasks) != 1: |
| 258 | + raise ResourceNotFoundError(Entity.Task, task_id) |
| 259 | + task: Task = tasks[0] |
| 260 | + task._user = user |
| 261 | + return task |
0 commit comments