Skip to content

Commit 254aafa

Browse files
committed
Add Slice object and function to export data row IDs by slice
1 parent 226944a commit 254aafa

File tree

3 files changed

+76
-1
lines changed

3 files changed

+76
-1
lines changed

labelbox/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,3 +27,4 @@
2727
from labelbox.schema.resource_tag import ResourceTag
2828
from labelbox.schema.project_resource_tag import ProjectResourceTag
2929
from labelbox.schema.media_type import MediaType
30+
from labelbox.schema.slice import Slice

labelbox/client.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from labelbox.schema.user import User
3434
from labelbox.schema.project import Project
3535
from labelbox.schema.role import Role
36+
from labelbox.schema.slice import Slice
3637

3738
from labelbox.schema.media_type import MediaType
3839

@@ -963,7 +964,7 @@ def assign_global_keys_to_data_rows(
963964
timeout_seconds=60) -> Dict[str, Union[str, List[Any]]]:
964965
"""
965966
Assigns global keys to data rows.
966-
967+
967968
Args:
968969
A list of dicts containing data_row_id and global_key.
969970
Returns:
@@ -1211,3 +1212,27 @@ def _format_failed_rows(rows: List[str],
12111212
"Timed out waiting for get_data_rows_for_global_keys job to complete."
12121213
)
12131214
time.sleep(sleep_time)
1215+
1216+
def get_slice(self, slice_id) -> Slice:
1217+
"""
1218+
Fetches a Slice by ID.
1219+
1220+
Args:
1221+
slice_id (str): The ID of the Slice
1222+
Returns:
1223+
Slice
1224+
"""
1225+
query_str = """
1226+
query getSavedQueryPyApi($id: ID!) {
1227+
getSavedQuery(id: $id) {
1228+
id
1229+
name
1230+
description
1231+
filter
1232+
createdAt
1233+
updatedAt
1234+
}
1235+
}
1236+
"""
1237+
res = self.execute(query_str, {'id': slice_id})
1238+
return Entity.Slice(self, res['getSavedQuery'])

labelbox/schema/slice.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
from labelbox.orm.db_object import DbObject
2+
from labelbox.orm.model import Field
3+
4+
5+
class Slice(DbObject):
6+
name = Field.String("name")
7+
description = Field.String("description")
8+
created_at = Field.DateTime("created_at")
9+
updated_at = Field.DateTime("updated_at")
10+
filter = Field.Json("filter")
11+
12+
def get_data_row_ids(self) -> list[str]:
13+
"""
14+
Fetches all data row ids that match this Slice
15+
16+
Returns:
17+
A list of data row ids
18+
"""
19+
query_str = """
20+
query getDataRowIdsBySavedQueryPyApi($id: ID!, $after: String) {
21+
getDataRowIdsBySavedQuery(input: {
22+
savedQueryId: $id,
23+
after: $after
24+
}) {
25+
totalCount
26+
nodes
27+
pageInfo {
28+
endCursor
29+
hasNextPage
30+
}
31+
}
32+
}
33+
"""
34+
data_row_ids = []
35+
total_count = 0
36+
end_cursor = None
37+
has_next_page = True
38+
while has_next_page:
39+
res = self.client.execute(query_str, {
40+
'id': self.uid,
41+
'after': end_cursor
42+
})['getDataRowIdsBySavedQuery']
43+
data_row_ids = data_row_ids + res['nodes']
44+
total_count = res['totalCount']
45+
has_next_page = res['pageInfo']['hasNextPage']
46+
end_cursor = res['pageInfo']['endCursor']
47+
48+
assert total_count == len(data_row_ids)
49+
return data_row_ids

0 commit comments

Comments
 (0)