1+ import logging
12from concurrent .futures import ThreadPoolExecutor , as_completed
23from typing import Callable , Generator , Iterable , Union
34from uuid import uuid4
45
56from labelbox .data .annotation_types .label import Label
7+ from labelbox .data .generator import PrefetchGenerator
68from labelbox .orm .model import Entity
79from labelbox .schema .ontology import OntologyBuilder
810from tqdm import tqdm
911
12+ logger = logging .getLogger (__name__ )
13+
1014
1115class LabelCollection :
1216 """
1317 A container for
1418
1519 """
20+
1621 def __init__ (self , data : Iterable [Label ]):
1722 self ._data = data
1823 self ._index = 0
@@ -35,7 +40,8 @@ def __len__(self) -> int:
3540 def __getitem__ (self , idx : int ) -> Label :
3641 return self ._data [idx ]
3742
38- def assign_schema_ids (self , ontology_builder : OntologyBuilder ) -> "LabelCollection" :
43+ def assign_schema_ids (
44+ self , ontology_builder : OntologyBuilder ) -> "LabelCollection" :
3945 """
4046 Based on an ontology:
4147 - Checks to make sure that the feature names exist in the ontology
@@ -57,19 +63,22 @@ def _ensure_unique_external_ids(self) -> None:
5763 )
5864 external_ids .add (label .data .external_id )
5965
60- def add_to_dataset (self , dataset , signer , max_concurrency = 20 ) -> "LabelCollection" :
66+ def add_to_dataset (self ,
67+ dataset ,
68+ signer ,
69+ max_concurrency = 20 ) -> "LabelCollection" :
6170 """
6271 # It is reccomended to create a new dataset if memory is a concern
6372 # Also note that this relies on exported data that it cached.
6473 # So this will not work on the same dataset more frequently than every 30 min.
6574 # The workaround is creating a new dataset
6675 """
6776 self ._ensure_unique_external_ids ()
68- self .add_urls_to_data (signer ,
69- max_concurrency = max_concurrency )
70- upload_task = dataset . create_data_rows (
71- [{ Entity .DataRow .row_data : label . data . url , Entity . DataRow . external_id : label .data .external_id } for label in self . _data ]
72- )
77+ self .add_urls_to_data (signer , max_concurrency = max_concurrency )
78+ upload_task = dataset . create_data_rows ([{
79+ Entity . DataRow . row_data : label . data . url ,
80+ Entity .DataRow .external_id : label .data .external_id
81+ } for label in self . _data ] )
7382 upload_task .wait_til_done ()
7483
7584 data_row_lookup = {
@@ -80,20 +89,26 @@ def add_to_dataset(self, dataset, signer, max_concurrency=20) -> "LabelCollectio
8089 label .data .uid = data_row_lookup [label .data .external_id ]
8190 return self
8291
83- def add_urls_to_masks (self , signer , max_concurrency = 20 ) -> "LabelCollection" :
92+ def add_urls_to_masks (self ,
93+ signer ,
94+ max_concurrency = 20 ) -> "LabelCollection" :
8495 """
8596 Creates a data row id for each data row that needs it. If the data row exists then it skips the row.
8697 TODO: Add error handling..
8798 """
88- for row in self ._apply_threaded ([label .add_url_to_masks for label in self ._data ], max_concurrency , signer ):
99+ for row in self ._apply_threaded (
100+ [label .add_url_to_masks for label in self ._data ], max_concurrency ,
101+ signer ):
89102 ...
90103 return self
91104
92105 def add_urls_to_data (self , signer , max_concurrency = 20 ) -> "LabelCollection" :
93106 """
94107 TODO: Add error handling..
95108 """
96- for row in self ._apply_threaded ([label .add_url_to_data for label in self ._data ], max_concurrency , signer ):
109+ for row in self ._apply_threaded (
110+ [label .add_url_to_data for label in self ._data ], max_concurrency ,
111+ signer ):
97112 ...
98113 return self
99114
@@ -105,68 +120,84 @@ def _apply_threaded(self, fns, max_concurrency, *args):
105120 for future in tqdm (as_completed (futures )):
106121 yield future .result ()
107122
108- class LabelGenerator :
123+
124+ class LabelGenerator (PrefetchGenerator ):
109125 """
110126 Use this class if you have larger data. It is slightly harder to work with
111127 than the LabelCollection but will be much more memory efficient.
112128 """
113- def __init__ (self , data : Generator [Label , None ,None ]):
114- if isinstance (data , (list , tuple )):
115- self ._data = (r for r in data )
116- else :
117- self ._data = data
129+
130+ def __init__ (self , data : Generator [Label , None , None ], * args , ** kwargs ):
118131 self ._fns = {}
132+ super ().__init__ (data , * args , ** kwargs )
119133
120134 def __iter__ (self ):
121135 return self
122136
123- def __next__ (self ) -> Label :
124- # Maybe some sort of prefetching could be nice
125- # to make things faster if users are applying io functions
126- value = next (self ._data )
137+ def process (self , value ):
127138 for fn in self ._fns .values ():
128139 value = fn (value )
129140 return value
130141
131142 def as_collection (self ) -> "LabelCollection" :
132- return LabelCollection (data = list (self ._data ))
143+ return LabelCollection (data = list (self ))
144+
145+ def assign_schema_ids (
146+ self , ontology_builder : OntologyBuilder ) -> "LabelGenerator" :
133147
134- def assign_schema_ids (self , ontology_builder : OntologyBuilder ) -> "LabelGenerator" :
135148 def _assign_ids (label : Label ):
136149 label .assign_schema_ids (ontology_builder )
137150 return label
151+
138152 self ._fns ['assign_schema_ids' ] = _assign_ids
139153 return self
140154
141- def add_urls_to_data (self , signer : Callable [[bytes ], str ]) -> "LabelGenerator" :
155+ def add_urls_to_data (self , signer : Callable [[bytes ],
156+ str ]) -> "LabelGenerator" :
142157 """
143158 Updates masks to have `url` attribute
144159 Doesn't update masks that already have urls
145160 """
161+
146162 def _add_urls_to_data (label : Label ):
147163 label .add_url_to_data (signer )
148164 return label
165+
149166 self ._fns ['_add_urls_to_data' ] = _add_urls_to_data
150167 return self
151168
152- def add_to_dataset (self , dataset , signer : Callable [[bytes ], str ]) -> "LabelGenerator" :
169+ def add_to_dataset (self , dataset ,
170+ signer : Callable [[bytes ], str ]) -> "LabelGenerator" :
171+
153172 def _add_to_dataset (label : Label ):
154173 label .create_data_row (dataset , signer )
155174 return label
175+
156176 self ._fns ['assign_datarow_ids' ] = _add_to_dataset
157177 return self
158178
159- def add_urls_to_masks (self , signer : Callable [[bytes ], str ]) -> "LabelGenerator" :
179+ def add_urls_to_masks (self , signer : Callable [[bytes ],
180+ str ]) -> "LabelGenerator" :
160181 """
161182 Updates masks to have `url` attribute
162183 Doesn't update masks that already have urls
163184 """
185+
164186 def _add_urls_to_masks (label : Label ):
165187 label .add_url_to_masks (signer )
166188 return label
189+
167190 self ._fns ['add_urls_to_masks' ] = _add_urls_to_masks
168191 return self
169192
193+ def __next__ (self ):
194+ """
195+ - Double check that all values have been set.
196+ - Items could have been processed before any of these modifying functions are called.
197+ - None of these functions do anything if run more than once so the cost is minimal.
198+ """
199+ value = super ().__next__ ()
200+ return self .process (value )
170201
171202
172203LabelData = Union [LabelCollection , LabelGenerator ]
0 commit comments