Skip to content

Commit cd3caf2

Browse files
author
Matt Sokoloff
committed
update tests
1 parent 12b278d commit cd3caf2

File tree

10 files changed

+27
-30
lines changed

10 files changed

+27
-30
lines changed

labelbox/data/generator.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import threading
33
from queue import Queue
44
from typing import Any, Iterable
5+
from concurrent.futures import ThreadPoolExecutor
56

67
logger = logging.getLogger(__name__)
78

@@ -32,10 +33,7 @@ class PrefetchGenerator:
3233
Useful for modifying the generator results based on data from a network
3334
"""
3435

35-
def __init__(self,
36-
data: Iterable[Any],
37-
prefetch_limit=20,
38-
max_concurrency=4):
36+
def __init__(self, data: Iterable[Any], prefetch_limit=20, num_executors=4):
3937
if isinstance(data, (list, tuple)):
4038
self._data = (r for r in data)
4139
else:
@@ -44,14 +42,13 @@ def __init__(self,
4442
self.queue = Queue(prefetch_limit)
4543
self._data = ThreadSafeGen(self._data)
4644
self.completed_threads = 0
47-
self.max_concurrency = max_concurrency
48-
self.threads = [
49-
threading.Thread(target=self.fill_queue)
50-
for _ in range(max_concurrency)
51-
]
52-
for thread in self.threads:
53-
thread.daemon = True
54-
thread.start()
45+
# Can only iterate over once it the queue.get hangs forever.
46+
self.done = False
47+
self.num_executors = num_executors
48+
with ThreadPoolExecutor(max_workers=num_executors) as executor:
49+
self.futures = [
50+
executor.submit(self.fill_queue) for _ in range(num_executors)
51+
]
5552

5653
def _process(self, value) -> Any:
5754
raise NotImplementedError("Abstract method needs to be implemented")
@@ -73,10 +70,13 @@ def __iter__(self):
7370
return self
7471

7572
def __next__(self) -> Any:
73+
if self.done:
74+
raise StopIteration
7675
value = self.queue.get()
7776
while value is None:
7877
self.completed_threads += 1
79-
if self.completed_threads == self.max_concurrency:
78+
if self.completed_threads == self.num_executors:
79+
self.done = True
8080
raise StopIteration
8181
value = self.queue.get()
8282
return value

labelbox/data/serialization/ndjson/classification.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -189,10 +189,6 @@ def from_common(
189189
raise TypeError(
190190
f"Unable to convert object to MAL format. `{type(annotation.value)}`"
191191
)
192-
if len(annotation.classifications):
193-
raise ValueError(
194-
"Nested classifications not supported by this format")
195-
196192
return classify_obj.from_common(annotation.value, annotation.schema_id,
197193
annotation.extra, data)
198194

labelbox/data/serialization/ndjson/converter.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import logging
22
from typing import Any, Dict, Generator, Iterable
33

4-
from ...annotation_types.collection import LabelData, LabelGenerator
4+
from ...annotation_types.collection import LabelCollection, LabelGenerator
55
from .label import NDLabel
66

77
logger = logging.getLogger(__name__)
@@ -23,7 +23,8 @@ def deserialize(json_data: Iterable[Dict[str, Any]]) -> LabelGenerator:
2323
return data.to_common()
2424

2525
@staticmethod
26-
def serialize(labels: LabelData) -> Generator[Dict[str, Any], None, None]:
26+
def serialize(
27+
labels: LabelCollection) -> Generator[Dict[str, Any], None, None]:
2728
"""
2829
Converts a labelbox common object to the labelbox ndjson format (prediction import format)
2930
@@ -32,7 +33,7 @@ def serialize(labels: LabelData) -> Generator[Dict[str, Any], None, None]:
3233
We will continue to improve the error messages and add helper functions to deal with this.
3334
3435
Args:
35-
labels: Either a LabelCollection or a LabelGenerator
36+
labels: Either a LabelList or a LabelGenerator
3637
Returns:
3738
A generator for accessing the ndjson representation of the data
3839
"""

labelbox/data/serialization/ndjson/label.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from pydantic import BaseModel
77

88
from ...annotation_types.annotation import ClassificationAnnotation, ObjectAnnotation, VideoClassificationAnnotation
9-
from ...annotation_types.collection import LabelData, LabelGenerator
9+
from ...annotation_types.collection import LabelCollection, LabelGenerator
1010
from ...annotation_types.data import RasterData, TextData, VideoData
1111
from ...annotation_types.label import Label
1212
from ...annotation_types.ner import TextEntity
@@ -25,7 +25,8 @@ def to_common(self) -> LabelGenerator:
2525
data=self._generate_annotations(grouped_annotations))
2626

2727
@classmethod
28-
def from_common(cls, data: LabelData) -> Generator["NDLabel", None, None]:
28+
def from_common(cls,
29+
data: LabelCollection) -> Generator["NDLabel", None, None]:
2930
for label in data:
3031
yield from cls._create_non_video_annotations(label)
3132
yield from cls._create_video_annotations(label)

labelbox/data/serialization/ndjson/objects.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ class NDMask(NDBaseObject):
123123

124124
def to_common(self) -> Mask:
125125
return Mask(mask=RasterData(url=self.mask.instanceURI),
126-
color_rgb=self.mask.colorRGB)
126+
color=self.mask.colorRGB)
127127

128128
@classmethod
129129
def from_common(cls, mask: Mask,
@@ -134,8 +134,7 @@ def from_common(cls, mask: Mask,
134134
raise ValueError(
135135
"Mask does not have a url. Use `LabelGenerator.add_url_to_masks`, `LabelCollection.add_url_to_masks`, or `Label.add_url_to_masks`."
136136
)
137-
return cls(mask=_Mask(instanceURI=mask.mask.url,
138-
colorRGB=mask.color_rgb),
137+
return cls(mask=_Mask(instanceURI=mask.mask.url, colorRGB=mask.color),
139138
dataRow=DataRow(id=data.uid),
140139
schema_id=schema_id,
141140
uuid=extra.get('uuid'),

tests/data/serialization/ndjson/test_classification.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,6 @@ def test_classification():
77
with open('tests/data/assets/ndjson/classification_import.json',
88
'r') as file:
99
data = json.load(file)
10-
res = NDJsonConverter.deserialize(data).as_collection()
10+
res = NDJsonConverter.deserialize(data).as_list()
1111
res = list(NDJsonConverter.serialize(res))
1212
assert res == data

tests/data/serialization/ndjson/test_image.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def test_image():
2020
with open('tests/data/assets/ndjson/image_import.json', 'r') as file:
2121
data = json.load(file)
2222

23-
res = NDJsonConverter.deserialize(data).as_collection()
23+
res = NDJsonConverter.deserialize(data).as_list()
2424
res = list(NDJsonConverter.serialize(res))
2525
for r in res:
2626
r.pop('classifications', None)

tests/data/serialization/ndjson/test_nested.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,6 @@
66
def test_nested():
77
with open('tests/data/assets/ndjson/nested_import.json', 'r') as file:
88
data = json.load(file)
9-
res = NDJsonConverter.deserialize(data).as_collection()
9+
res = NDJsonConverter.deserialize(data).as_list()
1010
res = list(NDJsonConverter.serialize(res))
1111
assert res == data

tests/data/serialization/ndjson/test_text.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,6 @@
66
def test_text():
77
with open('tests/data/assets/ndjson/text_import.json', 'r') as file:
88
data = json.load(file)
9-
res = NDJsonConverter.deserialize(data).as_collection()
9+
res = NDJsonConverter.deserialize(data).as_list()
1010
res = list(NDJsonConverter.serialize(res))
1111
assert res == data

tests/data/serialization/ndjson/test_video.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,6 @@ def test_video():
77
with open('tests/data/assets/ndjson/video_import.json', 'r') as file:
88
data = json.load(file)
99

10-
res = NDJsonConverter.deserialize(data).as_collection()
10+
res = NDJsonConverter.deserialize(data).as_list()
1111
res = list(NDJsonConverter.serialize(res))
1212
assert res == [data[2], data[0], data[1]]

0 commit comments

Comments
 (0)