Skip to content

Commit 723d926

Browse files
author
Matt Sokoloff
committed
prefetch generator
1 parent 92925e9 commit 723d926

File tree

15 files changed

+231
-113
lines changed

15 files changed

+231
-113
lines changed
Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
1-
from typing import Any, Dict, List, Union, ForwardRef
2-
from pydantic.class_validators import validator
3-
4-
from pydantic.main import BaseModel
1+
from typing import Any, Dict, List
52

63
from labelbox.data.annotation_types.reference import FeatureSchemaRef
4+
from pydantic.main import BaseModel
75

86

97
class ClassificationAnswer(FeatureSchemaRef):
@@ -16,7 +14,6 @@ class Radio(BaseModel):
1614

1715
class CheckList(BaseModel):
1816
answer: List[ClassificationAnswer]
19-
# TODO: Validate that there is only one of each answer
2017

2118

2219
class Text(BaseModel):
@@ -25,4 +22,3 @@ class Text(BaseModel):
2522

2623
class Dropdown(BaseModel):
2724
answer: List[ClassificationAnswer]
28-
# TODO: Validate that there is only one of each answer

labelbox/data/annotation_types/collection.py

Lines changed: 56 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,23 @@
1+
import logging
12
from concurrent.futures import ThreadPoolExecutor, as_completed
23
from typing import Callable, Generator, Iterable, Union
34
from uuid import uuid4
45

56
from labelbox.data.annotation_types.label import Label
7+
from labelbox.data.generator import PrefetchGenerator
68
from labelbox.orm.model import Entity
79
from labelbox.schema.ontology import OntologyBuilder
810
from tqdm import tqdm
911

12+
logger = logging.getLogger(__name__)
13+
1014

1115
class LabelCollection:
1216
"""
1317
A container for
1418
1519
"""
20+
1621
def __init__(self, data: Iterable[Label]):
1722
self._data = data
1823
self._index = 0
@@ -35,7 +40,8 @@ def __len__(self) -> int:
3540
def __getitem__(self, idx: int) -> Label:
3641
return self._data[idx]
3742

38-
def assign_schema_ids(self, ontology_builder: OntologyBuilder) -> "LabelCollection":
43+
def assign_schema_ids(
44+
self, ontology_builder: OntologyBuilder) -> "LabelCollection":
3945
"""
4046
Based on an ontology:
4147
- Checks to make sure that the feature names exist in the ontology
@@ -57,19 +63,22 @@ def _ensure_unique_external_ids(self) -> None:
5763
)
5864
external_ids.add(label.data.external_id)
5965

60-
def add_to_dataset(self, dataset, signer, max_concurrency=20) -> "LabelCollection":
66+
def add_to_dataset(self,
67+
dataset,
68+
signer,
69+
max_concurrency=20) -> "LabelCollection":
6170
"""
6271
# It is reccomended to create a new dataset if memory is a concern
6372
# Also note that this relies on exported data that it cached.
6473
# So this will not work on the same dataset more frequently than every 30 min.
6574
# The workaround is creating a new dataset
6675
"""
6776
self._ensure_unique_external_ids()
68-
self.add_urls_to_data(signer,
69-
max_concurrency=max_concurrency)
70-
upload_task = dataset.create_data_rows(
71-
[{Entity.DataRow.row_data: label.data.url, Entity.DataRow.external_id: label.data.external_id} for label in self._data]
72-
)
77+
self.add_urls_to_data(signer, max_concurrency=max_concurrency)
78+
upload_task = dataset.create_data_rows([{
79+
Entity.DataRow.row_data: label.data.url,
80+
Entity.DataRow.external_id: label.data.external_id
81+
} for label in self._data])
7382
upload_task.wait_til_done()
7483

7584
data_row_lookup = {
@@ -80,20 +89,26 @@ def add_to_dataset(self, dataset, signer, max_concurrency=20) -> "LabelCollectio
8089
label.data.uid = data_row_lookup[label.data.external_id]
8190
return self
8291

83-
def add_urls_to_masks(self, signer, max_concurrency=20) -> "LabelCollection":
92+
def add_urls_to_masks(self,
93+
signer,
94+
max_concurrency=20) -> "LabelCollection":
8495
"""
8596
Creates a data row id for each data row that needs it. If the data row exists then it skips the row.
8697
TODO: Add error handling..
8798
"""
88-
for row in self._apply_threaded([label.add_url_to_masks for label in self._data], max_concurrency, signer):
99+
for row in self._apply_threaded(
100+
[label.add_url_to_masks for label in self._data], max_concurrency,
101+
signer):
89102
...
90103
return self
91104

92105
def add_urls_to_data(self, signer, max_concurrency=20) -> "LabelCollection":
93106
"""
94107
TODO: Add error handling..
95108
"""
96-
for row in self._apply_threaded([label.add_url_to_data for label in self._data], max_concurrency, signer):
109+
for row in self._apply_threaded(
110+
[label.add_url_to_data for label in self._data], max_concurrency,
111+
signer):
97112
...
98113
return self
99114

@@ -105,68 +120,84 @@ def _apply_threaded(self, fns, max_concurrency, *args):
105120
for future in tqdm(as_completed(futures)):
106121
yield future.result()
107122

108-
class LabelGenerator:
123+
124+
class LabelGenerator(PrefetchGenerator):
109125
"""
110126
Use this class if you have larger data. It is slightly harder to work with
111127
than the LabelCollection but will be much more memory efficient.
112128
"""
113-
def __init__(self, data: Generator[Label, None,None]):
114-
if isinstance(data, (list, tuple)):
115-
self._data = (r for r in data)
116-
else:
117-
self._data = data
129+
130+
def __init__(self, data: Generator[Label, None, None], *args, **kwargs):
118131
self._fns = {}
132+
super().__init__(data, *args, **kwargs)
119133

120134
def __iter__(self):
121135
return self
122136

123-
def __next__(self) -> Label:
124-
# Maybe some sort of prefetching could be nice
125-
# to make things faster if users are applying io functions
126-
value = next(self._data)
137+
def process(self, value):
127138
for fn in self._fns.values():
128139
value = fn(value)
129140
return value
130141

131142
def as_collection(self) -> "LabelCollection":
132-
return LabelCollection(data = list(self._data))
143+
return LabelCollection(data=list(self))
144+
145+
def assign_schema_ids(
146+
self, ontology_builder: OntologyBuilder) -> "LabelGenerator":
133147

134-
def assign_schema_ids(self, ontology_builder: OntologyBuilder) -> "LabelGenerator":
135148
def _assign_ids(label: Label):
136149
label.assign_schema_ids(ontology_builder)
137150
return label
151+
138152
self._fns['assign_schema_ids'] = _assign_ids
139153
return self
140154

141-
def add_urls_to_data(self, signer: Callable[[bytes], str]) -> "LabelGenerator":
155+
def add_urls_to_data(self, signer: Callable[[bytes],
156+
str]) -> "LabelGenerator":
142157
"""
143158
Updates masks to have `url` attribute
144159
Doesn't update masks that already have urls
145160
"""
161+
146162
def _add_urls_to_data(label: Label):
147163
label.add_url_to_data(signer)
148164
return label
165+
149166
self._fns['_add_urls_to_data'] = _add_urls_to_data
150167
return self
151168

152-
def add_to_dataset(self, dataset, signer: Callable[[bytes], str]) -> "LabelGenerator":
169+
def add_to_dataset(self, dataset,
170+
signer: Callable[[bytes], str]) -> "LabelGenerator":
171+
153172
def _add_to_dataset(label: Label):
154173
label.create_data_row(dataset, signer)
155174
return label
175+
156176
self._fns['assign_datarow_ids'] = _add_to_dataset
157177
return self
158178

159-
def add_urls_to_masks(self, signer: Callable[[bytes], str]) -> "LabelGenerator":
179+
def add_urls_to_masks(self, signer: Callable[[bytes],
180+
str]) -> "LabelGenerator":
160181
"""
161182
Updates masks to have `url` attribute
162183
Doesn't update masks that already have urls
163184
"""
185+
164186
def _add_urls_to_masks(label: Label):
165187
label.add_url_to_masks(signer)
166188
return label
189+
167190
self._fns['add_urls_to_masks'] = _add_urls_to_masks
168191
return self
169192

193+
def __next__(self):
194+
"""
195+
- Double check that all values have been set.
196+
- Items could have been processed before any of these modifying functions are called.
197+
- None of these functions do anything if run more than once so the cost is minimal.
198+
"""
199+
value = super().__next__()
200+
return self.process(value)
170201

171202

172203
LabelData = Union[LabelCollection, LabelGenerator]

labelbox/data/annotation_types/data/raster.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ def data(self) -> np.ndarray:
4949
raise ValueError("Must set either url, file_path or im_bytes")
5050

5151
def create_url(self, signer: Callable[[bytes], str]) -> None:
52+
5253
if self.url is not None:
5354
return self.url
5455
elif self.im_bytes is not None:

labelbox/data/annotation_types/data/text.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
from typing import Callable, Optional
22

33
import requests
4-
from pydantic import ValidationError, root_validator
5-
64
from labelbox.data.annotation_types.reference import DataRowRef
5+
from pydantic import ValidationError, root_validator
76

87

98
class TextData(DataRowRef):
@@ -49,7 +48,7 @@ def validate_date(cls, values):
4948
url = values.get("url")
5049
uid = values.get('uid')
5150
if uid == file_path == text == url == None:
52-
raise ValidationError(
51+
raise ValueError(
5352
"One of `file_path`, `text`, `uid`, or `url` required.")
5453
return values
5554

labelbox/data/annotation_types/data/video.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,13 @@
11
import logging
2-
from uuid import uuid4
32
import os
4-
from typing import Generator, Callable, Optional, Tuple, Dict, Any
3+
import urllib.request
4+
from typing import Any, Callable, Dict, Generator, Optional, Tuple
5+
from uuid import uuid4
56

67
import cv2
7-
import urllib.request
88
import numpy as np
9-
from pydantic import ValidationError, root_validator
10-
119
from labelbox.data.annotation_types.reference import DataRowRef
10+
from pydantic import ValidationError, root_validator
1211

1312
logger = logging.getLogger(__name__)
1413

@@ -104,7 +103,7 @@ def validate_data(cls, values):
104103
uid = values.get("uid")
105104

106105
if uid == file_path == frames == url == None:
107-
raise ValidationError(
106+
raise ValueError(
108107
"One of `file_path`, `frames`, `uid`, or `url` required.")
109108
return values
110109

labelbox/data/annotation_types/geometry/rectangle.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
1-
from typing import Dict, Any
1+
from typing import Any, Dict
22

3-
import numpy as np
43
import cv2
54
import geojson
6-
5+
import numpy as np
76
from labelbox.data.annotation_types.geometry.geometry import Geometry
87
from labelbox.data.annotation_types.geometry.point import Point
98

labelbox/data/annotation_types/label.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
from typing import Any, Dict, List, Union
22

3-
from labelbox.data.annotation_types.annotation import (
4-
AnnotationType, ClassificationAnnotation, ObjectAnnotation,
5-
VideoAnnotationType)
3+
from labelbox.data.annotation_types.annotation import (AnnotationType,
4+
ClassificationAnnotation,
5+
ObjectAnnotation,
6+
VideoAnnotationType)
67
from labelbox.data.annotation_types.classification.classification import \
78
ClassificationAnswer
89
from labelbox.data.annotation_types.data.raster import RasterData
@@ -27,7 +28,7 @@ def add_url_to_data(self, signer):
2728
self.data.create_url(signer)
2829
return self
2930

30-
def add_url_to_masks(self, signer):
31+
def add_url_to_masks(self, signer) -> "Label":
3132
masks = []
3233
for annotation in self.annotations:
3334
# Allows us to upload shared masks once
@@ -39,16 +40,17 @@ def add_url_to_masks(self, signer):
3940
return self
4041

4142
def create_data_row(self, dataset, signer):
42-
args = {
43-
'row_data' : self.add_url_to_data(signer)
44-
}
43+
"""
44+
Only overwrites if necessary
45+
46+
"""
47+
args = {'row_data': self.add_url_to_data(signer)}
4548
if self.data.external_id is not None:
46-
args.update({
47-
'external'
48-
})
49-
data_row = dataset.create_data_row(**args)
50-
self.data.uid = data_row.uid
51-
self.data.external_id = data_row.external_id
49+
args.update({'external'})
50+
if self.data.uid is None:
51+
data_row = dataset.create_data_row(**args)
52+
self.data.uid = data_row.uid
53+
self.data.external_id = data_row.external_id
5254
return self
5355

5456
def get_feature_schema_lookup(self, ontology_builder):

0 commit comments

Comments
 (0)