Skip to content

Commit 92925e9

Browse files
author
Matt Sokoloff
committed
label containers improved
1 parent cfb8487 commit 92925e9

File tree

16 files changed

+351
-99
lines changed

16 files changed

+351
-99
lines changed

labelbox/data/annotation_types/annotation.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
1-
from typing import List, Union, Dict, Any
1+
from typing import Any, Dict, List, Union
22

3-
from labelbox.data.annotation_types.classification.classification import Dropdown, Text, CheckList, Radio
4-
from labelbox.data.annotation_types.reference import FeatureSchemaRef
5-
from labelbox.data.annotation_types.ner import TextEntity
3+
from labelbox.data.annotation_types.classification.classification import (
4+
CheckList, Dropdown, Radio, Text)
65
from labelbox.data.annotation_types.geometry import Geometry
6+
from labelbox.data.annotation_types.ner import TextEntity
7+
from labelbox.data.annotation_types.reference import FeatureSchemaRef
78

89

910
class BaseAnnotation(FeatureSchemaRef):
Lines changed: 135 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,74 +1,172 @@
11
from concurrent.futures import ThreadPoolExecutor, as_completed
2-
from typing import Iterable, List, Any
2+
from typing import Callable, Generator, Iterable, Union
33
from uuid import uuid4
44

5-
from pydantic import BaseModel
6-
75
from labelbox.data.annotation_types.label import Label
86
from labelbox.orm.model import Entity
7+
from labelbox.schema.ontology import OntologyBuilder
8+
from tqdm import tqdm
9+
10+
11+
class LabelCollection:
12+
"""
13+
A container for
14+
15+
"""
16+
def __init__(self, data: Iterable[Label]):
17+
self._data = data
18+
self._index = 0
19+
20+
def __iter__(self):
21+
self._index = 0
22+
return self
23+
24+
def __next__(self) -> Label:
25+
if self._index == len(self._data):
26+
raise StopIteration
27+
28+
value = self._data[self._index]
29+
self._index += 1
30+
return value
931

32+
def __len__(self) -> int:
33+
return len(self._data)
1034

11-
class LabelCollection(BaseModel):
12-
data: Iterable[Label]
35+
def __getitem__(self, idx: int) -> Label:
36+
return self._data[idx]
1337

14-
def assign_schema_ids(self, ontology_builder):
38+
def assign_schema_ids(self, ontology_builder: OntologyBuilder) -> "LabelCollection":
1539
"""
1640
Based on an ontology:
1741
- Checks to make sure that the feature names exist in the ontology
1842
- Updates the names to match the ontology.
1943
"""
20-
for label in self.data:
21-
for annotation in label.annotations:
22-
annotation.assign_schema_ids(ontology_builder)
44+
for label in self._data:
45+
label.assign_schema_ids(ontology_builder)
46+
return self
2347

24-
def create_dataset(self, client, dataset_name, signer, max_concurrency=20):
48+
def _ensure_unique_external_ids(self) -> None:
2549
external_ids = set()
26-
for label in self.data:
50+
for label in self._data:
2751
if label.data.external_id is None:
2852
label.data.external_id = uuid4()
2953
else:
3054
if label.data.external_id in external_ids:
3155
raise ValueError(
32-
f"External ids must be unique for bulk uploading. Found {label.data.exeternal_id} more than once."
56+
f"External ids must be unique for bulk uploading. Found {label.data.external_id} more than once."
3357
)
3458
external_ids.add(label.data.external_id)
35-
labels = self.create_urls_for_data(signer,
59+
60+
def add_to_dataset(self, dataset, signer, max_concurrency=20) -> "LabelCollection":
61+
"""
62+
# It is reccomended to create a new dataset if memory is a concern
63+
# Also note that this relies on exported data that it cached.
64+
# So this will not work on the same dataset more frequently than every 30 min.
65+
# The workaround is creating a new dataset
66+
"""
67+
self._ensure_unique_external_ids()
68+
self.add_urls_to_data(signer,
3669
max_concurrency=max_concurrency)
37-
dataset = client.create_dataset(name=dataset_name)
38-
upload_task = dataset.create_data_row(
39-
{Entity.DataRow.row_data: label.data.url for label in labels})
70+
upload_task = dataset.create_data_rows(
71+
[{Entity.DataRow.row_data: label.data.url, Entity.DataRow.external_id: label.data.external_id} for label in self._data]
72+
)
4073
upload_task.wait_til_done()
4174

42-
data_rows = {
75+
data_row_lookup = {
4376
data_row.external_id: data_row.uid
4477
for data_row in dataset.export_data_rows()
4578
}
46-
for label in self.data:
47-
data_row = data_rows[label.data.external_id]
48-
label.data.uid = data_row.uid
79+
for label in self._data:
80+
label.data.uid = data_row_lookup[label.data.external_id]
81+
return self
4982

50-
def create_urls_for_masks(self, signer, max_concurrency=20):
83+
def add_urls_to_masks(self, signer, max_concurrency=20) -> "LabelCollection":
5184
"""
5285
Creates a data row id for each data row that needs it. If the data row exists then it skips the row.
5386
TODO: Add error handling..
5487
"""
55-
futures = {}
56-
with ThreadPoolExecutor(max_workers=max_concurrency) as executor:
57-
for label in self.data:
58-
futures[executor.submit(label.create_url_for_masks)] = label
59-
for future in as_completed(futures):
60-
# Yields the label. But this function modifies the objects to have updated urls.
61-
yield futures[future]
62-
del futures[future]
63-
64-
def create_urls_for_data(self, signer, max_concurrency=20):
88+
for row in self._apply_threaded([label.add_url_to_masks for label in self._data], max_concurrency, signer):
89+
...
90+
return self
91+
92+
def add_urls_to_data(self, signer, max_concurrency=20) -> "LabelCollection":
6593
"""
6694
TODO: Add error handling..
6795
"""
68-
futures = {}
96+
for row in self._apply_threaded([label.add_url_to_data for label in self._data], max_concurrency, signer):
97+
...
98+
return self
99+
100+
def _apply_threaded(self, fns, max_concurrency, *args):
101+
futures = []
69102
with ThreadPoolExecutor(max_workers=max_concurrency) as executor:
70-
for label in self.data:
71-
futures[executor.submit(label.create_url_for_data)] = label
72-
for future in as_completed(futures):
73-
yield futures[future]
74-
del futures[future]
103+
for fn in fns:
104+
futures.append(executor.submit(fn, *args))
105+
for future in tqdm(as_completed(futures)):
106+
yield future.result()
107+
108+
class LabelGenerator:
109+
"""
110+
Use this class if you have larger data. It is slightly harder to work with
111+
than the LabelCollection but will be much more memory efficient.
112+
"""
113+
def __init__(self, data: Generator[Label, None,None]):
114+
if isinstance(data, (list, tuple)):
115+
self._data = (r for r in data)
116+
else:
117+
self._data = data
118+
self._fns = {}
119+
120+
def __iter__(self):
121+
return self
122+
123+
def __next__(self) -> Label:
124+
# Maybe some sort of prefetching could be nice
125+
# to make things faster if users are applying io functions
126+
value = next(self._data)
127+
for fn in self._fns.values():
128+
value = fn(value)
129+
return value
130+
131+
def as_collection(self) -> "LabelCollection":
132+
return LabelCollection(data = list(self._data))
133+
134+
def assign_schema_ids(self, ontology_builder: OntologyBuilder) -> "LabelGenerator":
135+
def _assign_ids(label: Label):
136+
label.assign_schema_ids(ontology_builder)
137+
return label
138+
self._fns['assign_schema_ids'] = _assign_ids
139+
return self
140+
141+
def add_urls_to_data(self, signer: Callable[[bytes], str]) -> "LabelGenerator":
142+
"""
143+
Updates masks to have `url` attribute
144+
Doesn't update masks that already have urls
145+
"""
146+
def _add_urls_to_data(label: Label):
147+
label.add_url_to_data(signer)
148+
return label
149+
self._fns['_add_urls_to_data'] = _add_urls_to_data
150+
return self
151+
152+
def add_to_dataset(self, dataset, signer: Callable[[bytes], str]) -> "LabelGenerator":
153+
def _add_to_dataset(label: Label):
154+
label.create_data_row(dataset, signer)
155+
return label
156+
self._fns['assign_datarow_ids'] = _add_to_dataset
157+
return self
158+
159+
def add_urls_to_masks(self, signer: Callable[[bytes], str]) -> "LabelGenerator":
160+
"""
161+
Updates masks to have `url` attribute
162+
Doesn't update masks that already have urls
163+
"""
164+
def _add_urls_to_masks(label: Label):
165+
label.add_url_to_masks(signer)
166+
return label
167+
self._fns['add_urls_to_masks'] = _add_urls_to_masks
168+
return self
169+
170+
171+
172+
LabelData = Union[LabelCollection, LabelGenerator]

labelbox/data/annotation_types/data/raster.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
1-
from typing import Callable, Dict, Any, Optional
21
from io import BytesIO
2+
from typing import Any, Callable, Dict, Optional
33

4-
from PIL import Image
54
import numpy as np
65
import requests
7-
from pydantic import ValidationError, root_validator
8-
96
from labelbox.data.annotation_types.reference import DataRowRef
7+
from PIL import Image
8+
from pydantic import ValidationError, root_validator
109

1110

1211
class RasterData(DataRowRef):
@@ -72,16 +71,16 @@ def validate_args(cls, values):
7271
arr = values.get("arr")
7372
uid = values.get('uid')
7473
if uid == file_path == im_bytes == url == None and arr is None:
75-
raise ValidationError(
74+
raise ValueError(
7675
"One of `file_path`, `im_bytes`, `url`, `uid` or `arr` required."
7776
)
7877
if arr is not None:
7978
if arr.dtype != np.uint8:
80-
raise ValidationError(
79+
raise TypeError(
8180
"Numpy array representing segmentation mask must be np.uint8"
8281
)
8382
elif len(arr.shape) not in [2, 3]:
84-
raise ValidationError(
83+
raise TypeError(
8584
f"Numpy array must have 2 or 3 dims. Found shape {arr.shape}"
8685
)
8786
return values

labelbox/data/annotation_types/geometry/mask.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
11
from typing import Any, Dict, Tuple
22

33
import numpy as np
4+
from labelbox.data.annotation_types.data.raster import RasterData
5+
from labelbox.data.annotation_types.geometry.geometry import Geometry
46
from rasterio.features import shapes
57
from shapely.geometry import MultiPolygon, shape
68

7-
from labelbox.data.annotation_types.geometry.geometry import Geometry
8-
from labelbox.data.annotation_types.data.raster import RasterData
9-
109

1110
class Mask(Geometry):
1211
# Raster data can be shared across multiple masks... or not

labelbox/data/annotation_types/label.py

Lines changed: 29 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,33 @@
1-
from typing import Union, List, Dict, Any
1+
from typing import Any, Dict, List, Union
22

3-
from pydantic import BaseModel
4-
5-
from labelbox.schema.ontology import Classification as OClassification, Option
6-
from labelbox.data.annotation_types.classification.classification import ClassificationAnswer
7-
from labelbox.data.annotation_types.annotation import AnnotationType, ClassificationAnnotation, ObjectAnnotation, VideoAnnotationType
3+
from labelbox.data.annotation_types.annotation import (
4+
AnnotationType, ClassificationAnnotation, ObjectAnnotation,
5+
VideoAnnotationType)
6+
from labelbox.data.annotation_types.classification.classification import \
7+
ClassificationAnswer
88
from labelbox.data.annotation_types.data.raster import RasterData
99
from labelbox.data.annotation_types.data.text import TextData
1010
from labelbox.data.annotation_types.data.video import VideoData
11-
from labelbox.data.annotation_types.metrics import Metric
1211
from labelbox.data.annotation_types.geometry.mask import Mask
12+
from labelbox.data.annotation_types.metrics import Metric
13+
from labelbox.schema.ontology import Classification as OClassification
14+
from labelbox.schema.ontology import Option
15+
from pydantic import BaseModel
1316

1417

1518
class Label(BaseModel):
1619
data: Union[VideoData, RasterData, TextData]
1720
annotations: List[Union[AnnotationType, VideoAnnotationType, Metric]] = []
1821
extra: Dict[str, Any] = {}
1922

20-
def create_url_for_data(self, signer):
21-
return self.data.create_url(signer)
23+
def add_url_to_data(self, signer):
24+
"""
25+
Only creates a url if one doesn't exist
26+
"""
27+
self.data.create_url(signer)
28+
return self
2229

23-
def create_url_for_masks(self, signer):
30+
def add_url_to_masks(self, signer):
2431
masks = []
2532
for annotation in self.annotations:
2633
# Allows us to upload shared masks once
@@ -29,12 +36,20 @@ def create_url_for_masks(self, signer):
2936
masks.append(annotation.value.mask)
3037
for mask in masks:
3138
mask.create_url(signer)
39+
return self
3240

3341
def create_data_row(self, dataset, signer):
34-
data_row = dataset.create_data_row(
35-
row_data=self.create_url_for_data(signer))
42+
args = {
43+
'row_data' : self.add_url_to_data(signer)
44+
}
45+
if self.data.external_id is not None:
46+
args.update({
47+
'external'
48+
})
49+
data_row = dataset.create_data_row(**args)
3650
self.data.uid = data_row.uid
37-
return data_row
51+
self.data.external_id = data_row.external_id
52+
return self
3853

3954
def get_feature_schema_lookup(self, ontology_builder):
4055
tool_lookup = {}
@@ -64,8 +79,6 @@ def flatten_classification(classifications):
6479
def assign_schema_ids(self, ontology_builder):
6580
"""
6681
Classifications get flattened when labeling.
67-
68-
6982
"""
7083

7184
def assign_or_raise(annotation, lookup):
@@ -106,3 +119,4 @@ def assign_option(classification, lookup):
106119
else:
107120
raise TypeError(
108121
f"Unexpected type found for annotation. {type(annotation)}")
122+
return self

0 commit comments

Comments
 (0)