|
1 | 1 | from concurrent.futures import ThreadPoolExecutor, as_completed |
2 | | -from typing import Iterable, List, Any |
| 2 | +from typing import Callable, Generator, Iterable, Union |
3 | 3 | from uuid import uuid4 |
4 | 4 |
|
5 | | -from pydantic import BaseModel |
6 | | - |
7 | 5 | from labelbox.data.annotation_types.label import Label |
8 | 6 | from labelbox.orm.model import Entity |
| 7 | +from labelbox.schema.ontology import OntologyBuilder |
| 8 | +from tqdm import tqdm |
| 9 | + |
| 10 | + |
| 11 | +class LabelCollection: |
| 12 | + """ |
| 13 | + A container for |
| 14 | +
|
| 15 | + """ |
| 16 | + def __init__(self, data: Iterable[Label]): |
| 17 | + self._data = data |
| 18 | + self._index = 0 |
| 19 | + |
| 20 | + def __iter__(self): |
| 21 | + self._index = 0 |
| 22 | + return self |
| 23 | + |
| 24 | + def __next__(self) -> Label: |
| 25 | + if self._index == len(self._data): |
| 26 | + raise StopIteration |
| 27 | + |
| 28 | + value = self._data[self._index] |
| 29 | + self._index += 1 |
| 30 | + return value |
9 | 31 |
|
| 32 | + def __len__(self) -> int: |
| 33 | + return len(self._data) |
10 | 34 |
|
11 | | -class LabelCollection(BaseModel): |
12 | | - data: Iterable[Label] |
| 35 | + def __getitem__(self, idx: int) -> Label: |
| 36 | + return self._data[idx] |
13 | 37 |
|
14 | | - def assign_schema_ids(self, ontology_builder): |
| 38 | + def assign_schema_ids(self, ontology_builder: OntologyBuilder) -> "LabelCollection": |
15 | 39 | """ |
16 | 40 | Based on an ontology: |
17 | 41 | - Checks to make sure that the feature names exist in the ontology |
18 | 42 | - Updates the names to match the ontology. |
19 | 43 | """ |
20 | | - for label in self.data: |
21 | | - for annotation in label.annotations: |
22 | | - annotation.assign_schema_ids(ontology_builder) |
| 44 | + for label in self._data: |
| 45 | + label.assign_schema_ids(ontology_builder) |
| 46 | + return self |
23 | 47 |
|
24 | | - def create_dataset(self, client, dataset_name, signer, max_concurrency=20): |
| 48 | + def _ensure_unique_external_ids(self) -> None: |
25 | 49 | external_ids = set() |
26 | | - for label in self.data: |
| 50 | + for label in self._data: |
27 | 51 | if label.data.external_id is None: |
28 | 52 | label.data.external_id = uuid4() |
29 | 53 | else: |
30 | 54 | if label.data.external_id in external_ids: |
31 | 55 | raise ValueError( |
32 | | - f"External ids must be unique for bulk uploading. Found {label.data.exeternal_id} more than once." |
| 56 | + f"External ids must be unique for bulk uploading. Found {label.data.external_id} more than once." |
33 | 57 | ) |
34 | 58 | external_ids.add(label.data.external_id) |
35 | | - labels = self.create_urls_for_data(signer, |
| 59 | + |
| 60 | + def add_to_dataset(self, dataset, signer, max_concurrency=20) -> "LabelCollection": |
| 61 | + """ |
| 62 | + # It is reccomended to create a new dataset if memory is a concern |
| 63 | + # Also note that this relies on exported data that it cached. |
| 64 | + # So this will not work on the same dataset more frequently than every 30 min. |
| 65 | + # The workaround is creating a new dataset |
| 66 | + """ |
| 67 | + self._ensure_unique_external_ids() |
| 68 | + self.add_urls_to_data(signer, |
36 | 69 | max_concurrency=max_concurrency) |
37 | | - dataset = client.create_dataset(name=dataset_name) |
38 | | - upload_task = dataset.create_data_row( |
39 | | - {Entity.DataRow.row_data: label.data.url for label in labels}) |
| 70 | + upload_task = dataset.create_data_rows( |
| 71 | + [{Entity.DataRow.row_data: label.data.url, Entity.DataRow.external_id: label.data.external_id} for label in self._data] |
| 72 | + ) |
40 | 73 | upload_task.wait_til_done() |
41 | 74 |
|
42 | | - data_rows = { |
| 75 | + data_row_lookup = { |
43 | 76 | data_row.external_id: data_row.uid |
44 | 77 | for data_row in dataset.export_data_rows() |
45 | 78 | } |
46 | | - for label in self.data: |
47 | | - data_row = data_rows[label.data.external_id] |
48 | | - label.data.uid = data_row.uid |
| 79 | + for label in self._data: |
| 80 | + label.data.uid = data_row_lookup[label.data.external_id] |
| 81 | + return self |
49 | 82 |
|
50 | | - def create_urls_for_masks(self, signer, max_concurrency=20): |
| 83 | + def add_urls_to_masks(self, signer, max_concurrency=20) -> "LabelCollection": |
51 | 84 | """ |
52 | 85 | Creates a data row id for each data row that needs it. If the data row exists then it skips the row. |
53 | 86 | TODO: Add error handling.. |
54 | 87 | """ |
55 | | - futures = {} |
56 | | - with ThreadPoolExecutor(max_workers=max_concurrency) as executor: |
57 | | - for label in self.data: |
58 | | - futures[executor.submit(label.create_url_for_masks)] = label |
59 | | - for future in as_completed(futures): |
60 | | - # Yields the label. But this function modifies the objects to have updated urls. |
61 | | - yield futures[future] |
62 | | - del futures[future] |
63 | | - |
64 | | - def create_urls_for_data(self, signer, max_concurrency=20): |
| 88 | + for row in self._apply_threaded([label.add_url_to_masks for label in self._data], max_concurrency, signer): |
| 89 | + ... |
| 90 | + return self |
| 91 | + |
| 92 | + def add_urls_to_data(self, signer, max_concurrency=20) -> "LabelCollection": |
65 | 93 | """ |
66 | 94 | TODO: Add error handling.. |
67 | 95 | """ |
68 | | - futures = {} |
| 96 | + for row in self._apply_threaded([label.add_url_to_data for label in self._data], max_concurrency, signer): |
| 97 | + ... |
| 98 | + return self |
| 99 | + |
| 100 | + def _apply_threaded(self, fns, max_concurrency, *args): |
| 101 | + futures = [] |
69 | 102 | with ThreadPoolExecutor(max_workers=max_concurrency) as executor: |
70 | | - for label in self.data: |
71 | | - futures[executor.submit(label.create_url_for_data)] = label |
72 | | - for future in as_completed(futures): |
73 | | - yield futures[future] |
74 | | - del futures[future] |
| 103 | + for fn in fns: |
| 104 | + futures.append(executor.submit(fn, *args)) |
| 105 | + for future in tqdm(as_completed(futures)): |
| 106 | + yield future.result() |
| 107 | + |
| 108 | +class LabelGenerator: |
| 109 | + """ |
| 110 | + Use this class if you have larger data. It is slightly harder to work with |
| 111 | + than the LabelCollection but will be much more memory efficient. |
| 112 | + """ |
| 113 | + def __init__(self, data: Generator[Label, None,None]): |
| 114 | + if isinstance(data, (list, tuple)): |
| 115 | + self._data = (r for r in data) |
| 116 | + else: |
| 117 | + self._data = data |
| 118 | + self._fns = {} |
| 119 | + |
| 120 | + def __iter__(self): |
| 121 | + return self |
| 122 | + |
| 123 | + def __next__(self) -> Label: |
| 124 | + # Maybe some sort of prefetching could be nice |
| 125 | + # to make things faster if users are applying io functions |
| 126 | + value = next(self._data) |
| 127 | + for fn in self._fns.values(): |
| 128 | + value = fn(value) |
| 129 | + return value |
| 130 | + |
| 131 | + def as_collection(self) -> "LabelCollection": |
| 132 | + return LabelCollection(data = list(self._data)) |
| 133 | + |
| 134 | + def assign_schema_ids(self, ontology_builder: OntologyBuilder) -> "LabelGenerator": |
| 135 | + def _assign_ids(label: Label): |
| 136 | + label.assign_schema_ids(ontology_builder) |
| 137 | + return label |
| 138 | + self._fns['assign_schema_ids'] = _assign_ids |
| 139 | + return self |
| 140 | + |
| 141 | + def add_urls_to_data(self, signer: Callable[[bytes], str]) -> "LabelGenerator": |
| 142 | + """ |
| 143 | + Updates masks to have `url` attribute |
| 144 | + Doesn't update masks that already have urls |
| 145 | + """ |
| 146 | + def _add_urls_to_data(label: Label): |
| 147 | + label.add_url_to_data(signer) |
| 148 | + return label |
| 149 | + self._fns['_add_urls_to_data'] = _add_urls_to_data |
| 150 | + return self |
| 151 | + |
| 152 | + def add_to_dataset(self, dataset, signer: Callable[[bytes], str]) -> "LabelGenerator": |
| 153 | + def _add_to_dataset(label: Label): |
| 154 | + label.create_data_row(dataset, signer) |
| 155 | + return label |
| 156 | + self._fns['assign_datarow_ids'] = _add_to_dataset |
| 157 | + return self |
| 158 | + |
| 159 | + def add_urls_to_masks(self, signer: Callable[[bytes], str]) -> "LabelGenerator": |
| 160 | + """ |
| 161 | + Updates masks to have `url` attribute |
| 162 | + Doesn't update masks that already have urls |
| 163 | + """ |
| 164 | + def _add_urls_to_masks(label: Label): |
| 165 | + label.add_url_to_masks(signer) |
| 166 | + return label |
| 167 | + self._fns['add_urls_to_masks'] = _add_urls_to_masks |
| 168 | + return self |
| 169 | + |
| 170 | + |
| 171 | + |
| 172 | +LabelData = Union[LabelCollection, LabelGenerator] |
0 commit comments