Skip to content

Commit 676de2f

Browse files
author
Val Brodsky
committed
Update ModelSlice: fix get_data_row_ids and add get_data_row_identifiers
1 parent e181218 commit 676de2f

File tree

2 files changed

+76
-31
lines changed

2 files changed

+76
-31
lines changed

labelbox/client.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1696,6 +1696,10 @@ def get_model_slice(self, slice_id) -> ModelSlice:
16961696
}
16971697
"""
16981698
res = self.execute(query_str, {"id": slice_id})
1699+
if res is None or res["getSavedQuery"] is None:
1700+
raise labelbox.exceptions.ResourceNotFoundError(
1701+
ModelSlice, slice_id)
1702+
16991703
return Entity.ModelSlice(self, res["getSavedQuery"])
17001704

17011705
def delete_feature_schema_from_ontology(

labelbox/schema/slice.py

Lines changed: 72 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,6 @@ class Slice(DbObject):
2929
updated_at = Field.DateTime("updated_at")
3030
filter = Field.Json("filter")
3131

32-
33-
class CatalogSlice(Slice):
34-
"""
35-
Represents a Slice used for filtering data rows in Catalog.
36-
"""
37-
3832
@dataclass
3933
class DataRowIdAndGlobalKey:
4034
id: UniqueId
@@ -44,6 +38,18 @@ def __init__(self, id: str, global_key: Optional[str]):
4438
self.id = UniqueId(id)
4539
self.global_key = GlobalKey(global_key) if global_key else None
4640

41+
def to_hash(self):
42+
return {
43+
"id": self.id.key,
44+
"global_key": self.global_key.key if self.global_key else None
45+
}
46+
47+
48+
class CatalogSlice(Slice):
49+
"""
50+
Represents a Slice used for filtering data rows in Catalog.
51+
"""
52+
4753
def get_data_row_ids(self) -> PaginatedCollection:
4854
"""
4955
Fetches all data row ids that match this Slice
@@ -75,7 +81,7 @@ def get_data_row_ids(self) -> PaginatedCollection:
7581
return PaginatedCollection(
7682
client=self.client,
7783
query=query_str,
78-
params={'id': self.uid},
84+
params={'id': str(self.uid)},
7985
dereferencing=['getDataRowIdsBySavedQuery', 'nodes'],
8086
obj_class=lambda _, data_row_id: data_row_id,
8187
cursor_path=['getDataRowIdsBySavedQuery', 'pageInfo', 'endCursor'])
@@ -85,7 +91,7 @@ def get_data_row_identifiers(self) -> PaginatedCollection:
8591
Fetches all data row ids and global keys (where defined) that match this Slice
8692
8793
Returns:
88-
A PaginatedCollection of data row ids
94+
A PaginatedCollection of Slice.DataRowIdAndGlobalKey
8995
"""
9096
query_str = """
9197
query getDataRowIdenfifiersBySavedQueryPyApi($id: ID!, $from: String, $first: Int!) {
@@ -112,9 +118,9 @@ def get_data_row_identifiers(self) -> PaginatedCollection:
112118
query=query_str,
113119
params={'id': str(self.uid)},
114120
dereferencing=['getDataRowIdentifiersBySavedQuery', 'nodes'],
115-
obj_class=lambda _, data_row_id_and_gk: CatalogSlice.
116-
DataRowIdAndGlobalKey(data_row_id_and_gk.get('id'),
117-
data_row_id_and_gk.get('globalKey', None)),
121+
obj_class=lambda _, data_row_id_and_gk: Slice.DataRowIdAndGlobalKey(
122+
data_row_id_and_gk.get('id'),
123+
data_row_id_and_gk.get('globalKey', None)),
118124
cursor_path=[
119125
'getDataRowIdentifiersBySavedQuery', 'pageInfo', 'endCursor'
120126
])
@@ -224,33 +230,68 @@ class ModelSlice(Slice):
224230
Represents a Slice used for filtering data rows in Model.
225231
"""
226232

233+
@classmethod
234+
def query_str(cls):
235+
query_str = """
236+
query getDataRowIdenfifiersBySavedModelQueryPyApi($id: ID!, $from: DataRowIdentifierCursorInput, $first: Int!) {
237+
getDataRowIdentifiersBySavedModelQuery(input: {
238+
savedQueryId: $id,
239+
after: $from
240+
first: $first
241+
}) {
242+
totalCount
243+
nodes
244+
{
245+
id
246+
globalKey
247+
}
248+
pageInfo {
249+
endCursor {
250+
dataRowId
251+
globalKey
252+
}
253+
hasNextPage
254+
}
255+
}
256+
}
257+
"""
258+
return query_str
259+
227260
def get_data_row_ids(self) -> PaginatedCollection:
228261
"""
229262
Fetches all data row ids that match this Slice
230263
231264
Returns:
232265
A PaginatedCollection of data row ids
233266
"""
234-
query_str = """
235-
query getDataRowIdsBySavedQueryPyApi($id: ID!, $from: String, $first: Int!) {
236-
getDataRowIdsBySavedQuery(input: {
237-
savedQueryId: $id,
238-
after: $from
239-
first: $first
240-
}) {
241-
totalCount
242-
nodes
243-
pageInfo {
244-
endCursor
245-
hasNextPage
246-
}
247-
}
248-
}
267+
return PaginatedCollection(
268+
client=self.client,
269+
query=ModelSlice.query_str(),
270+
params={'id': str(self.uid)},
271+
dereferencing=['getDataRowIdentifiersBySavedModelQuery', 'nodes'],
272+
obj_class=lambda _, data_row_id_and_gk: data_row_id_and_gk.get('id'
273+
),
274+
cursor_path=[
275+
'getDataRowIdentifiersBySavedModelQuery', 'pageInfo',
276+
'endCursor'
277+
])
278+
279+
def get_data_row_identifiers(self) -> PaginatedCollection:
280+
"""
281+
Fetches all data row ids and global keys (where defined) that match this Slice
282+
283+
Returns:
284+
A PaginatedCollection of Slice.DataRowIdAndGlobalKey
249285
"""
250286
return PaginatedCollection(
251287
client=self.client,
252-
query=query_str,
253-
params={'id': self.uid},
254-
dereferencing=['getDataRowIdsBySavedQuery', 'nodes'],
255-
obj_class=lambda _, data_row_id: data_row_id,
256-
cursor_path=['getDataRowIdsBySavedQuery', 'pageInfo', 'endCursor'])
288+
query=ModelSlice.query_str(),
289+
params={'id': str(self.uid)},
290+
dereferencing=['getDataRowIdentifiersBySavedModelQuery', 'nodes'],
291+
obj_class=lambda _, data_row_id_and_gk: Slice.DataRowIdAndGlobalKey(
292+
data_row_id_and_gk.get('id'),
293+
data_row_id_and_gk.get('globalKey', None)),
294+
cursor_path=[
295+
'getDataRowIdentifiersBySavedModelQuery', 'pageInfo',
296+
'endCursor'
297+
])

0 commit comments

Comments
 (0)