Skip to content

Commit ba106df

Browse files
authored
Merge pull request #931 from Labelbox/kozikkam/AL-4886
[AL-4886] Introduce ModelSlice
2 parents 6816ffa + 5c18470 commit ba106df

File tree

4 files changed

+65
-2
lines changed

4 files changed

+65
-2
lines changed

labelbox/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,6 @@
2727
from labelbox.schema.resource_tag import ResourceTag
2828
from labelbox.schema.project_resource_tag import ProjectResourceTag
2929
from labelbox.schema.media_type import MediaType
30-
from labelbox.schema.slice import Slice, CatalogSlice
30+
from labelbox.schema.slice import Slice, CatalogSlice, ModelSlice
3131
from labelbox.schema.queue_mode import QueueMode
3232
from labelbox.schema.task_queue import TaskQueue

labelbox/client.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
from labelbox.schema.user import User
3434
from labelbox.schema.project import Project
3535
from labelbox.schema.role import Role
36-
from labelbox.schema.slice import CatalogSlice
36+
from labelbox.schema.slice import CatalogSlice, ModelSlice
3737
from labelbox.schema.queue_mode import QueueMode
3838

3939
from labelbox.schema.media_type import MediaType, get_media_type_validation_error
@@ -1384,3 +1384,27 @@ def get_catalog_slice(self, slice_id) -> CatalogSlice:
13841384
"""
13851385
res = self.execute(query_str, {'id': slice_id})
13861386
return Entity.CatalogSlice(self, res['getSavedQuery'])
1387+
1388+
def get_model_slice(self, slice_id) -> ModelSlice:
1389+
"""
1390+
Fetches a Model Slice by ID.
1391+
1392+
Args:
1393+
slice_id (str): The ID of the Slice
1394+
Returns:
1395+
ModelSlice
1396+
"""
1397+
query_str = """
1398+
query getSavedQueryPyApi($id: ID!) {
1399+
getSavedQuery(id: $id) {
1400+
id
1401+
name
1402+
description
1403+
filter
1404+
createdAt
1405+
updatedAt
1406+
}
1407+
}
1408+
"""
1409+
res = self.execute(query_str, {"id": slice_id})
1410+
return Entity.ModelSlice(self, res["getSavedQuery"])

labelbox/orm/model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -378,6 +378,7 @@ class Entity(metaclass=EntityMeta):
378378
Project: Type[labelbox.Project]
379379
Batch: Type[labelbox.Batch]
380380
CatalogSlice: Type[labelbox.CatalogSlice]
381+
ModelSlice: Type[labelbox.ModelSlice]
381382
TaskQueue: Type[labelbox.TaskQueue]
382383

383384
@classmethod

labelbox/schema/slice.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ class Slice(DbObject):
1515
updated_at (datetime)
1616
filter (json)
1717
"""
18+
1819
name = Field.String("name")
1920
description = Field.String("description")
2021
created_at = Field.DateTime("created_at")
@@ -57,3 +58,40 @@ def get_data_row_ids(self) -> PaginatedCollection:
5758
dereferencing=['getDataRowIdsBySavedQuery', 'nodes'],
5859
obj_class=lambda _, data_row_id: data_row_id,
5960
cursor_path=['getDataRowIdsBySavedQuery', 'pageInfo', 'endCursor'])
61+
62+
63+
class ModelSlice(Slice):
64+
"""
65+
Represents a Slice used for filtering data rows in Model.
66+
"""
67+
68+
def get_data_row_ids(self) -> PaginatedCollection:
69+
"""
70+
Fetches all data row ids that match this Slice
71+
72+
Returns:
73+
A PaginatedCollection of data row ids
74+
"""
75+
query_str = """
76+
query getDataRowIdsBySavedQueryPyApi($id: ID!, $from: String, $first: Int!) {
77+
getDataRowIdsBySavedQuery(input: {
78+
savedQueryId: $id,
79+
after: $from
80+
first: $first
81+
}) {
82+
totalCount
83+
nodes
84+
pageInfo {
85+
endCursor
86+
hasNextPage
87+
}
88+
}
89+
}
90+
"""
91+
return PaginatedCollection(
92+
client=self.client,
93+
query=query_str,
94+
params={'id': self.uid},
95+
dereferencing=['getDataRowIdsBySavedQuery', 'nodes'],
96+
obj_class=lambda _, data_row_id: data_row_id,
97+
cursor_path=['getDataRowIdsBySavedQuery', 'pageInfo', 'endCursor'])

0 commit comments

Comments
 (0)