Skip to content

Commit 0e049dd

Browse files
authored
[AL-5383] Export list of datarow IDs specified by user (#1061)
1 parent 0883e01 commit 0e049dd

File tree

14 files changed

+401
-349
lines changed

14 files changed

+401
-349
lines changed

labelbox/schema/dataset.py

Lines changed: 14 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,8 @@
1818
from labelbox.orm import query
1919
from labelbox.exceptions import MalformedQueryException
2020
from labelbox.schema.data_row import DataRow
21-
from labelbox.schema.export_filters import DatasetExportFilters, SharedExportFilters
21+
from labelbox.schema.export_filters import DatasetExportFilters, build_filters
2222
from labelbox.schema.export_params import CatalogExportParams, validate_catalog_export_params
23-
from labelbox.schema.project import _validate_datetime
2423
from labelbox.schema.task import Task
2524
from labelbox.schema.user import User
2625

@@ -549,7 +548,8 @@ def export_v2(self,
549548
>>> task = dataset.export_v2(
550549
>>> filters={
551550
>>> "last_activity_at": ["2000-01-01 00:00:00", "2050-01-01 00:00:00"],
552-
>>> "label_created_at": ["2000-01-01 00:00:00", "2050-01-01 00:00:00"]
551+
>>> "label_created_at": ["2000-01-01 00:00:00", "2050-01-01 00:00:00"],
552+
>>> "data_row_ids": [DATA_ROW_ID_1, DATA_ROW_ID_2, ...]
553553
>>> },
554554
>>> params={
555555
>>> "performance_details": False,
@@ -574,38 +574,26 @@ def export_v2(self,
574574

575575
_filters = filters or DatasetExportFilters({
576576
"last_activity_at": None,
577-
"label_created_at": None
577+
"label_created_at": None,
578+
"data_row_ids": None,
578579
})
579580

580-
def _get_timezone() -> str:
581-
timezone_query_str = """query CurrentUserPyApi { user { timezone } }"""
582-
tz_res = self.client.execute(timezone_query_str)
583-
return tz_res["user"]["timezone"] or "UTC"
584-
585-
timezone: Optional[str] = None
586-
587581
mutation_name = "exportDataRowsInCatalog"
588582
create_task_query_str = """mutation exportDataRowsInCatalogPyApi($input: ExportDataRowsInCatalogInput!){
589583
%s(input: $input) {taskId} }
590584
""" % (mutation_name)
591585

592-
search_query: List[Dict[str, Collection[str]]] = []
593-
search_query.append({
594-
"ids": [self.uid],
595-
"operator": "is",
596-
"type": "dataset"
597-
})
598586
media_type_override = _params.get('media_type_override', None)
599587

600588
if task_name is None:
601589
task_name = f"Export v2: dataset - {self.name}"
602-
query_params = {
590+
query_params: Dict[str, Any] = {
603591
"input": {
604592
"taskName": task_name,
605593
"filters": {
606594
"searchQuery": {
607595
"scope": None,
608-
"query": search_query
596+
"query": None,
609597
}
610598
},
611599
"params": {
@@ -632,82 +620,13 @@ def _get_timezone() -> str:
632620
}
633621
}
634622

635-
if "last_activity_at" in _filters and _filters[
636-
'last_activity_at'] is not None:
637-
if timezone is None:
638-
timezone = _get_timezone()
639-
values = _filters['last_activity_at']
640-
start, end = values
641-
if (start is not None and end is not None):
642-
[_validate_datetime(date) for date in values]
643-
search_query.append({
644-
"type": "data_row_last_activity_at",
645-
"value": {
646-
"operator": "BETWEEN",
647-
"timezone": timezone,
648-
"value": {
649-
"min": start,
650-
"max": end
651-
}
652-
}
653-
})
654-
elif (start is not None):
655-
_validate_datetime(start)
656-
search_query.append({
657-
"type": "data_row_last_activity_at",
658-
"value": {
659-
"operator": "GREATER_THAN_OR_EQUAL",
660-
"timezone": timezone,
661-
"value": start
662-
}
663-
})
664-
elif (end is not None):
665-
_validate_datetime(end)
666-
search_query.append({
667-
"type": "data_row_last_activity_at",
668-
"value": {
669-
"operator": "LESS_THAN_OR_EQUAL",
670-
"timezone": timezone,
671-
"value": end
672-
}
673-
})
674-
675-
if "label_created_at" in _filters and _filters[
676-
"label_created_at"] is not None:
677-
if timezone is None:
678-
timezone = _get_timezone()
679-
values = _filters['label_created_at']
680-
start, end = values
681-
if (start is not None and end is not None):
682-
[_validate_datetime(date) for date in values]
683-
search_query.append({
684-
"type": "labeled_at",
685-
"value": {
686-
"operator": "BETWEEN",
687-
"value": {
688-
"min": start,
689-
"max": end
690-
}
691-
}
692-
})
693-
elif (start is not None):
694-
_validate_datetime(start)
695-
search_query.append({
696-
"type": "labeled_at",
697-
"value": {
698-
"operator": "GREATER_THAN_OR_EQUAL",
699-
"value": start
700-
}
701-
})
702-
elif (end is not None):
703-
_validate_datetime(end)
704-
search_query.append({
705-
"type": "labeled_at",
706-
"value": {
707-
"operator": "LESS_THAN_OR_EQUAL",
708-
"value": end
709-
}
710-
})
623+
search_query = build_filters(self.client, _filters)
624+
search_query.append({
625+
"ids": [self.uid],
626+
"operator": "is",
627+
"type": "dataset"
628+
})
629+
query_params["input"]["filters"]["searchQuery"]["query"] = search_query
711630

712631
res = self.client.execute(
713632
create_task_query_str,

labelbox/schema/export_filters.py

Lines changed: 125 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
import sys
22

3-
from typing import Optional
3+
from datetime import datetime
4+
from typing import Collection, Dict, Tuple, List, Optional
5+
46
if sys.version_info >= (3, 8):
57
from typing import TypedDict
68
else:
79
from typing_extensions import TypedDict
810

9-
from typing import Tuple
11+
MAX_DATA_ROW_IDS_PER_EXPORT_V2 = 2_000
1012

1113

1214
class SharedExportFilters(TypedDict):
@@ -26,6 +28,12 @@ class SharedExportFilters(TypedDict):
2628
>>> [None, "2050-01-01 00:00:00"]
2729
>>> ["2000-01-01 00:00:00", None]
2830
"""
31+
data_row_ids: Optional[List[str]]
32+
""" Datarow ids to export
33+
Only allows MAX_DATAROW_IDS_PER_EXPORT_V2 datarows
34+
Example:
35+
>>> ["clgo3lyax0000veeezdbu3ws4", "clgo3lzjl0001veeer6y6z8zp", ...]
36+
"""
2937

3038

3139
class ProjectExportFilters(SharedExportFilters):
@@ -34,3 +42,118 @@ class ProjectExportFilters(SharedExportFilters):
3442

3543
class DatasetExportFilters(SharedExportFilters):
3644
pass
45+
46+
47+
def validate_datetime(string_date: str) -> bool:
48+
"""helper function validate that datetime is as follows: YYYY-MM-DD for the export"""
49+
if string_date:
50+
for fmt in ("%Y-%m-%d", "%Y-%m-%d %H:%M:%S"):
51+
try:
52+
datetime.strptime(string_date, fmt)
53+
return True
54+
except ValueError:
55+
pass
56+
raise ValueError(f"""Incorrect format for: {string_date}.
57+
Format must be \"YYYY-MM-DD\" or \"YYYY-MM-DD hh:mm:ss\"""")
58+
return True
59+
60+
61+
def build_filters(client, filters):
62+
search_query: List[Dict[str, Collection[str]]] = []
63+
timezone: Optional[str] = None
64+
65+
def _get_timezone() -> str:
66+
timezone_query_str = """query CurrentUserPyApi { user { timezone } }"""
67+
tz_res = client.execute(timezone_query_str)
68+
return tz_res["user"]["timezone"] or "UTC"
69+
70+
last_activity_at = filters.get("last_activity_at")
71+
if last_activity_at:
72+
if timezone is None:
73+
timezone = _get_timezone()
74+
start, end = last_activity_at
75+
if (start is not None and end is not None):
76+
[validate_datetime(date) for date in last_activity_at]
77+
search_query.append({
78+
"type": "data_row_last_activity_at",
79+
"value": {
80+
"operator": "BETWEEN",
81+
"timezone": timezone,
82+
"value": {
83+
"min": start,
84+
"max": end
85+
}
86+
}
87+
})
88+
elif (start is not None):
89+
validate_datetime(start)
90+
search_query.append({
91+
"type": "data_row_last_activity_at",
92+
"value": {
93+
"operator": "GREATER_THAN_OR_EQUAL",
94+
"timezone": timezone,
95+
"value": start
96+
}
97+
})
98+
elif (end is not None):
99+
validate_datetime(end)
100+
search_query.append({
101+
"type": "data_row_last_activity_at",
102+
"value": {
103+
"operator": "LESS_THAN_OR_EQUAL",
104+
"timezone": timezone,
105+
"value": end
106+
}
107+
})
108+
109+
label_created_at = filters.get("label_created_at")
110+
if label_created_at:
111+
if timezone is None:
112+
timezone = _get_timezone()
113+
start, end = label_created_at
114+
if (start is not None and end is not None):
115+
[validate_datetime(date) for date in label_created_at]
116+
search_query.append({
117+
"type": "labeled_at",
118+
"value": {
119+
"operator": "BETWEEN",
120+
"value": {
121+
"min": start,
122+
"max": end
123+
}
124+
}
125+
})
126+
elif (start is not None):
127+
validate_datetime(start)
128+
search_query.append({
129+
"type": "labeled_at",
130+
"value": {
131+
"operator": "GREATER_THAN_OR_EQUAL",
132+
"value": start
133+
}
134+
})
135+
elif (end is not None):
136+
validate_datetime(end)
137+
search_query.append({
138+
"type": "labeled_at",
139+
"value": {
140+
"operator": "LESS_THAN_OR_EQUAL",
141+
"value": end
142+
}
143+
})
144+
145+
data_row_ids = filters.get("data_row_ids")
146+
if data_row_ids:
147+
if not isinstance(data_row_ids, list):
148+
raise ValueError("`data_row_ids` filter expects a list.")
149+
if len(data_row_ids) > MAX_DATA_ROW_IDS_PER_EXPORT_V2:
150+
raise ValueError(
151+
f"`data_row_ids` filter only supports a max of {MAX_DATA_ROW_IDS_PER_EXPORT_V2} items."
152+
)
153+
search_query.append({
154+
"ids": data_row_ids,
155+
"operator": "is",
156+
"type": "data_row_id"
157+
})
158+
159+
return search_query

labelbox/schema/export_params.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@ class CatalogExportParams(DataRowParams):
3030
performance_details: Optional[bool]
3131
model_run_ids: Optional[List[str]]
3232
project_ids: Optional[List[str]]
33-
pass
3433

3534

3635
class ModelRunExportParams(DataRowParams):

0 commit comments

Comments
 (0)