Skip to content

Commit 061dab2

Browse files
authored
[PLT-319] Add Catalog class, export_v2 method (#1432)
1 parent a917f89 commit 061dab2

File tree

7 files changed

+233
-0
lines changed

7 files changed

+233
-0
lines changed

labelbox/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from labelbox.schema.annotation_import import MALPredictionImport, MEAPredictionImport, LabelImport, MEAToMALPredictionImport
1010
from labelbox.schema.dataset import Dataset
1111
from labelbox.schema.data_row import DataRow
12+
from labelbox.schema.catalog import Catalog
1213
from labelbox.schema.enums import AnnotationImportState
1314
from labelbox.schema.label import Label
1415
from labelbox.schema.batch import Batch

labelbox/client.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from labelbox.schema import role
2626
from labelbox.schema.conflict_resolution_strategy import ConflictResolutionStrategy
2727
from labelbox.schema.data_row import DataRow
28+
from labelbox.schema.catalog import Catalog
2829
from labelbox.schema.data_row_metadata import DataRowMetadataOntology
2930
from labelbox.schema.dataset import Dataset
3031
from labelbox.schema.enums import CollectionJobStatus
@@ -1614,6 +1615,9 @@ def _format_failed_rows(rows: List[str],
16141615
"Timed out waiting for clear_global_keys job to complete.")
16151616
time.sleep(sleep_time)
16161617

1618+
def get_catalog(self) -> Catalog:
1619+
return Catalog(client=self)
1620+
16171621
def get_catalog_slice(self, slice_id) -> CatalogSlice:
16181622
"""
16191623
Fetches a Catalog Slice by ID.

labelbox/schema/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,3 +23,4 @@
2323
import labelbox.schema.media_type
2424
import labelbox.schema.identifiables
2525
import labelbox.schema.identifiable
26+
import labelbox.schema.catalog

labelbox/schema/catalog.py

Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
from typing import Any, Dict, List, Optional, Union
2+
from labelbox.orm.db_object import experimental
3+
from labelbox.schema.export_filters import CatalogExportFilters, build_filters
4+
5+
from labelbox.schema.export_params import (CatalogExportParams,
6+
validate_catalog_export_params)
7+
from labelbox.schema.export_task import ExportTask
8+
from labelbox.schema.task import Task
9+
10+
from typing import TYPE_CHECKING
11+
if TYPE_CHECKING:
12+
from labelbox import Client
13+
14+
15+
class Catalog:
16+
client: "Client"
17+
18+
def __init__(self, client: 'Client'):
19+
self.client = client
20+
21+
def export_v2(
22+
self,
23+
task_name: Optional[str] = None,
24+
filters: Union[CatalogExportFilters, Dict[str, List[str]], None] = None,
25+
params: Optional[CatalogExportParams] = None,
26+
) -> Task:
27+
"""
28+
Creates a catalog export task with the given params, filters and returns the task.
29+
30+
>>> import labelbox as lb
31+
>>> client = lb.Client(<API_KEY>)
32+
>>> catalog = client.get_catalog()
33+
>>> task = catalog.export_v2(
34+
>>> filters={
35+
>>> "last_activity_at": ["2000-01-01 00:00:00", "2050-01-01 00:00:00"],
36+
>>> "label_created_at": ["2000-01-01 00:00:00", "2050-01-01 00:00:00"],
37+
>>> },
38+
>>> params={
39+
>>> "performance_details": False,
40+
>>> "label_details": True
41+
>>> })
42+
>>> task.wait_till_done()
43+
>>> task.result
44+
"""
45+
return self._export(task_name, filters, params, False)
46+
47+
@experimental
48+
def export(
49+
self,
50+
task_name: Optional[str] = None,
51+
filters: Union[CatalogExportFilters, Dict[str, List[str]], None] = None,
52+
params: Optional[CatalogExportParams] = None,
53+
) -> ExportTask:
54+
"""
55+
Creates a catalog export task with the given params, filters and returns the task.
56+
57+
>>> import labelbox as lb
58+
>>> client = lb.Client(<API_KEY>)
59+
>>> export_task = Catalog.export(
60+
>>> filters={
61+
>>> "last_activity_at": ["2000-01-01 00:00:00", "2050-01-01 00:00:00"],
62+
>>> "label_created_at": ["2000-01-01 00:00:00", "2050-01-01 00:00:00"],
63+
>>> },
64+
>>> params={
65+
>>> "performance_details": False,
66+
>>> "label_details": True
67+
>>> })
68+
>>> export_task.wait_till_done()
69+
>>>
70+
>>> # Return a JSON output string from the export task results/errors one by one:
71+
>>> def json_stream_handler(output: lb.JsonConverterOutput):
72+
>>> print(output.json_str)
73+
>>>
74+
>>> if export_task.has_errors():
75+
>>> export_task.get_stream(
76+
>>> converter=lb.JsonConverter(),
77+
>>> stream_type=lb.StreamType.ERRORS
78+
>>> ).start(stream_handler=lambda error: print(error))
79+
>>>
80+
>>> if export_task.has_result():
81+
>>> export_json = export_task.get_stream(
82+
>>> converter=lb.JsonConverter(),
83+
>>> stream_type=lb.StreamType.RESULT
84+
>>> ).start(stream_handler=json_stream_handler)
85+
"""
86+
task = self._export(task_name, filters, params, True)
87+
return ExportTask(task)
88+
89+
def _export(self,
90+
task_name: Optional[str] = None,
91+
filters: Union[CatalogExportFilters, Dict[str, List[str]],
92+
None] = None,
93+
params: Optional[CatalogExportParams] = None,
94+
streamable: bool = False) -> Task:
95+
96+
_params = params or CatalogExportParams({
97+
"attachments": False,
98+
"metadata_fields": False,
99+
"data_row_details": False,
100+
"project_details": False,
101+
"performance_details": False,
102+
"label_details": False,
103+
"media_type_override": None,
104+
"model_run_ids": None,
105+
"project_ids": None,
106+
"interpolated_frames": False,
107+
"all_projects": False,
108+
"all_model_runs": False,
109+
})
110+
validate_catalog_export_params(_params)
111+
112+
_filters = filters or CatalogExportFilters({
113+
"last_activity_at": None,
114+
"label_created_at": None,
115+
"data_row_ids": None,
116+
"global_keys": None,
117+
})
118+
119+
mutation_name = "exportDataRowsInCatalog"
120+
create_task_query_str = (
121+
f"mutation {mutation_name}PyApi"
122+
f"($input: ExportDataRowsInCatalogInput!)"
123+
f"{{{mutation_name}(input: $input){{taskId}}}}")
124+
125+
media_type_override = _params.get('media_type_override', None)
126+
query_params: Dict[str, Any] = {
127+
"input": {
128+
"taskName": task_name,
129+
"filters": {
130+
"searchQuery": {
131+
"scope": None,
132+
"query": None,
133+
}
134+
},
135+
"params": {
136+
"mediaTypeOverride":
137+
media_type_override.value
138+
if media_type_override is not None else None,
139+
"includeAttachments":
140+
_params.get('attachments', False),
141+
"includeMetadata":
142+
_params.get('metadata_fields', False),
143+
"includeDataRowDetails":
144+
_params.get('data_row_details', False),
145+
"includeProjectDetails":
146+
_params.get('project_details', False),
147+
"includePerformanceDetails":
148+
_params.get('performance_details', False),
149+
"includeLabelDetails":
150+
_params.get('label_details', False),
151+
"includeInterpolatedFrames":
152+
_params.get('interpolated_frames', False),
153+
"projectIds":
154+
_params.get('project_ids', None),
155+
"modelRunIds":
156+
_params.get('model_run_ids', None),
157+
"allProjects":
158+
_params.get('all_projects', False),
159+
"allModelRuns":
160+
_params.get('all_model_runs', False),
161+
},
162+
"streamable": streamable,
163+
}
164+
}
165+
166+
search_query = build_filters(self.client, _filters)
167+
query_params["input"]["filters"]["searchQuery"]["query"] = search_query
168+
169+
res = self.client.execute(create_task_query_str,
170+
query_params,
171+
error_log_key="errors")
172+
res = res[mutation_name]
173+
task_id = res["taskId"]
174+
return Task.get_task(self.client, task_id)

labelbox/schema/export_filters.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,10 @@ class DatasetExportFilters(SharedExportFilters):
5959
pass
6060

6161

62+
class CatalogExportFilters(SharedExportFilters):
63+
pass
64+
65+
6266
class DatarowExportFilters(BaseExportFilters):
6367
pass
6468

tests/integration/conftest.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from labelbox.orm import query
1818
from labelbox.pagination import PaginatedCollection
1919
from labelbox.schema.annotation_import import LabelImport
20+
from labelbox.schema.catalog import Catalog
2021
from labelbox.schema.enums import AnnotationImportState
2122
from labelbox.schema.invite import Invite
2223
from labelbox.schema.quality_mode import QualityMode
@@ -744,6 +745,35 @@ def run_dataset_export_v2_task(cls,
744745

745746
return task.result
746747

748+
@classmethod
749+
def run_catalog_export_v2_task(cls,
750+
client,
751+
num_retries=5,
752+
task_name=None,
753+
filters={},
754+
params={}):
755+
task = None
756+
params = params if params else {
757+
"performance_details": False,
758+
"label_details": True
759+
}
760+
catalog = client.get_catalog()
761+
while (num_retries > 0):
762+
763+
task = catalog.export_v2(task_name=task_name,
764+
filters=filters,
765+
params=params)
766+
task.wait_till_done()
767+
assert task.status == "COMPLETE"
768+
assert task.errors is None
769+
if len(task.result) == 0:
770+
num_retries -= 1
771+
time.sleep(5)
772+
else:
773+
break
774+
775+
return task.result
776+
747777

748778
@pytest.fixture
749779
def export_v2_test_helpers() -> Type[ExportV2Helpers]:
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
import pytest
2+
3+
4+
@pytest.mark.parametrize('data_rows', [3], indirect=True)
5+
def test_catalog_export_v2(client, export_v2_test_helpers, data_rows):
6+
datarow_filter_size = 2
7+
data_row_ids = [dr.uid for dr in data_rows]
8+
9+
params = {"performance_details": False, "label_details": False}
10+
filters = {"data_row_ids": data_row_ids[:datarow_filter_size]}
11+
12+
task_results = export_v2_test_helpers.run_catalog_export_v2_task(
13+
client, filters=filters, params=params)
14+
15+
# only 2 datarows should be exported
16+
assert len(task_results) == datarow_filter_size
17+
# only filtered datarows should be exported
18+
assert set([dr['data_row']['id'] for dr in task_results
19+
]) == set(data_row_ids[:datarow_filter_size])

0 commit comments

Comments
 (0)