Skip to content

Commit b188201

Browse files
Merge pull request #571 from Labelbox/kkim/AL-2219
[AL-2219] Add custom_metadata to input file for create_data_rows()
2 parents 0f7c336 + 34054d8 commit b188201

File tree

5 files changed

+190
-39
lines changed

5 files changed

+190
-39
lines changed

labelbox/schema/data_row_metadata.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -307,7 +307,7 @@ def _batch_upsert(
307307
data_row_id=m.data_row_id,
308308
fields=list(
309309
chain.from_iterable(
310-
self.parse_upsert(m) for m in m.fields))).dict(
310+
self._parse_upsert(m) for m in m.fields))).dict(
311311
by_alias=True))
312312
res = _batch_operations(_batch_upsert, items, self._batch_size)
313313
return res
@@ -404,7 +404,7 @@ def _bulk_export(_data_row_ids: List[str]) -> List[DataRowMetadata]:
404404
data_row_ids,
405405
batch_size=self._batch_size)
406406

407-
def parse_upsert(
407+
def _parse_upsert(
408408
self, metadatum: DataRowMetadataField
409409
) -> List[_UpsertDataRowMetadataInput]:
410410
"""Format for metadata upserts to GQL"""
@@ -435,6 +435,33 @@ def parse_upsert(
435435

436436
return [_UpsertDataRowMetadataInput(**p) for p in parsed]
437437

438+
# Convert metadata to DataRowMetadataField objects, parse all fields
439+
# and return a dictionary of metadata fields for upsert
440+
def parse_upsert_metadata(self, metadata_fields):
441+
442+
def _convert_metadata_field(metadata_field):
443+
if isinstance(metadata_field, DataRowMetadataField):
444+
return metadata_field
445+
elif isinstance(metadata_field, dict):
446+
if not all(key in metadata_field
447+
for key in ("schema_id", "value")):
448+
raise ValueError(
449+
f"Custom metadata field '{metadata_field}' must have 'schema_id' and 'value' keys"
450+
)
451+
return DataRowMetadataField(
452+
schema_id=metadata_field["schema_id"],
453+
value=metadata_field["value"])
454+
else:
455+
raise ValueError(
456+
f"Metadata field '{metadata_field}' is neither 'DataRowMetadataField' type or a dictionary"
457+
)
458+
459+
# Convert all metadata fields to DataRowMetadataField type
460+
metadata_fields = [_convert_metadata_field(m) for m in metadata_fields]
461+
parsed_metadata = list(
462+
chain.from_iterable(self._parse_upsert(m) for m in metadata_fields))
463+
return [m.dict(by_alias=True) for m in parsed_metadata]
464+
438465
def _validate_delete(self, delete: DeleteDataRowMetadata):
439466
if not len(delete.fields):
440467
raise ValueError(f"No fields specified for {delete.data_row_id}")

labelbox/schema/dataset.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from collections.abc import Iterable
66
import time
77
import ndjson
8-
from itertools import islice, chain
8+
from itertools import islice
99

1010
from concurrent.futures import ThreadPoolExecutor, as_completed
1111
from io import StringIO
@@ -82,13 +82,8 @@ def create_data_row(self, **kwargs) -> "DataRow":
8282
# Parse metadata fields, if they are provided
8383
if DataRow.custom_metadata.name in kwargs:
8484
mdo = self.client.get_data_row_metadata_ontology()
85-
metadata_fields = kwargs[DataRow.custom_metadata.name]
86-
metadata = list(
87-
chain.from_iterable(
88-
mdo.parse_upsert(m) for m in metadata_fields))
89-
kwargs[DataRow.custom_metadata.name] = [
90-
md.dict(by_alias=True) for md in metadata
91-
]
85+
kwargs[DataRow.custom_metadata.name] = mdo.parse_upsert_metadata(
86+
kwargs[DataRow.custom_metadata.name])
9287

9388
return self.client._create(DataRow, kwargs)
9489

@@ -268,6 +263,13 @@ def validate_attachments(item):
268263
)
269264
return attachments
270265

266+
def parse_metadata_fields(item):
267+
metadata_fields = item.get('custom_metadata')
268+
if metadata_fields:
269+
mdo = self.client.get_data_row_metadata_ontology()
270+
item['custom_metadata'] = mdo.parse_upsert_metadata(
271+
metadata_fields)
272+
271273
def format_row(item):
272274
# Formats user input into a consistent dict structure
273275
if isinstance(item, dict):
@@ -308,6 +310,8 @@ def convert_item(item):
308310
validate_keys(item)
309311
# Make sure attachments are valid
310312
validate_attachments(item)
313+
# Parse metadata fields if they exist
314+
parse_metadata_fields(item)
311315
# Upload any local file paths
312316
item = upload_if_necessary(item)
313317

tests/integration/test_batch.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def test_create_batch(configured_project: Project, big_dataset: Dataset):
3636

3737
data_rows = [dr.uid for dr in list(big_dataset.export_data_rows())]
3838
batch = configured_project.create_batch("test-batch", data_rows, 3)
39-
assert batch.name == 'test-batch'
39+
assert batch.name == "test-batch"
4040
assert batch.size == len(data_rows)
4141

4242

@@ -79,4 +79,4 @@ def test_export_data_rows(configured_project: Project, dataset: Dataset):
7979
exported_data_rows = [dr.uid for dr in result]
8080

8181
assert len(result) == n_data_rows
82-
assert set(data_rows) == set(exported_data_rows)
82+
assert set(data_rows) == set(exported_data_rows)

tests/integration/test_data_row_metadata.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -281,4 +281,4 @@ def test_parse_raw_metadata(mdo):
281281

282282
for row in parsed:
283283
for field in row.fields:
284-
assert mdo.parse_upsert(field)
284+
assert mdo._parse_upsert(field)

tests/integration/test_data_rows.py

Lines changed: 146 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,10 @@
1515
EMBEDDING_SCHEMA_ID = "ckpyije740000yxdk81pbgjdc"
1616
TEXT_SCHEMA_ID = "cko8s9r5v0001h2dk9elqdidh"
1717
CAPTURE_DT_SCHEMA_ID = "cko8sdzv70006h2dk8jg64zvb"
18+
EXPECTED_METADATA_SCHEMA_IDS = [
19+
SPLIT_SCHEMA_ID, TEST_SPLIT_ID, EMBEDDING_SCHEMA_ID, TEXT_SCHEMA_ID,
20+
CAPTURE_DT_SCHEMA_ID
21+
].sort()
1822

1923

2024
def make_metadata_fields():
@@ -31,6 +35,27 @@ def make_metadata_fields():
3135
return fields
3236

3337

38+
def make_metadata_fields_dict():
39+
embeddings = [0.0] * 128
40+
msg = "A message"
41+
time = datetime.utcnow()
42+
43+
fields = [{
44+
"schema_id": SPLIT_SCHEMA_ID,
45+
"value": TEST_SPLIT_ID
46+
}, {
47+
"schema_id": CAPTURE_DT_SCHEMA_ID,
48+
"value": time
49+
}, {
50+
"schema_id": TEXT_SCHEMA_ID,
51+
"value": msg
52+
}, {
53+
"schema_id": EMBEDDING_SCHEMA_ID,
54+
"value": embeddings
55+
}]
56+
return fields
57+
58+
3459
def test_get_data_row(datarow, client):
3560
assert client.get_data_row(datarow.uid)
3661

@@ -152,7 +177,7 @@ def test_data_row_single_creation(dataset, rand_gen, image_url):
152177
assert requests.get(data_row_2.row_data).content == data
153178

154179

155-
def test_data_row_single_creation_with_metadata(dataset, rand_gen, image_url):
180+
def test_create_data_row_with_metadata(dataset, image_url):
156181
client = dataset.client
157182
assert len(list(dataset.data_rows())) == 0
158183

@@ -167,38 +192,133 @@ def test_data_row_single_creation_with_metadata(dataset, rand_gen, image_url):
167192
requests.get(data_row.row_data).content
168193
assert data_row.media_attributes is not None
169194
assert len(data_row.custom_metadata) == 5
195+
assert [m["schemaId"] for m in data_row.custom_metadata
196+
].sort() == EXPECTED_METADATA_SCHEMA_IDS
170197

171-
with NamedTemporaryFile() as fp:
172-
data = rand_gen(str).encode()
173-
fp.write(data)
174-
fp.flush()
175-
data_row_2 = dataset.create_data_row(row_data=fp.name)
176-
assert len(list(dataset.data_rows())) == 2
177-
assert requests.get(data_row_2.row_data).content == data
178198

199+
def test_create_data_row_with_metadata_dict(dataset, image_url):
200+
client = dataset.client
201+
assert len(list(dataset.data_rows())) == 0
179202

180-
def test_data_row_single_creation_with_invalid_metadata(dataset, image_url):
203+
data_row = dataset.create_data_row(
204+
row_data=image_url, custom_metadata=make_metadata_fields_dict())
181205

182-
def make_invalid_metadata_fields():
183-
embeddings = [0.0] * 128
184-
msg = "A message"
185-
time = datetime.utcnow()
206+
assert len(list(dataset.data_rows())) == 1
207+
assert data_row.dataset() == dataset
208+
assert data_row.created_by() == client.get_user()
209+
assert data_row.organization() == client.get_organization()
210+
assert requests.get(image_url).content == \
211+
requests.get(data_row.row_data).content
212+
assert data_row.media_attributes is not None
213+
assert len(data_row.custom_metadata) == 5
214+
assert [m["schemaId"] for m in data_row.custom_metadata
215+
].sort() == EXPECTED_METADATA_SCHEMA_IDS
186216

187-
fields = [
188-
DataRowMetadataField(schema_id=SPLIT_SCHEMA_ID,
189-
value=TEST_SPLIT_ID),
190-
DataRowMetadataField(schema_id=CAPTURE_DT_SCHEMA_ID, value=time),
191-
DataRowMetadataField(schema_id=TEXT_SCHEMA_ID, value=msg),
192-
DataRowMetadataField(schema_id=EMBEDDING_SCHEMA_ID,
193-
value=embeddings),
194-
DataRowMetadataField(schema_id=EMBEDDING_SCHEMA_ID,
195-
value=embeddings),
196-
]
197-
return fields
217+
218+
def test_create_data_row_with_invalid_metadata(dataset, image_url):
219+
fields = make_metadata_fields()
220+
fields.append(
221+
DataRowMetadataField(schema_id=EMBEDDING_SCHEMA_ID, value=[0.0] * 128))
198222

199223
with pytest.raises(labelbox.exceptions.MalformedQueryException) as excinfo:
200-
dataset.create_data_row(row_data=image_url,
201-
custom_metadata=make_invalid_metadata_fields())
224+
dataset.create_data_row(row_data=image_url, custom_metadata=fields)
225+
226+
227+
def test_create_data_rows_with_metadata(dataset, image_url):
228+
client = dataset.client
229+
assert len(list(dataset.data_rows())) == 0
230+
231+
task = dataset.create_data_rows([
232+
{
233+
DataRow.row_data: image_url,
234+
DataRow.external_id: "row1",
235+
DataRow.custom_metadata: make_metadata_fields()
236+
},
237+
{
238+
DataRow.row_data: image_url,
239+
DataRow.external_id: "row2",
240+
"custom_metadata": make_metadata_fields()
241+
},
242+
{
243+
DataRow.row_data: image_url,
244+
DataRow.external_id: "row3",
245+
DataRow.custom_metadata: make_metadata_fields_dict()
246+
},
247+
{
248+
DataRow.row_data: image_url,
249+
DataRow.external_id: "row4",
250+
"custom_metadata": make_metadata_fields_dict()
251+
},
252+
])
253+
task.wait_till_done()
254+
255+
assert len(list(dataset.data_rows())) == 4
256+
for r in ["row1", "row2", "row3", "row4"]:
257+
row = list(dataset.data_rows(where=DataRow.external_id == r))[0]
258+
assert row.dataset() == dataset
259+
assert row.created_by() == client.get_user()
260+
assert row.organization() == client.get_organization()
261+
assert requests.get(image_url).content == \
262+
requests.get(row.row_data).content
263+
assert row.media_attributes is not None
264+
assert len(row.custom_metadata) == 5
265+
assert [m["schemaId"] for m in row.custom_metadata
266+
].sort() == EXPECTED_METADATA_SCHEMA_IDS
267+
268+
269+
def test_create_data_rows_with_invalid_metadata(dataset, image_url):
270+
fields = make_metadata_fields()
271+
fields.append(
272+
DataRowMetadataField(schema_id=EMBEDDING_SCHEMA_ID, value=[0.0] * 128))
273+
274+
task = dataset.create_data_rows([{
275+
DataRow.row_data: image_url,
276+
DataRow.custom_metadata: fields
277+
}])
278+
task.wait_till_done()
279+
assert task.status == "FAILED"
280+
281+
282+
def test_create_data_rows_with_metadata_missing_value(dataset, image_url):
283+
fields = make_metadata_fields()
284+
fields.append({"schemaId": "some schema id"})
285+
286+
with pytest.raises(ValueError) as exc:
287+
dataset.create_data_rows([
288+
{
289+
DataRow.row_data: image_url,
290+
DataRow.external_id: "row1",
291+
DataRow.custom_metadata: fields
292+
},
293+
])
294+
295+
296+
def test_create_data_rows_with_metadata_missing_schema_id(dataset, image_url):
297+
fields = make_metadata_fields()
298+
fields.append({"value": "some value"})
299+
300+
with pytest.raises(ValueError) as exc:
301+
dataset.create_data_rows([
302+
{
303+
DataRow.row_data: image_url,
304+
DataRow.external_id: "row1",
305+
DataRow.custom_metadata: fields
306+
},
307+
])
308+
309+
310+
def test_create_data_rows_with_metadata_wrong_type(dataset, image_url):
311+
fields = make_metadata_fields()
312+
fields.append("Neither DataRowMetadataField or dict")
313+
314+
with pytest.raises(ValueError) as exc:
315+
task = dataset.create_data_rows([
316+
{
317+
DataRow.row_data: image_url,
318+
DataRow.external_id: "row1",
319+
DataRow.custom_metadata: fields
320+
},
321+
])
202322

203323

204324
def test_data_row_update(dataset, rand_gen, image_url):

0 commit comments

Comments
 (0)