Skip to content

Commit 632432d

Browse files
author
Kevin Kim
committed
Adding tests
1 parent f6ce4b6 commit 632432d

File tree

2 files changed

+175
-58
lines changed

2 files changed

+175
-58
lines changed

labelbox/schema/dataset.py

Lines changed: 29 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -82,14 +82,8 @@ def create_data_row(self, **kwargs) -> "DataRow":
8282

8383
# Parse metadata fields, if they are provided
8484
if DataRow.custom_metadata.name in kwargs:
85-
mdo = self.client.get_data_row_metadata_ontology()
86-
metadata_fields = kwargs[DataRow.custom_metadata.name]
87-
metadata = list(
88-
chain.from_iterable(
89-
mdo.parse_upsert(m) for m in metadata_fields))
90-
kwargs[DataRow.custom_metadata.name] = [
91-
md.dict(by_alias=True) for md in metadata
92-
]
85+
kwargs[DataRow.custom_metadata.name] = self._parse_metadata(
86+
kwargs[DataRow.custom_metadata.name])
9387

9488
return self.client._create(DataRow, kwargs)
9589

@@ -269,32 +263,10 @@ def validate_attachments(item):
269263
)
270264
return attachments
271265

272-
def convert_metadata_field(metadata_field):
273-
if isinstance(metadata_field, DataRowMetadataField):
274-
return metadata_field
275-
elif isinstance(metadata_field, dict):
276-
if not all (key in metadata_field for key in ("schema_id", "value")):
277-
raise ValueError(f"Custom metadata fields must have 'schema_id' and 'value' keys")
278-
return DataRowMetadataField(schema_id=metadata_field["schema_id"], value=metadata_field["value"])
279-
else:
280-
raise ValueError(f"Metadata field is neither 'DataRowMetadataField' type or a dictionary!")
281-
282266
def parse_metadata_fields(item):
283267
metadata_fields = item.get('custom_metadata')
284268
if metadata_fields:
285-
# Convert all metadata fields to DataRowMetadataField type
286-
metadata_fields = [
287-
convert_metadata_field(m)
288-
for m in metadata_fields
289-
]
290-
mdo = self.client.get_data_row_metadata_ontology()
291-
metadata = list(
292-
chain.from_iterable(
293-
mdo.parse_upsert(m) for m in metadata_fields))
294-
metadata_fields = [
295-
md.dict(by_alias=True) for md in metadata
296-
]
297-
item['custom_metadata'] = metadata_fields
269+
item['custom_metadata'] = self._parse_metadata(metadata_fields)
298270

299271
def format_row(item):
300272
# Formats user input into a consistent dict structure
@@ -329,7 +301,7 @@ def convert_item(item):
329301
# Don't make any changes to tms data
330302
if "tileLayerUrl" in item:
331303
validate_attachments(item)
332-
return item
304+
return item
333305
# Convert all payload variations into the same dict format
334306
item = format_row(item)
335307
# Make sure required keys exist (and there are no extra keys)
@@ -455,3 +427,28 @@ def export_data_rows(self, timeout_seconds=120) -> Generator:
455427
logger.debug("Dataset '%s' data row export, waiting for server...",
456428
self.uid)
457429
time.sleep(sleep_time)
430+
431+
def _convert_metadata_field(self, metadata_field):
432+
if isinstance(metadata_field, DataRowMetadataField):
433+
return metadata_field
434+
elif isinstance(metadata_field, dict):
435+
if not all(key in metadata_field for key in ("schema_id", "value")):
436+
raise ValueError(
437+
f"Custom metadata field '{metadata_field}' must have 'schema_id' and 'value' keys"
438+
)
439+
return DataRowMetadataField(schema_id=metadata_field["schema_id"],
440+
value=metadata_field["value"])
441+
else:
442+
raise ValueError(
443+
f"Metadata field '{metadata_field}' is neither 'DataRowMetadataField' type or a dictionary!"
444+
)
445+
446+
def _parse_metadata(self, metadata_fields):
447+
# Convert all metadata fields to DataRowMetadataField type
448+
metadata_fields = [
449+
self._convert_metadata_field(m) for m in metadata_fields
450+
]
451+
mdo = self.client.get_data_row_metadata_ontology()
452+
parsed_metadata = list(
453+
chain.from_iterable(mdo.parse_upsert(m) for m in metadata_fields))
454+
return [m.dict(by_alias=True) for m in parsed_metadata]

tests/integration/test_data_rows.py

Lines changed: 146 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,10 @@
1515
EMBEDDING_SCHEMA_ID = "ckpyije740000yxdk81pbgjdc"
1616
TEXT_SCHEMA_ID = "cko8s9r5v0001h2dk9elqdidh"
1717
CAPTURE_DT_SCHEMA_ID = "cko8sdzv70006h2dk8jg64zvb"
18+
EXPECTED_METADATA_SCHEMA_IDS = [
19+
SPLIT_SCHEMA_ID, TEST_SPLIT_ID, EMBEDDING_SCHEMA_ID, TEXT_SCHEMA_ID,
20+
CAPTURE_DT_SCHEMA_ID
21+
].sort()
1822

1923

2024
def make_metadata_fields():
@@ -31,6 +35,27 @@ def make_metadata_fields():
3135
return fields
3236

3337

38+
def make_metadata_fields_dict():
39+
embeddings = [0.0] * 128
40+
msg = "A message"
41+
time = datetime.utcnow()
42+
43+
fields = [{
44+
"schema_id": SPLIT_SCHEMA_ID,
45+
"value": TEST_SPLIT_ID
46+
}, {
47+
"schema_id": CAPTURE_DT_SCHEMA_ID,
48+
"value": time
49+
}, {
50+
"schema_id": TEXT_SCHEMA_ID,
51+
"value": msg
52+
}, {
53+
"schema_id": EMBEDDING_SCHEMA_ID,
54+
"value": embeddings
55+
}]
56+
return fields
57+
58+
3459
def test_get_data_row(datarow, client):
3560
assert client.get_data_row(datarow.uid)
3661

@@ -152,7 +177,7 @@ def test_data_row_single_creation(dataset, rand_gen, image_url):
152177
assert requests.get(data_row_2.row_data).content == data
153178

154179

155-
def test_data_row_single_creation_with_metadata(dataset, rand_gen, image_url):
180+
def test_create_data_row_with_metadata(dataset, image_url):
156181
client = dataset.client
157182
assert len(list(dataset.data_rows())) == 0
158183

@@ -167,38 +192,133 @@ def test_data_row_single_creation_with_metadata(dataset, rand_gen, image_url):
167192
requests.get(data_row.row_data).content
168193
assert data_row.media_attributes is not None
169194
assert len(data_row.custom_metadata) == 5
195+
assert [m["schemaId"] for m in data_row.custom_metadata
196+
].sort() == EXPECTED_METADATA_SCHEMA_IDS
170197

171-
with NamedTemporaryFile() as fp:
172-
data = rand_gen(str).encode()
173-
fp.write(data)
174-
fp.flush()
175-
data_row_2 = dataset.create_data_row(row_data=fp.name)
176-
assert len(list(dataset.data_rows())) == 2
177-
assert requests.get(data_row_2.row_data).content == data
178198

199+
def test_create_data_row_with_metadata_dict(dataset, image_url):
200+
client = dataset.client
201+
assert len(list(dataset.data_rows())) == 0
179202

180-
def test_data_row_single_creation_with_invalid_metadata(dataset, image_url):
203+
data_row = dataset.create_data_row(
204+
row_data=image_url, custom_metadata=make_metadata_fields_dict())
181205

182-
def make_invalid_metadata_fields():
183-
embeddings = [0.0] * 128
184-
msg = "A message"
185-
time = datetime.utcnow()
206+
assert len(list(dataset.data_rows())) == 1
207+
assert data_row.dataset() == dataset
208+
assert data_row.created_by() == client.get_user()
209+
assert data_row.organization() == client.get_organization()
210+
assert requests.get(image_url).content == \
211+
requests.get(data_row.row_data).content
212+
assert data_row.media_attributes is not None
213+
assert len(data_row.custom_metadata) == 5
214+
assert [m["schemaId"] for m in data_row.custom_metadata
215+
].sort() == EXPECTED_METADATA_SCHEMA_IDS
186216

187-
fields = [
188-
DataRowMetadataField(schema_id=SPLIT_SCHEMA_ID,
189-
value=TEST_SPLIT_ID),
190-
DataRowMetadataField(schema_id=CAPTURE_DT_SCHEMA_ID, value=time),
191-
DataRowMetadataField(schema_id=TEXT_SCHEMA_ID, value=msg),
192-
DataRowMetadataField(schema_id=EMBEDDING_SCHEMA_ID,
193-
value=embeddings),
194-
DataRowMetadataField(schema_id=EMBEDDING_SCHEMA_ID,
195-
value=embeddings),
196-
]
197-
return fields
217+
218+
def test_create_data_row_with_invalid_metadata(dataset, image_url):
219+
fields = make_metadata_fields()
220+
fields.append(
221+
DataRowMetadataField(schema_id=EMBEDDING_SCHEMA_ID, value=[0.0] * 128))
198222

199223
with pytest.raises(labelbox.exceptions.MalformedQueryException) as excinfo:
200-
dataset.create_data_row(row_data=image_url,
201-
custom_metadata=make_invalid_metadata_fields())
224+
dataset.create_data_row(row_data=image_url, custom_metadata=fields)
225+
226+
227+
def test_create_data_rows_with_metadata(dataset, image_url):
228+
client = dataset.client
229+
assert len(list(dataset.data_rows())) == 0
230+
231+
task = dataset.create_data_rows([
232+
{
233+
DataRow.row_data: image_url,
234+
DataRow.external_id: "row1",
235+
DataRow.custom_metadata: make_metadata_fields()
236+
},
237+
{
238+
DataRow.row_data: image_url,
239+
DataRow.external_id: "row2",
240+
"custom_metadata": make_metadata_fields()
241+
},
242+
{
243+
DataRow.row_data: image_url,
244+
DataRow.external_id: "row3",
245+
DataRow.custom_metadata: make_metadata_fields_dict()
246+
},
247+
{
248+
DataRow.row_data: image_url,
249+
DataRow.external_id: "row4",
250+
"custom_metadata": make_metadata_fields_dict()
251+
},
252+
])
253+
task.wait_till_done()
254+
255+
assert len(list(dataset.data_rows())) == 4
256+
for r in ["row1", "row2", "row3", "row4"]:
257+
row = list(dataset.data_rows(where=DataRow.external_id == r))[0]
258+
assert row.dataset() == dataset
259+
assert row.created_by() == client.get_user()
260+
assert row.organization() == client.get_organization()
261+
assert requests.get(image_url).content == \
262+
requests.get(row.row_data).content
263+
assert row.media_attributes is not None
264+
assert len(row.custom_metadata) == 5
265+
assert [m["schemaId"] for m in row.custom_metadata
266+
].sort() == EXPECTED_METADATA_SCHEMA_IDS
267+
268+
269+
def test_create_data_rows_with_invalid_metadata(dataset, image_url):
270+
fields = make_metadata_fields()
271+
fields.append(
272+
DataRowMetadataField(schema_id=EMBEDDING_SCHEMA_ID, value=[0.0] * 128))
273+
274+
task = dataset.create_data_rows([{
275+
DataRow.row_data: image_url,
276+
DataRow.custom_metadata: fields
277+
}])
278+
task.wait_till_done()
279+
assert task.status == "FAILED"
280+
281+
282+
def test_create_data_rows_with_metadata_missing_value(dataset, image_url):
283+
fields = make_metadata_fields()
284+
fields.append({"schemaId": "some schema id"})
285+
286+
with pytest.raises(ValueError) as exc:
287+
dataset.create_data_rows([
288+
{
289+
DataRow.row_data: image_url,
290+
DataRow.external_id: "row1",
291+
DataRow.custom_metadata: fields
292+
},
293+
])
294+
295+
296+
def test_create_data_rows_with_metadata_missing_schema_id(dataset, image_url):
297+
fields = make_metadata_fields()
298+
fields.append({"value": "some value"})
299+
300+
with pytest.raises(ValueError) as exc:
301+
dataset.create_data_rows([
302+
{
303+
DataRow.row_data: image_url,
304+
DataRow.external_id: "row1",
305+
DataRow.custom_metadata: fields
306+
},
307+
])
308+
309+
310+
def test_create_data_rows_with_metadata_wrong_type(dataset, image_url):
311+
fields = make_metadata_fields()
312+
fields.append("Neither DataRowMetadataField or dict")
313+
314+
with pytest.raises(ValueError) as exc:
315+
task = dataset.create_data_rows([
316+
{
317+
DataRow.row_data: image_url,
318+
DataRow.external_id: "row1",
319+
DataRow.custom_metadata: fields
320+
},
321+
])
202322

203323

204324
def test_data_row_update(dataset, rand_gen, image_url):

0 commit comments

Comments
 (0)