Skip to content

Commit cdea274

Browse files
committed
Add static method to export data rows + add tests
1 parent 4f33e96 commit cdea274

File tree

4 files changed

+127
-3
lines changed

4 files changed

+127
-3
lines changed

labelbox/schema/data_row.py

Lines changed: 111 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,18 @@
11
import logging
2-
from typing import TYPE_CHECKING, Optional
2+
from typing import TYPE_CHECKING, Collection, Dict, List, Optional
33
import json
4+
from labelbox.exceptions import ResourceNotFoundError
45

56
from labelbox.orm import query
67
from labelbox.orm.db_object import DbObject, Updateable, BulkDeletable
78
from labelbox.orm.model import Entity, Field, Relationship
89
from labelbox.schema.data_row_metadata import DataRowMetadataField # type: ignore
10+
from labelbox.schema.export_params import CatalogExportParams
11+
from labelbox.schema.task import Task
12+
from labelbox.schema.user import User # type: ignore
913

1014
if TYPE_CHECKING:
11-
from labelbox import AssetAttachment
15+
from labelbox import AssetAttachment, Client
1216

1317
logger = logging.getLogger(__name__)
1418

@@ -150,3 +154,108 @@ def create_attachment(self,
150154
})
151155
return Entity.AssetAttachment(self.client,
152156
res["createDataRowAttachment"])
157+
158+
@staticmethod
159+
def export_v2(client: 'Client',
160+
data_rows: List['DataRow'],
161+
task_name: Optional[str] = None,
162+
params: Optional[CatalogExportParams] = None) -> Task:
163+
"""
164+
Creates a data rows export task with the given list, params and returns the task.
165+
166+
>>> dataset = client.get_dataset(DATASET_ID)
167+
>>> task = DataRow.export_v2(
168+
>>> data_rows_ids=[data_row.uid for data_row in dataset.data_rows.list()],
169+
>>> filters={
170+
>>> "last_activity_at": ["2000-01-01 00:00:00", "2050-01-01 00:00:00"],
171+
>>> "label_created_at": ["2000-01-01 00:00:00", "2050-01-01 00:00:00"]
172+
>>> },
173+
>>> params={
174+
>>> "performance_details": False,
175+
>>> "label_details": True
176+
>>> })
177+
>>> task.wait_till_done()
178+
>>> task.result
179+
"""
180+
print('export start')
181+
182+
_params = params or CatalogExportParams({
183+
"attachments": False,
184+
"metadata_fields": False,
185+
"data_row_details": False,
186+
"project_details": False,
187+
"performance_details": False,
188+
"label_details": False,
189+
"media_type_override": None,
190+
"model_runs_ids": None,
191+
"projects_ids": None,
192+
})
193+
194+
mutation_name = "exportDataRowsInCatalog"
195+
create_task_query_str = """mutation exportDataRowsInCatalogPyApi($input: ExportDataRowsInCatalogInput!){
196+
%s(input: $input) {taskId} }
197+
""" % (mutation_name)
198+
199+
data_rows_ids = [data_row.uid for data_row in data_rows]
200+
search_query: List[Dict[str, Collection[str]]] = []
201+
search_query.append({
202+
"ids": data_rows_ids,
203+
"operator": "is",
204+
"type": "data_row_id"
205+
})
206+
207+
print(search_query)
208+
media_type_override = _params.get('media_type_override', None)
209+
210+
if task_name is None:
211+
task_name = f"Export v2: data rows (%s)" % len(data_rows_ids)
212+
query_params = {
213+
"input": {
214+
"taskName": task_name,
215+
"filters": {
216+
"searchQuery": {
217+
"scope": None,
218+
"query": search_query
219+
}
220+
},
221+
"params": {
222+
"mediaTypeOverride":
223+
media_type_override.value
224+
if media_type_override is not None else None,
225+
"includeAttachments":
226+
_params.get('attachments', False),
227+
"includeMetadata":
228+
_params.get('metadata_fields', False),
229+
"includeDataRowDetails":
230+
_params.get('data_row_details', False),
231+
"includeProjectDetails":
232+
_params.get('project_details', False),
233+
"includePerformanceDetails":
234+
_params.get('performance_details', False),
235+
"includeLabelDetails":
236+
_params.get('label_details', False)
237+
},
238+
}
239+
}
240+
241+
print('export execution')
242+
print(client)
243+
244+
res = client.execute(
245+
create_task_query_str,
246+
query_params,
247+
)
248+
print(res)
249+
res = res[mutation_name]
250+
task_id = res["taskId"]
251+
user: User = client.get_user()
252+
tasks: List[Task] = list(
253+
user.created_tasks(where=Entity.Task.uid == task_id))
254+
# Cache user in a private variable as the relationship can't be
255+
# resolved due to server-side limitations (see Task.created_by)
256+
# for more info.
257+
if len(tasks) != 1:
258+
raise ResourceNotFoundError(Entity.Task, task_id)
259+
task: Task = tasks[0]
260+
task._user = user
261+
return task

labelbox/schema/dataset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -568,7 +568,7 @@ def export_v2(self,
568568
"label_details": False,
569569
"media_type_override": None,
570570
"model_runs_ids": None,
571-
"project_ids": None,
571+
"projects_ids": None,
572572
})
573573

574574
_filters = filters or DatasetExportFilters({

tests/integration/test_data_rows.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from tempfile import NamedTemporaryFile
2+
import time
23
import uuid
34
from datetime import datetime
45
import json
@@ -962,3 +963,13 @@ def test_create_data_row_with_media_type(dataset, image_url):
962963
assert "Found invalid contents for media type: \'IMAGE\'" in str(exc.value)
963964

964965
dataset.create_data_row(row_data=image_url, media_type="IMAGE")
966+
967+
968+
def test_export_data_rows(client, datarow):
969+
# Ensure created data rows are indexed
970+
time.sleep(10)
971+
task = DataRow.export_v2(client=client, data_rows=[datarow])
972+
task.wait_till_done()
973+
assert task.status == "COMPLETE"
974+
assert task.errors is None
975+
assert len(task.result) == 1

tests/integration/test_dataset.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import json
2+
import time
23
import pytest
34
import requests
45
from labelbox import Dataset
@@ -153,10 +154,13 @@ def test_dataset_export_v2(dataset, image_url):
153154
ids = set()
154155
for _ in range(n_data_rows):
155156
ids.add(dataset.create_data_row(row_data=image_url))
157+
158+
time.sleep(10)
156159
task = dataset.export_v2(params={
157160
"performance_details": False,
158161
"label_details": True
159162
})
163+
task.wait_till_done()
160164
assert task.status == "COMPLETE"
161165
assert task.errors is None
162166
assert len(task.result) == n_data_rows

0 commit comments

Comments
 (0)