Skip to content

Commit 1ea2ad3

Browse files
committed
Add catalog slice SDK support
1 parent 18fe1c1 commit 1ea2ad3

File tree

2 files changed

+95
-2
lines changed

2 files changed

+95
-2
lines changed

labelbox/schema/export_params.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import sys
22

3-
from typing import Optional
3+
from typing import Optional, List
44

55
from labelbox.schema.media_type import MediaType
66
if sys.version_info >= (3, 8):
@@ -22,6 +22,15 @@ class ProjectExportParams(DataRowParams):
2222
performance_details: Optional[bool]
2323

2424

25+
class CatalogSliceExportParams(DataRowParams):
26+
project_details: Optional[bool]
27+
label_details: Optional[bool]
28+
performance_details: Optional[bool]
29+
model_runs_ids: Optional[List[str]]
30+
projects_ids: Optional[List[str]]
31+
pass
32+
33+
2534
class ModelRunExportParams(DataRowParams):
2635
# TODO: Add model run fields
2736
pass

labelbox/schema/slice.py

Lines changed: 85 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
1+
from typing import Optional, List
2+
from labelbox.exceptions import ResourceNotFoundError
13
from labelbox.orm.db_object import DbObject
2-
from labelbox.orm.model import Field
4+
from labelbox.orm.model import Entity, Field
35
from labelbox.pagination import PaginatedCollection
6+
from labelbox.schema.export_params import CatalogSliceExportParams
7+
from labelbox.schema.task import Task
8+
from labelbox.schema.user import User
49

510

611
class Slice(DbObject):
@@ -59,6 +64,85 @@ def get_data_row_ids(self) -> PaginatedCollection:
5964
obj_class=lambda _, data_row_id: data_row_id,
6065
cursor_path=['getDataRowIdsBySavedQuery', 'pageInfo', 'endCursor'])
6166

67+
def export_v2(self,
68+
task_name: Optional[str] = None,
69+
params: Optional[CatalogSliceExportParams] = None) -> Task:
70+
"""
71+
Creates a slice export task with the given params and returns the task.
72+
>>> slice = client.get_catalog_slice("SLICE_ID")
73+
>>> task = slice.export_v2(
74+
>>> params={"performance_details": False, "label_details": True}
75+
>>> )
76+
>>> task.wait_till_done()
77+
>>> task.result
78+
"""
79+
80+
_params = params or CatalogSliceExportParams({
81+
"attachments": False,
82+
"metadata_fields": False,
83+
"data_row_details": False,
84+
"project_details": False,
85+
"performance_details": False,
86+
"label_details": False,
87+
"media_type_override": None,
88+
"model_runs_ids": None,
89+
"projects_ids": None,
90+
})
91+
92+
mutation_name = "exportDataRowsInSlice"
93+
create_task_query_str = """mutation exportDataRowsInSlicePyApi($input: ExportDataRowsInSliceInput!){
94+
%s(input: $input) {taskId} }
95+
""" % (mutation_name)
96+
97+
media_type_override = _params.get('media_type_override', None)
98+
query_params = {
99+
"input": {
100+
"taskName": task_name,
101+
"filters": {
102+
"sliceId": self.uid
103+
},
104+
"params": {
105+
"mediaTypeOverride":
106+
media_type_override.value
107+
if media_type_override is not None else None,
108+
"includeAttachments":
109+
_params.get('attachments', False),
110+
"includeMetadata":
111+
_params.get('metadata_fields', False),
112+
"includeDataRowDetails":
113+
_params.get('data_row_details', False),
114+
"includeProjectDetails":
115+
_params.get('project_details', False),
116+
"includePerformanceDetails":
117+
_params.get('performance_details', False),
118+
"includeLabelDetails":
119+
_params.get('label_details', False),
120+
"projectIds":
121+
_params.get('projects_ids', None),
122+
"modelRunIds":
123+
_params.get('model_runs_ids', None),
124+
},
125+
}
126+
}
127+
128+
res = self.client.execute(
129+
create_task_query_str,
130+
query_params,
131+
)
132+
res = res[mutation_name]
133+
task_id = res["taskId"]
134+
user: User = self.client.get_user()
135+
tasks: List[Task] = list(
136+
user.created_tasks(where=Entity.Task.uid == task_id))
137+
# Cache user in a private variable as the relationship can't be
138+
# resolved due to server-side limitations (see Task.created_by)
139+
# for more info.
140+
if len(tasks) != 1:
141+
raise ResourceNotFoundError(Entity.Task, task_id)
142+
task: Task = tasks[0]
143+
task._user = user
144+
return task
145+
62146

63147
class ModelSlice(Slice):
64148
"""

0 commit comments

Comments
 (0)