Skip to content

Commit 784a121

Browse files
committed
Add export Dataset.export_v2 method, improve docs
1 parent 1ea2ad3 commit 784a121

File tree

5 files changed

+207
-12
lines changed

5 files changed

+207
-12
lines changed

labelbox/schema/dataset.py

Lines changed: 191 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Generator, List, Union, Any, TYPE_CHECKING
1+
from typing import Collection, Dict, Generator, List, Optional, Union, Any, TYPE_CHECKING
22
import os
33
import json
44
import logging
@@ -17,9 +17,12 @@
1717
from labelbox.orm.model import Entity, Field, Relationship
1818
from labelbox.orm import query
1919
from labelbox.exceptions import MalformedQueryException
20-
21-
if TYPE_CHECKING:
22-
from labelbox import Task, User, DataRow
20+
from labelbox.schema.data_row import DataRow
21+
from labelbox.schema.export_filters import DatasetExportFilters, SharedExportFilters
22+
from labelbox.schema.export_params import CatalogExportParams
23+
from labelbox.schema.project import _validate_datetime
24+
from labelbox.schema.task import Task
25+
from labelbox.schema.user import User
2326

2427
logger = logging.getLogger(__name__)
2528

@@ -534,3 +537,187 @@ def export_data_rows(self,
534537
logger.debug("Dataset '%s' data row export, waiting for server...",
535538
self.uid)
536539
time.sleep(sleep_time)
540+
541+
def export_v2(self,
542+
task_name: Optional[str] = None,
543+
filters: Optional[DatasetExportFilters] = None,
544+
params: Optional[CatalogExportParams] = None) -> Task:
545+
"""
546+
Creates a dataset export task with the given params and returns the task.
547+
548+
>>> dataset = client.get_dataset(DATASET_ID)
549+
>>> task = dataset.export_v2(
550+
>>> filters={
551+
>>> "last_activity_at": ["2000-01-01 00:00:00", "2050-01-01 00:00:00"],
552+
>>> "label_created_at": ["2000-01-01 00:00:00", "2050-01-01 00:00:00"]
553+
>>> },
554+
>>> params={
555+
>>> "performance_details": False,
556+
>>> "label_details": True
557+
>>> })
558+
>>> task.wait_till_done()
559+
>>> task.result
560+
"""
561+
562+
_params = params or CatalogExportParams({
563+
"attachments": False,
564+
"metadata_fields": False,
565+
"data_row_details": False,
566+
"project_details": False,
567+
"performance_details": False,
568+
"label_details": False,
569+
"media_type_override": None,
570+
"model_runs_ids": None,
571+
"project_ids": None,
572+
})
573+
574+
_filters = filters or DatasetExportFilters({
575+
"last_activity_at": None,
576+
"label_created_at": None
577+
})
578+
579+
def _get_timezone() -> str:
580+
timezone_query_str = """query CurrentUserPyApi { user { timezone } }"""
581+
tz_res = self.client.execute(timezone_query_str)
582+
return tz_res["user"]["timezone"] or "UTC"
583+
584+
timezone: Optional[str] = None
585+
586+
mutation_name = "exportDataRowsInCatalog"
587+
create_task_query_str = """mutation exportDataRowsInCatalogPyApi($input: ExportDataRowsInCatalogInput!){
588+
%s(input: $input) {taskId} }
589+
""" % (mutation_name)
590+
591+
search_query: List[Dict[str, Collection[str]]] = []
592+
search_query.append({
593+
"ids": [self.uid],
594+
"operator": "is",
595+
"type": "dataset"
596+
})
597+
media_type_override = _params.get('media_type_override', None)
598+
599+
if task_name is None:
600+
task_name = f"Export v2: dataset - {self.name}"
601+
query_params = {
602+
"input": {
603+
"taskName": task_name,
604+
"filters": {
605+
"searchQuery": {
606+
"scope": None,
607+
"query": search_query
608+
}
609+
},
610+
"params": {
611+
"mediaTypeOverride":
612+
media_type_override.value
613+
if media_type_override is not None else None,
614+
"includeAttachments":
615+
_params.get('attachments', False),
616+
"includeMetadata":
617+
_params.get('metadata_fields', False),
618+
"includeDataRowDetails":
619+
_params.get('data_row_details', False),
620+
"includeProjectDetails":
621+
_params.get('project_details', False),
622+
"includePerformanceDetails":
623+
_params.get('performance_details', False),
624+
"includeLabelDetails":
625+
_params.get('label_details', False)
626+
},
627+
}
628+
}
629+
630+
if "last_activity_at" in _filters and _filters[
631+
'last_activity_at'] is not None:
632+
if timezone is None:
633+
timezone = _get_timezone()
634+
values = _filters['last_activity_at']
635+
start, end = values
636+
if (start is not None and end is not None):
637+
[_validate_datetime(date) for date in values]
638+
search_query.append({
639+
"type": "data_row_last_activity_at",
640+
"value": {
641+
"operator": "BETWEEN",
642+
"timezone": timezone,
643+
"value": {
644+
"min": start,
645+
"max": end
646+
}
647+
}
648+
})
649+
elif (start is not None):
650+
_validate_datetime(start)
651+
search_query.append({
652+
"type": "data_row_last_activity_at",
653+
"value": {
654+
"operator": "GREATER_THAN_OR_EQUAL",
655+
"timezone": timezone,
656+
"value": start
657+
}
658+
})
659+
elif (end is not None):
660+
_validate_datetime(end)
661+
search_query.append({
662+
"type": "data_row_last_activity_at",
663+
"value": {
664+
"operator": "LESS_THAN_OR_EQUAL",
665+
"timezone": timezone,
666+
"value": end
667+
}
668+
})
669+
670+
if "label_created_at" in _filters and _filters[
671+
"label_created_at"] is not None:
672+
if timezone is None:
673+
timezone = _get_timezone()
674+
values = _filters['label_created_at']
675+
start, end = values
676+
if (start is not None and end is not None):
677+
[_validate_datetime(date) for date in values]
678+
search_query.append({
679+
"type": "labeled_at",
680+
"value": {
681+
"operator": "BETWEEN",
682+
"value": {
683+
"min": start,
684+
"max": end
685+
}
686+
}
687+
})
688+
elif (start is not None):
689+
_validate_datetime(start)
690+
search_query.append({
691+
"type": "labeled_at",
692+
"value": {
693+
"operator": "GREATER_THAN_OR_EQUAL",
694+
"value": start
695+
}
696+
})
697+
elif (end is not None):
698+
_validate_datetime(end)
699+
search_query.append({
700+
"type": "labeled_at",
701+
"value": {
702+
"operator": "LESS_THAN_OR_EQUAL",
703+
"value": end
704+
}
705+
})
706+
707+
res = self.client.execute(
708+
create_task_query_str,
709+
query_params,
710+
)
711+
res = res[mutation_name]
712+
task_id = res["taskId"]
713+
user: User = self.client.get_user()
714+
tasks: List[Task] = list(
715+
user.created_tasks(where=Entity.Task.uid == task_id))
716+
# Cache user in a private variable as the relationship can't be
717+
# resolved due to server-side limitations (see Task.created_by)
718+
# for more info.
719+
if len(tasks) != 1:
720+
raise ResourceNotFoundError(Entity.Task, task_id)
721+
task: Task = tasks[0]
722+
task._user = user
723+
return task

labelbox/schema/export_filters.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from typing import Tuple
1010

1111

12-
class ProjectExportFilters(TypedDict):
12+
class SharedExportFilters(TypedDict):
1313
label_created_at: Optional[Tuple[str, str]]
1414
""" Date range for labels created at
1515
Formatted "YYYY-MM-DD" or "YYYY-MM-DD hh:mm:ss"
@@ -26,3 +26,11 @@ class ProjectExportFilters(TypedDict):
2626
>>> [None, "2050-01-01 00:00:00"]
2727
>>> ["2000-01-01 00:00:00", None]
2828
"""
29+
30+
31+
class ProjectExportFilters(SharedExportFilters):
32+
pass
33+
34+
35+
class DatasetExportFilters(SharedExportFilters):
36+
pass

labelbox/schema/export_params.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ class ProjectExportParams(DataRowParams):
2222
performance_details: Optional[bool]
2323

2424

25-
class CatalogSliceExportParams(DataRowParams):
25+
class CatalogExportParams(DataRowParams):
2626
project_details: Optional[bool]
2727
label_details: Optional[bool]
2828
performance_details: Optional[bool]

labelbox/schema/project.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -420,7 +420,7 @@ def export_v2(self,
420420
filters: Optional[ProjectExportFilters] = None,
421421
params: Optional[ProjectExportParams] = None) -> Task:
422422
"""
423-
Creates a project run export task with the given params and returns the task.
423+
Creates a project export task with the given params and returns the task.
424424
425425
For more information visit: https://docs.labelbox.com/docs/exports-v2#export-from-a-project-python-sdk
426426
@@ -430,8 +430,8 @@ def export_v2(self,
430430
>>> "label_created_at": ["2000-01-01 00:00:00", "2050-01-01 00:00:00"]
431431
>>> },
432432
>>> params={
433-
>>> "include_performance_details": False,
434-
>>> "include_labels": True
433+
>>> "performance_details": False,
434+
>>> "label_details": True
435435
>>> })
436436
>>> task.wait_till_done()
437437
>>> task.result

labelbox/schema/slice.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from labelbox.orm.db_object import DbObject
44
from labelbox.orm.model import Entity, Field
55
from labelbox.pagination import PaginatedCollection
6-
from labelbox.schema.export_params import CatalogSliceExportParams
6+
from labelbox.schema.export_params import CatalogExportParams
77
from labelbox.schema.task import Task
88
from labelbox.schema.user import User
99

@@ -66,7 +66,7 @@ def get_data_row_ids(self) -> PaginatedCollection:
6666

6767
def export_v2(self,
6868
task_name: Optional[str] = None,
69-
params: Optional[CatalogSliceExportParams] = None) -> Task:
69+
params: Optional[CatalogExportParams] = None) -> Task:
7070
"""
7171
Creates a slice export task with the given params and returns the task.
7272
>>> slice = client.get_catalog_slice("SLICE_ID")
@@ -77,7 +77,7 @@ def export_v2(self,
7777
>>> task.result
7878
"""
7979

80-
_params = params or CatalogSliceExportParams({
80+
_params = params or CatalogExportParams({
8181
"attachments": False,
8282
"metadata_fields": False,
8383
"data_row_details": False,

0 commit comments

Comments
 (0)