Skip to content

Commit 8dc2466

Browse files
author
Matt Sokoloff
committed
Merge branch 'develop' of https://github.com/Labelbox/labelbox-python into ms/al-3814
2 parents 6b9f0e3 + c7b0623 commit 8dc2466

File tree

10 files changed

+119
-15
lines changed

10 files changed

+119
-15
lines changed

docs/source/index.rst

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -173,14 +173,14 @@ DataRowMetadata
173173

174174
.. automodule:: labelbox.schema.data_row_metadata
175175
:members:
176-
:show-inheritance:
176+
:show-inheritance:
177177

178178
AnnotationImport
179179
----------------------------
180180

181181
.. automodule:: labelbox.schema.annotation_import
182182
:members:
183-
:show-inheritance:
183+
:show-inheritance:
184184

185185
Batch
186186
----------------------------
@@ -194,4 +194,19 @@ ResourceTag
194194

195195
.. automodule:: labelbox.schema.resource_tag
196196
:members:
197-
:show-inheritance:
197+
:show-inheritance:
198+
199+
Slice
200+
-----------------------------------------
201+
202+
.. automodule:: labelbox.schema.slice
203+
:members: Slice
204+
:exclude-members: CatalogSlice
205+
:show-inheritance:
206+
207+
CatalogSlice
208+
-----------------------------------------
209+
.. automodule:: labelbox.schema.slice
210+
:members: CatalogSlice
211+
:exclude-members: Slice
212+
:show-inheritance:

labelbox/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,3 +27,4 @@
2727
from labelbox.schema.resource_tag import ResourceTag
2828
from labelbox.schema.project_resource_tag import ProjectResourceTag
2929
from labelbox.schema.media_type import MediaType
30+
from labelbox.schema.slice import Slice, CatalogSlice

labelbox/client.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from labelbox.schema.user import User
3434
from labelbox.schema.project import Project
3535
from labelbox.schema.role import Role
36+
from labelbox.schema.slice import CatalogSlice
3637

3738
from labelbox.schema.media_type import MediaType
3839

@@ -963,7 +964,7 @@ def assign_global_keys_to_data_rows(
963964
timeout_seconds=60) -> Dict[str, Union[str, List[Any]]]:
964965
"""
965966
Assigns global keys to data rows.
966-
967+
967968
Args:
968969
A list of dicts containing data_row_id and global_key.
969970
Returns:
@@ -1211,3 +1212,27 @@ def _format_failed_rows(rows: List[str],
12111212
"Timed out waiting for get_data_rows_for_global_keys job to complete."
12121213
)
12131214
time.sleep(sleep_time)
1215+
1216+
def get_catalog_slice(self, slice_id) -> CatalogSlice:
1217+
"""
1218+
Fetches a Catalog Slice by ID.
1219+
1220+
Args:
1221+
slice_id (str): The ID of the Slice
1222+
Returns:
1223+
CatalogSlice
1224+
"""
1225+
query_str = """
1226+
query getSavedQueryPyApi($id: ID!) {
1227+
getSavedQuery(id: $id) {
1228+
id
1229+
name
1230+
description
1231+
filter
1232+
createdAt
1233+
updatedAt
1234+
}
1235+
}
1236+
"""
1237+
res = self.execute(query_str, {'id': slice_id})
1238+
return Entity.CatalogSlice(self, res['getSavedQuery'])

labelbox/orm/model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -377,6 +377,7 @@ class Entity(metaclass=EntityMeta):
377377
ProjectRole: Type[labelbox.ProjectRole]
378378
Project: Type[labelbox.Project]
379379
Batch: Type[labelbox.Batch]
380+
CatalogSlice: Type[labelbox.CatalogSlice]
380381

381382
@classmethod
382383
def _attributes_of_type(cls, attr_type):

labelbox/schema/slice.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
from labelbox.orm.db_object import DbObject
2+
from labelbox.orm.model import Field
3+
from labelbox.pagination import PaginatedCollection
4+
5+
6+
class Slice(DbObject):
7+
"""
8+
A Slice is a saved set of filters (saved query).
9+
This is an abstract class and should not be instantiated.
10+
11+
Attributes:
12+
name (datetime)
13+
description (datetime)
14+
created_at (datetime)
15+
updated_at (datetime)
16+
filter (json)
17+
"""
18+
name = Field.String("name")
19+
description = Field.String("description")
20+
created_at = Field.DateTime("created_at")
21+
updated_at = Field.DateTime("updated_at")
22+
filter = Field.Json("filter")
23+
24+
25+
class CatalogSlice(Slice):
26+
"""
27+
Represents a Slice used for filtering data rows in Catalog.
28+
"""
29+
30+
def get_data_row_ids(self) -> PaginatedCollection:
31+
"""
32+
Fetches all data row ids that match this Slice
33+
34+
Returns:
35+
A PaginatedCollection of data row ids
36+
"""
37+
query_str = """
38+
query getDataRowIdsBySavedQueryPyApi($id: ID!, $from: String, $first: Int!) {
39+
getDataRowIdsBySavedQuery(input: {
40+
savedQueryId: $id,
41+
after: $from
42+
first: $first
43+
}) {
44+
totalCount
45+
nodes
46+
pageInfo {
47+
endCursor
48+
hasNextPage
49+
}
50+
}
51+
}
52+
"""
53+
return PaginatedCollection(
54+
client=self.client,
55+
query=query_str,
56+
params={'id': self.uid},
57+
dereferencing=['getDataRowIdsBySavedQuery', 'nodes'],
58+
obj_class=lambda _, data_row_id: data_row_id,
59+
cursor_path=['getDataRowIdsBySavedQuery', 'pageInfo', 'endCursor'])

tests/integration/conftest.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from labelbox.orm import query
1616
from labelbox.pagination import PaginatedCollection
1717
from labelbox.schema.annotation_import import LabelImport
18+
from labelbox.schema.enums import AnnotationImportState
1819
from labelbox.schema.invite import Invite
1920
from labelbox.schema.queue_mode import QueueMode
2021
from labelbox.schema.user import User
@@ -334,6 +335,7 @@ def create_label():
334335
upload_task = LabelImport.create_from_objects(
335336
client, project.uid, f'label-import-{uuid.uuid4()}', predictions)
336337
upload_task.wait_until_done(sleep_time_seconds=5)
338+
assert upload_task.state == AnnotationImportState.FINISHED
337339

338340
project.create_label = create_label
339341
project.create_label()

tests/integration/test_data_rows.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -151,14 +151,14 @@ def test_data_row_bulk_creation(dataset, rand_gen, image_url):
151151
@pytest.mark.slow
152152
def test_data_row_large_bulk_creation(dataset, image_url):
153153
# Do a longer task and expect it not to be complete immediately
154-
n_local = 2000
155-
n_urls = 250
154+
n_urls = 1000
155+
n_local = 250
156156
with NamedTemporaryFile() as fp:
157157
fp.write("Test data".encode())
158158
fp.flush()
159159
task = dataset.create_data_rows([{
160160
DataRow.row_data: image_url
161-
}] * n_local + [fp.name] * n_urls)
161+
}] * n_urls + [fp.name] * n_local)
162162
task.wait_till_done()
163163
assert task.status == "COMPLETE"
164164
assert len(list(dataset.data_rows())) == n_local + n_urls
@@ -353,7 +353,7 @@ def test_create_data_rows_with_invalid_metadata(dataset, image_url):
353353
DataRow.metadata_fields: fields
354354
}])
355355
task.wait_till_done()
356-
assert task.status == "COMPLETE"
356+
assert task.status == "FAILED"
357357
assert len(task.failed_data_rows) > 0
358358

359359

@@ -634,9 +634,10 @@ def test_data_row_bulk_creation_with_same_global_keys(dataset, sample_image):
634634
}])
635635

636636
task.wait_till_done()
637-
assert task.status == "COMPLETE"
637+
assert task.status == "FAILED"
638638
assert len(task.failed_data_rows) > 0
639639
assert len(list(dataset.data_rows())) == 0
640+
assert task.errors == "Import job failed"
640641

641642
task = dataset.create_data_rows([{
642643
DataRow.row_data: sample_image,

tests/integration/test_delegated_access.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from labelbox import Client
77

88

9-
@pytest.mark.skipif(os.environ.get("DA_GCP_LABELBOX_API_KEY") is None,
9+
@pytest.mark.skipif(not os.environ.get('DA_GCP_LABELBOX_API_KEY'),
1010
reason="DA_GCP_LABELBOX_API_KEY not found")
1111
def test_default_integration():
1212
"""
@@ -28,7 +28,7 @@ def test_default_integration():
2828
ds.delete()
2929

3030

31-
@pytest.mark.skipif(os.environ.get("DA_GCP_LABELBOX_API_KEY") is None,
31+
@pytest.mark.skipif(not os.environ.get("DA_GCP_LABELBOX_API_KEY"),
3232
reason="DA_GCP_LABELBOX_API_KEY not found")
3333
def test_non_default_integration():
3434
"""

tests/integration/test_label.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ def test_label_export(configured_project_with_label):
4646
# TODO: Skipping this test in staging due to label not updating
4747
@pytest.mark.skipif(condition=os.environ['LABELBOX_TEST_ENVIRON'] == "onprem" or
4848
os.environ['LABELBOX_TEST_ENVIRON'] == "staging" or
49+
os.environ['LABELBOX_TEST_ENVIRON'] == "local" or
4950
os.environ['LABELBOX_TEST_ENVIRON'] == "custom",
5051
reason="does not work for onprem")
5152
def test_label_update(configured_project_with_label):

tests/integration/test_task.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import pytest
22

3-
from labelbox import DataRow, Task
3+
from labelbox import DataRow
44
from labelbox.schema.data_row_metadata import DataRowMetadataField
55

66
EMBEDDING_SCHEMA_ID = "ckpyije740000yxdk81pbgjdc"
@@ -21,13 +21,12 @@ def test_task_errors(dataset, image_url):
2121
]
2222
},
2323
])
24+
2425
assert task in client.get_user().created_tasks()
2526
task.wait_till_done()
26-
assert task.status == "COMPLETE"
27+
assert task.status == "FAILED"
2728
assert len(task.failed_data_rows) > 0
2829
assert task.errors is not None
29-
assert 'message' in task.errors[0]
30-
assert len(task.result) == 0
3130

3231

3332
def test_task_success_json(dataset, image_url):

0 commit comments

Comments
 (0)