Skip to content

Commit 38e1d0c

Browse files
author
Kevin Kim
committed
Merge branch 'develop' of https://github.com/Labelbox/labelbox-python into kkim/polish-tests
2 parents 58ac47b + 70ee77d commit 38e1d0c

File tree

10 files changed

+140
-15
lines changed

10 files changed

+140
-15
lines changed

docs/source/index.rst

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -173,14 +173,14 @@ DataRowMetadata
173173

174174
.. automodule:: labelbox.schema.data_row_metadata
175175
:members:
176-
:show-inheritance:
176+
:show-inheritance:
177177

178178
AnnotationImport
179179
----------------------------
180180

181181
.. automodule:: labelbox.schema.annotation_import
182182
:members:
183-
:show-inheritance:
183+
:show-inheritance:
184184

185185
Batch
186186
----------------------------
@@ -194,4 +194,19 @@ ResourceTag
194194

195195
.. automodule:: labelbox.schema.resource_tag
196196
:members:
197-
:show-inheritance:
197+
:show-inheritance:
198+
199+
Slice
200+
-----------------------------------------
201+
202+
.. automodule:: labelbox.schema.slice
203+
:members: Slice
204+
:exclude-members: CatalogSlice
205+
:show-inheritance:
206+
207+
CatalogSlice
208+
-----------------------------------------
209+
.. automodule:: labelbox.schema.slice
210+
:members: CatalogSlice
211+
:exclude-members: Slice
212+
:show-inheritance:

labelbox/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,3 +27,4 @@
2727
from labelbox.schema.resource_tag import ResourceTag
2828
from labelbox.schema.project_resource_tag import ProjectResourceTag
2929
from labelbox.schema.media_type import MediaType
30+
from labelbox.schema.slice import Slice, CatalogSlice

labelbox/client.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from labelbox.schema.user import User
3434
from labelbox.schema.project import Project
3535
from labelbox.schema.role import Role
36+
from labelbox.schema.slice import CatalogSlice
3637

3738
from labelbox.schema.media_type import MediaType
3839

@@ -963,7 +964,7 @@ def assign_global_keys_to_data_rows(
963964
timeout_seconds=60) -> Dict[str, Union[str, List[Any]]]:
964965
"""
965966
Assigns global keys to data rows.
966-
967+
967968
Args:
968969
A list of dicts containing data_row_id and global_key.
969970
Returns:
@@ -1211,3 +1212,27 @@ def _format_failed_rows(rows: List[str],
12111212
"Timed out waiting for get_data_rows_for_global_keys job to complete."
12121213
)
12131214
time.sleep(sleep_time)
1215+
1216+
def get_catalog_slice(self, slice_id) -> CatalogSlice:
1217+
"""
1218+
Fetches a Catalog Slice by ID.
1219+
1220+
Args:
1221+
slice_id (str): The ID of the Slice
1222+
Returns:
1223+
CatalogSlice
1224+
"""
1225+
query_str = """
1226+
query getSavedQueryPyApi($id: ID!) {
1227+
getSavedQuery(id: $id) {
1228+
id
1229+
name
1230+
description
1231+
filter
1232+
createdAt
1233+
updatedAt
1234+
}
1235+
}
1236+
"""
1237+
res = self.execute(query_str, {'id': slice_id})
1238+
return Entity.CatalogSlice(self, res['getSavedQuery'])

labelbox/data/annotation_types/geometry/mask.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,13 @@ def geometry(self) -> Dict[str, Tuple[int, int, int]]:
5454
external_polygons = self._extract_polygons_from_contours(
5555
external_contours)
5656
holes = self._extract_polygons_from_contours(holes)
57+
58+
if not external_polygons.is_valid:
59+
external_polygons = external_polygons.buffer(0)
60+
61+
if not holes.is_valid:
62+
holes = holes.buffer(0)
63+
5764
return external_polygons.difference(holes).__geo_interface__
5865

5966
def draw(self,
@@ -78,7 +85,6 @@ def draw(self,
7885
np.ndarray representing only this object
7986
as opposed to the mask that this object references which might have multiple objects determined by colors
8087
"""
81-
8288
mask = self.mask.value
8389
mask = np.alltrue(mask == self.color, axis=2).astype(np.uint8)
8490

labelbox/data/serialization/coco/instance_dataset.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@ def mask_to_coco_object_annotation(
2121
# This is going to fill any holes into the multipolygon
2222
# If you need to support holes use the panoptic data format
2323
shapely = annotation.value.shapely.simplify(1).buffer(0)
24-
2524
if shapely.is_empty:
2625
return
2726

labelbox/data/serialization/coco/panoptic_dataset.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,11 @@ def vector_to_coco_segment_info(canvas: np.ndarray,
2222
annotation: ObjectAnnotation,
2323
annotation_idx: int, image: CocoImage,
2424
category_id: int):
25+
2526
shapely = annotation.value.shapely
27+
if shapely.is_empty:
28+
return
29+
2630
xmin, ymin, xmax, ymax = shapely.bounds
2731
canvas = annotation.value.draw(height=image.height,
2832
width=image.width,
@@ -40,6 +44,9 @@ def mask_to_coco_segment_info(canvas: np.ndarray, annotation,
4044
color = id_to_rgb(annotation_idx)
4145
mask = annotation.value.draw(color=color)
4246
shapely = annotation.value.shapely
47+
if shapely.is_empty:
48+
return
49+
4350
xmin, ymin, xmax, ymax = shapely.bounds
4451
canvas = np.where(canvas == (0, 0, 0), mask, canvas)
4552
return SegmentInfo(id=annotation_idx,
@@ -70,20 +77,32 @@ def process_label(label: Label,
7077
for annotation_idx, annotation in enumerate(annotations[class_name]):
7178
categories[annotation.name] = hash_category_name(annotation.name)
7279
if isinstance(annotation.value, Mask):
73-
segment, canvas = (mask_to_coco_segment_info(
80+
coco_segment_info = mask_to_coco_segment_info(
7481
canvas, annotation, class_idx + 1,
75-
categories[annotation.name]))
82+
categories[annotation.name])
83+
84+
if coco_segment_info is None:
85+
# Filter out empty masks
86+
continue
87+
88+
segment, canvas = coco_segment_info
7689
segments.append(segment)
7790
is_thing[annotation.name] = 0
7891

7992
elif isinstance(annotation.value, (Polygon, Rectangle)):
80-
segment, canvas = vector_to_coco_segment_info(
93+
coco_vector_info = vector_to_coco_segment_info(
8194
canvas,
8295
annotation,
8396
annotation_idx=(class_idx if all_stuff else annotation_idx)
8497
+ 1,
8598
image=image,
8699
category_id=categories[annotation.name])
100+
101+
if coco_segment_info is None:
102+
# Filter out empty annotations
103+
continue
104+
105+
segment, canvas = coco_vector_info
87106
segments.append(segment)
88107
is_thing[annotation.name] = 1 - int(all_stuff)
89108

labelbox/orm/model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -377,6 +377,7 @@ class Entity(metaclass=EntityMeta):
377377
ProjectRole: Type[labelbox.ProjectRole]
378378
Project: Type[labelbox.Project]
379379
Batch: Type[labelbox.Batch]
380+
CatalogSlice: Type[labelbox.CatalogSlice]
380381

381382
@classmethod
382383
def _attributes_of_type(cls, attr_type):

labelbox/schema/slice.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
from labelbox.orm.db_object import DbObject
2+
from labelbox.orm.model import Field
3+
from labelbox.pagination import PaginatedCollection
4+
5+
6+
class Slice(DbObject):
7+
"""
8+
A Slice is a saved set of filters (saved query).
9+
This is an abstract class and should not be instantiated.
10+
11+
Attributes:
12+
name (datetime)
13+
description (datetime)
14+
created_at (datetime)
15+
updated_at (datetime)
16+
filter (json)
17+
"""
18+
name = Field.String("name")
19+
description = Field.String("description")
20+
created_at = Field.DateTime("created_at")
21+
updated_at = Field.DateTime("updated_at")
22+
filter = Field.Json("filter")
23+
24+
25+
class CatalogSlice(Slice):
26+
"""
27+
Represents a Slice used for filtering data rows in Catalog.
28+
"""
29+
30+
def get_data_row_ids(self) -> PaginatedCollection:
31+
"""
32+
Fetches all data row ids that match this Slice
33+
34+
Returns:
35+
A PaginatedCollection of data row ids
36+
"""
37+
query_str = """
38+
query getDataRowIdsBySavedQueryPyApi($id: ID!, $from: String, $first: Int!) {
39+
getDataRowIdsBySavedQuery(input: {
40+
savedQueryId: $id,
41+
after: $from
42+
first: $first
43+
}) {
44+
totalCount
45+
nodes
46+
pageInfo {
47+
endCursor
48+
hasNextPage
49+
}
50+
}
51+
}
52+
"""
53+
return PaginatedCollection(
54+
client=self.client,
55+
query=query_str,
56+
params={'id': self.uid},
57+
dereferencing=['getDataRowIdsBySavedQuery', 'nodes'],
58+
obj_class=lambda _, data_row_id: data_row_id,
59+
cursor_path=['getDataRowIdsBySavedQuery', 'pageInfo', 'endCursor'])

tests/integration/test_data_rows.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -352,7 +352,7 @@ def test_create_data_rows_with_invalid_metadata(dataset, image_url):
352352
DataRow.metadata_fields: fields
353353
}])
354354
task.wait_till_done()
355-
assert task.status == "COMPLETE"
355+
assert task.status == "FAILED"
356356
assert len(task.failed_data_rows) > 0
357357

358358

@@ -633,9 +633,10 @@ def test_data_row_bulk_creation_with_same_global_keys(dataset, sample_image):
633633
}])
634634

635635
task.wait_till_done()
636-
assert task.status == "COMPLETE"
636+
assert task.status == "FAILED"
637637
assert len(task.failed_data_rows) > 0
638638
assert len(list(dataset.data_rows())) == 0
639+
assert task.errors == "Import job failed"
639640

640641
task = dataset.create_data_rows([{
641642
DataRow.row_data: sample_image,

tests/integration/test_task.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import pytest
22

3-
from labelbox import DataRow, Task
3+
from labelbox import DataRow
44
from labelbox.schema.data_row_metadata import DataRowMetadataField
55

66
EMBEDDING_SCHEMA_ID = "ckpyije740000yxdk81pbgjdc"
@@ -21,13 +21,12 @@ def test_task_errors(dataset, image_url):
2121
]
2222
},
2323
])
24+
2425
assert task in client.get_user().created_tasks()
2526
task.wait_till_done()
26-
assert task.status == "COMPLETE"
27+
assert task.status == "FAILED"
2728
assert len(task.failed_data_rows) > 0
2829
assert task.errors is not None
29-
assert 'message' in task.errors[0]
30-
assert len(task.result) == 0
3130

3231

3332
def test_task_success_json(dataset, image_url):

0 commit comments

Comments
 (0)