|
1 | | -import json |
2 | | -import os |
3 | 1 | from concurrent.futures import ThreadPoolExecutor, as_completed |
4 | 2 |
|
5 | | -from typing import Iterable, List |
| 3 | +from typing import List |
6 | 4 |
|
7 | | -from labelbox.exceptions import InvalidQueryError |
8 | | -from labelbox.exceptions import InvalidAttributeError |
9 | | -from labelbox.exceptions import MalformedQueryException |
10 | | -from labelbox.orm.model import Entity |
11 | | -from labelbox.orm.model import Field |
12 | | -from labelbox.schema.embedding import EmbeddingVector |
13 | | -from labelbox.pydantic_compat import BaseModel |
14 | | -from labelbox.schema.internal.datarow_upload_constants import ( |
15 | | - MAX_DATAROW_PER_API_OPERATION, FILE_UPLOAD_THREAD_COUNT) |
| 5 | +from labelbox import pydantic_compat |
16 | 6 | from labelbox.schema.internal.data_row_upsert_item import DataRowUpsertItem |
| 7 | +from labelbox.schema.internal.descriptor_file_creator import DescriptorFileCreator |
17 | 8 |
|
18 | 9 |
|
19 | | -class UploadManifest(BaseModel): |
| 10 | +class UploadManifest(pydantic_compat.BaseModel): |
20 | 11 | source: str |
21 | 12 | item_count: int |
22 | 13 | chunk_uris: List[str] |
23 | 14 |
|
24 | 15 |
|
25 | | -class DataRowUploader: |
| 16 | +SOURCE_SDK = "SDK" |
26 | 17 |
|
27 | | - @staticmethod |
28 | | - def create_descriptor_file(client, |
29 | | - items, |
30 | | - max_attachments_per_data_row=None, |
31 | | - is_upsert=False): |
32 | | - """ |
33 | | - This function is shared by `Dataset.create_data_rows`, `Dataset.create_data_rows_sync` and `Dataset.update_data_rows`. |
34 | | - It is used to prepare the input file. The user defined input is validated, processed, and json stringified. |
35 | | - Finally the json data is uploaded to gcs and a uri is returned. This uri can be passed as a parameter to a mutation that uploads data rows |
36 | 18 |
|
37 | | - Each element in `items` can be either a `str` or a `dict`. If |
38 | | - it is a `str`, then it is interpreted as a local file path. The file |
39 | | - is uploaded to Labelbox and a DataRow referencing it is created. |
| 19 | +def upload_in_chunks(client, specs: List[DataRowUpsertItem], |
| 20 | + file_upload_thread_count: int, |
| 21 | + max_chunk_size_bytes: int) -> UploadManifest: |
| 22 | + empty_specs = list(filter(lambda spec: spec.is_empty(), specs)) |
40 | 23 |
|
41 | | - If an item is a `dict`, then it could support one of the two following structures |
42 | | - 1. For static imagery, video, and text it should map `DataRow` field names to values. |
43 | | - At the minimum an `items` passed as a `dict` must contain a `row_data` key and value. |
44 | | - If the value for row_data is a local file path and the path exists, |
45 | | - then the local file will be uploaded to labelbox. |
| 24 | + if empty_specs: |
| 25 | + ids = list(map(lambda spec: spec.id.get("value"), empty_specs)) |
| 26 | + raise ValueError(f"The following items have an empty payload: {ids}") |
46 | 27 |
|
47 | | - 2. For tiled imagery the dict must match the import structure specified in the link below |
48 | | - https://docs.labelbox.com/data-model/en/index-en#tiled-imagery-import |
| 28 | + chunk_uris = DescriptorFileCreator(client).create( |
| 29 | + specs, max_chunk_size_bytes=max_chunk_size_bytes) |
49 | 30 |
|
50 | | - >>> dataset.create_data_rows([ |
51 | | - >>> {DataRow.row_data:"http://my_site.com/photos/img_01.jpg"}, |
52 | | - >>> {DataRow.row_data:"/path/to/file1.jpg"}, |
53 | | - >>> "path/to/file2.jpg", |
54 | | - >>> {DataRow.row_data: {"tileLayerUrl" : "http://", ...}} |
55 | | - >>> {DataRow.row_data: {"type" : ..., 'version' : ..., 'messages' : [...]}} |
56 | | - >>> ]) |
57 | | -
|
58 | | - For an example showing how to upload tiled data_rows see the following notebook: |
59 | | - https://github.com/Labelbox/labelbox-python/blob/ms/develop/model_assisted_labeling/tiled_imagery_mal.ipynb |
60 | | -
|
61 | | - Args: |
62 | | - items (iterable of (dict or str)): See above for details. |
63 | | - max_attachments_per_data_row (Optional[int]): Param used during attachment validation to determine |
64 | | - if the user has provided too many attachments. |
65 | | -
|
66 | | - Returns: |
67 | | - uri (string): A reference to the uploaded json data. |
68 | | -
|
69 | | - Raises: |
70 | | - InvalidQueryError: If the `items` parameter does not conform to |
71 | | - the specification above or if the server did not accept the |
72 | | - DataRow creation request (unknown reason). |
73 | | - InvalidAttributeError: If there are fields in `items` not valid for |
74 | | - a DataRow. |
75 | | - ValueError: When the upload parameters are invalid |
76 | | - """ |
77 | | - file_upload_thread_count = FILE_UPLOAD_THREAD_COUNT |
78 | | - DataRow = Entity.DataRow |
79 | | - AssetAttachment = Entity.AssetAttachment |
80 | | - |
81 | | - def upload_if_necessary(item): |
82 | | - if is_upsert and 'row_data' not in item: |
83 | | - # When upserting, row_data is not required |
84 | | - return item |
85 | | - row_data = item['row_data'] |
86 | | - if isinstance(row_data, str) and os.path.exists(row_data): |
87 | | - item_url = client.upload_file(row_data) |
88 | | - item['row_data'] = item_url |
89 | | - if 'external_id' not in item: |
90 | | - # Default `external_id` to local file name |
91 | | - item['external_id'] = row_data |
92 | | - return item |
93 | | - |
94 | | - def validate_attachments(item): |
95 | | - attachments = item.get('attachments') |
96 | | - if attachments: |
97 | | - if isinstance(attachments, list): |
98 | | - if max_attachments_per_data_row and len( |
99 | | - attachments) > max_attachments_per_data_row: |
100 | | - raise ValueError( |
101 | | - f"Max attachments number of supported attachments per data row is {max_attachments_per_data_row}." |
102 | | - f" Found {len(attachments)}. Condense multiple attachments into one with the HTML attachment type if necessary." |
103 | | - ) |
104 | | - for attachment in attachments: |
105 | | - AssetAttachment.validate_attachment_json(attachment) |
106 | | - else: |
107 | | - raise ValueError( |
108 | | - f"Attachments must be a list. Found {type(attachments)}" |
109 | | - ) |
110 | | - return attachments |
111 | | - |
112 | | - def validate_embeddings(item): |
113 | | - embeddings = item.get("embeddings") |
114 | | - if embeddings: |
115 | | - item["embeddings"] = [ |
116 | | - EmbeddingVector(**e).to_gql() for e in embeddings |
117 | | - ] |
118 | | - |
119 | | - def validate_conversational_data(conversational_data: list) -> None: |
120 | | - """ |
121 | | - Checks each conversational message for keys expected as per https://docs.labelbox.com/reference/text-conversational#sample-conversational-json |
122 | | -
|
123 | | - Args: |
124 | | - conversational_data (list): list of dictionaries. |
125 | | - """ |
126 | | - |
127 | | - def check_message_keys(message): |
128 | | - accepted_message_keys = set([ |
129 | | - "messageId", "timestampUsec", "content", "user", "align", |
130 | | - "canLabel" |
131 | | - ]) |
132 | | - for key in message.keys(): |
133 | | - if not key in accepted_message_keys: |
134 | | - raise KeyError( |
135 | | - f"Invalid {key} key found! Accepted keys in messages list is {accepted_message_keys}" |
136 | | - ) |
137 | | - |
138 | | - if conversational_data and not isinstance(conversational_data, |
139 | | - list): |
140 | | - raise ValueError( |
141 | | - f"conversationalData must be a list. Found {type(conversational_data)}" |
142 | | - ) |
143 | | - |
144 | | - [check_message_keys(message) for message in conversational_data] |
145 | | - |
146 | | - def parse_metadata_fields(item): |
147 | | - metadata_fields = item.get('metadata_fields') |
148 | | - if metadata_fields: |
149 | | - mdo = client.get_data_row_metadata_ontology() |
150 | | - item['metadata_fields'] = mdo.parse_upsert_metadata( |
151 | | - metadata_fields) |
152 | | - |
153 | | - def format_row(item): |
154 | | - # Formats user input into a consistent dict structure |
155 | | - if isinstance(item, dict): |
156 | | - # Convert fields to strings |
157 | | - item = { |
158 | | - key.name if isinstance(key, Field) else key: value |
159 | | - for key, value in item.items() |
160 | | - } |
161 | | - elif isinstance(item, str): |
162 | | - # The main advantage of using a string over a dict is that the user is specifying |
163 | | - # that the file should exist locally. |
164 | | - # That info is lost after this section so we should check for it here. |
165 | | - if not os.path.exists(item): |
166 | | - raise ValueError(f"Filepath {item} does not exist.") |
167 | | - item = {"row_data": item, "external_id": item} |
168 | | - return item |
169 | | - |
170 | | - def validate_keys(item): |
171 | | - if not is_upsert and 'row_data' not in item: |
172 | | - raise InvalidQueryError( |
173 | | - "`row_data` missing when creating DataRow.") |
174 | | - |
175 | | - if isinstance(item.get('row_data'), |
176 | | - str) and item.get('row_data').startswith("s3:/"): |
177 | | - raise InvalidQueryError( |
178 | | - "row_data: s3 assets must start with 'https'.") |
179 | | - allowed_extra_fields = { |
180 | | - 'attachments', 'media_type', 'dataset_id', 'embeddings' |
181 | | - } |
182 | | - invalid_keys = set(item) - {f.name for f in DataRow.fields() |
183 | | - } - allowed_extra_fields |
184 | | - if invalid_keys: |
185 | | - raise InvalidAttributeError(DataRow, invalid_keys) |
186 | | - return item |
187 | | - |
188 | | - def format_legacy_conversational_data(item): |
189 | | - messages = item.pop("conversationalData") |
190 | | - version = item.pop("version", 1) |
191 | | - type = item.pop("type", "application/vnd.labelbox.conversational") |
192 | | - if "externalId" in item: |
193 | | - external_id = item.pop("externalId") |
194 | | - item["external_id"] = external_id |
195 | | - if "globalKey" in item: |
196 | | - global_key = item.pop("globalKey") |
197 | | - item["globalKey"] = global_key |
198 | | - validate_conversational_data(messages) |
199 | | - one_conversation = \ |
200 | | - { |
201 | | - "type": type, |
202 | | - "version": version, |
203 | | - "messages": messages |
204 | | - } |
205 | | - item["row_data"] = one_conversation |
206 | | - return item |
207 | | - |
208 | | - def convert_item(data_row_item): |
209 | | - if isinstance(data_row_item, DataRowUpsertItem): |
210 | | - item = data_row_item.payload |
211 | | - else: |
212 | | - item = data_row_item |
213 | | - |
214 | | - if "tileLayerUrl" in item: |
215 | | - validate_attachments(item) |
216 | | - return item |
217 | | - |
218 | | - if "conversationalData" in item: |
219 | | - format_legacy_conversational_data(item) |
220 | | - |
221 | | - # Convert all payload variations into the same dict format |
222 | | - item = format_row(item) |
223 | | - # Make sure required keys exist (and there are no extra keys) |
224 | | - validate_keys(item) |
225 | | - # Make sure attachments are valid |
226 | | - validate_attachments(item) |
227 | | - # Make sure embeddings are valid |
228 | | - validate_embeddings(item) |
229 | | - # Parse metadata fields if they exist |
230 | | - parse_metadata_fields(item) |
231 | | - # Upload any local file paths |
232 | | - item = upload_if_necessary(item) |
233 | | - |
234 | | - if isinstance(data_row_item, DataRowUpsertItem): |
235 | | - return {'id': data_row_item.id, 'payload': item} |
236 | | - else: |
237 | | - return item |
238 | | - |
239 | | - if not isinstance(items, Iterable): |
240 | | - raise ValueError( |
241 | | - f"Must pass an iterable to create_data_rows. Found {type(items)}" |
242 | | - ) |
243 | | - |
244 | | - if len(items) > MAX_DATAROW_PER_API_OPERATION: |
245 | | - raise MalformedQueryException( |
246 | | - f"Cannot create more than {MAX_DATAROW_PER_API_OPERATION} DataRows per function call." |
247 | | - ) |
248 | | - |
249 | | - with ThreadPoolExecutor(file_upload_thread_count) as executor: |
250 | | - futures = [executor.submit(convert_item, item) for item in items] |
251 | | - items = [future.result() for future in as_completed(futures)] |
252 | | - # Prepare and upload the desciptor file |
253 | | - data = json.dumps(items) |
254 | | - return client.upload_data(data, |
255 | | - content_type="application/json", |
256 | | - filename="json_import.json") |
257 | | - |
258 | | - @staticmethod |
259 | | - def upload_in_chunks(client, specs: List[DataRowUpsertItem], |
260 | | - file_upload_thread_count: int, |
261 | | - upsert_chunk_size: int) -> UploadManifest: |
262 | | - empty_specs = list(filter(lambda spec: spec.is_empty(), specs)) |
263 | | - |
264 | | - if empty_specs: |
265 | | - ids = list(map(lambda spec: spec.id.get("value"), empty_specs)) |
266 | | - raise ValueError( |
267 | | - f"The following items have an empty payload: {ids}") |
268 | | - |
269 | | - chunks = [ |
270 | | - specs[i:i + upsert_chunk_size] |
271 | | - for i in range(0, len(specs), upsert_chunk_size) |
272 | | - ] |
273 | | - |
274 | | - def _upload_chunk(chunk): |
275 | | - return DataRowUploader.create_descriptor_file(client, |
276 | | - chunk, |
277 | | - is_upsert=True) |
278 | | - |
279 | | - with ThreadPoolExecutor(file_upload_thread_count) as executor: |
280 | | - futures = [ |
281 | | - executor.submit(_upload_chunk, chunk) for chunk in chunks |
282 | | - ] |
283 | | - chunk_uris = [future.result() for future in as_completed(futures)] |
284 | | - |
285 | | - return UploadManifest(source="SDK", |
286 | | - item_count=len(specs), |
287 | | - chunk_uris=chunk_uris) |
| 31 | + return UploadManifest(source=SOURCE_SDK, |
| 32 | + item_count=len(specs), |
| 33 | + chunk_uris=chunk_uris) |
0 commit comments