Skip to content

Commit eb946b2

Browse files
authored
Vb/test annotation upload by global keys (#1239)
2 parents 48ae01f + 9ec8d62 commit eb946b2

File tree

5 files changed

+113
-7
lines changed

5 files changed

+113
-7
lines changed

tests/data/serialization/ndjson/test_video.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
from labelbox import parser
1313

1414
from labelbox.data.serialization.ndjson.converter import NDJsonConverter
15-
from labelbox.schema.annotation_import import MALPredictionImport
1615

1716

1817
def test_video():

tests/integration/annotation_import/conftest.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -508,8 +508,12 @@ def configured_project(client, initial_dataset, ontology, rand_gen, image_url):
508508

509509
data_row_ids = []
510510

511-
for _ in range(len(ontology['tools']) + len(ontology['classifications'])):
512-
data_row_ids.append(dataset.create_data_row(row_data=image_url).uid)
511+
ontologies = ontology['tools'] + ontology['classifications']
512+
for ind in range(len(ontologies)):
513+
data_row_ids.append(
514+
dataset.create_data_row(
515+
row_data=image_url,
516+
global_key=f"gk_{ontologies[ind]['name']}_{rand_gen(str)}").uid)
513517
project._wait_until_data_rows_are_processed(data_row_ids=data_row_ids,
514518
sleep_interval=3)
515519

tests/integration/annotation_import/test_data_types.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,51 @@ def test_import_data_types(
180180
data_row.delete()
181181

182182

183+
def test_import_data_types_by_global_key(
184+
client,
185+
configured_project,
186+
initial_dataset,
187+
rand_gen,
188+
data_row_json_by_data_type,
189+
annotations_by_data_type,
190+
):
191+
192+
project = configured_project
193+
project_id = project.uid
194+
dataset = initial_dataset
195+
data_type_class = ImageData
196+
set_project_media_type_from_data_type(project, data_type_class)
197+
198+
data_row_ndjson = data_row_json_by_data_type['image']
199+
data_row_ndjson['global_key'] = str(uuid.uuid4())
200+
data_row = create_data_row_for_project(project, dataset, data_row_ndjson,
201+
rand_gen(str))
202+
203+
annotations_ndjson = annotations_by_data_type['image']
204+
annotations_list = [
205+
label.annotations
206+
for label in NDJsonConverter.deserialize(annotations_ndjson)
207+
]
208+
labels = [
209+
lb_types.Label(data=data_type_class(global_key=data_row.global_key),
210+
annotations=annotations)
211+
for annotations in annotations_list
212+
]
213+
214+
label_import = lb.LabelImport.create_from_objects(client, project_id,
215+
f'test-import-image',
216+
labels)
217+
label_import.wait_until_done()
218+
219+
assert label_import.state == AnnotationImportState.FINISHED
220+
assert len(label_import.errors) == 0
221+
exported_labels = project.export_labels(download=True)
222+
objects = exported_labels[0]['Label']['objects']
223+
classifications = exported_labels[0]['Label']['classifications']
224+
assert len(objects) + len(classifications) == len(labels)
225+
data_row.delete()
226+
227+
183228
def validate_iso_format(date_string: str):
184229
parsed_t = datetime.datetime.fromisoformat(
185230
date_string) #this will blow up if the string is not in iso format
@@ -321,6 +366,17 @@ def one_datarow(client, rand_gen, data_row_json_by_data_type, data_type):
321366
dataset.delete()
322367

323368

369+
@pytest.fixture
370+
def one_datarow_global_key(client, rand_gen, data_row_json_by_data_type):
371+
dataset = client.create_dataset(name=rand_gen(str))
372+
data_row_json = data_row_json_by_data_type['video']
373+
data_row = dataset.create_data_row(data_row_json)
374+
375+
yield data_row
376+
377+
dataset.delete()
378+
379+
324380
@pytest.mark.parametrize('data_type, data_class, annotations', test_params)
325381
def test_import_mal_annotations(client, configured_project_with_one_data_row,
326382
data_type, data_class, annotations, rand_gen,
@@ -348,3 +404,33 @@ def test_import_mal_annotations(client, configured_project_with_one_data_row,
348404

349405
assert import_annotations.errors == []
350406
# MAL Labels cannot be exported and compared to input labels
407+
408+
409+
def test_import_mal_annotations_global_key(client,
410+
configured_project_with_one_data_row,
411+
rand_gen, one_datarow_global_key):
412+
data_class = lb_types.VideoData
413+
data_row = one_datarow_global_key
414+
annotations = [video_mask_annotation]
415+
set_project_media_type_from_data_type(configured_project_with_one_data_row,
416+
data_class)
417+
418+
configured_project_with_one_data_row.create_batch(
419+
rand_gen(str),
420+
[data_row.uid],
421+
)
422+
423+
labels = [
424+
lb_types.Label(data=data_class(global_key=data_row.global_key),
425+
annotations=annotations)
426+
]
427+
428+
import_annotations = lb.MALPredictionImport.create_from_objects(
429+
client=client,
430+
project_id=configured_project_with_one_data_row.uid,
431+
name=f"import {str(uuid.uuid4())}",
432+
predictions=labels)
433+
import_annotations.wait_until_done()
434+
435+
assert import_annotations.errors == []
436+
# MAL Labels cannot be exported and compared to input labels

tests/integration/annotation_import/test_mea_prediction_import.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,25 @@ def test_create_from_objects(model_run_with_data_rows, object_predictions,
3737
annotation_import.wait_until_done()
3838

3939

40+
def test_create_from_objects_global_key(client, model_run_with_data_rows,
41+
entity_inference,
42+
annotation_import_test_helpers):
43+
name = str(uuid.uuid4())
44+
dr = client.get_data_row(entity_inference['dataRow']['id'])
45+
del entity_inference['dataRow']['id']
46+
entity_inference['dataRow']['globalKey'] = dr.global_key
47+
object_predictions = [entity_inference]
48+
49+
annotation_import = model_run_with_data_rows.add_predictions(
50+
name=name, predictions=object_predictions)
51+
52+
assert annotation_import.model_run_id == model_run_with_data_rows.uid
53+
annotation_import_test_helpers.check_running_state(annotation_import, name)
54+
annotation_import_test_helpers.assert_file_content(
55+
annotation_import.input_file_url, object_predictions)
56+
annotation_import.wait_until_done()
57+
58+
4059
def test_create_from_objects_with_confidence(predictions_with_confidence,
4160
model_run_with_data_rows,
4261
annotation_import_test_helpers):

tests/integration/annotation_import/test_upsert_prediction_import.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
import uuid
22
from labelbox import parser
33
import pytest
4-
5-
from labelbox.schema.annotation_import import AnnotationImportState, MEAPredictionImport
64
"""
75
- Here we only want to check that the uploads are calling the validation
86
- Then with unit tests we can check the types of errors raised
@@ -28,7 +26,7 @@ def test_create_from_url(client, tmp_path, object_predictions,
2826
if p['dataRow']['id'] in model_run_data_rows
2927
]
3028
with file_path.open("w") as f:
31-
ndjson.dump(predictions, f)
29+
parser.dump(predictions, f)
3230

3331
# Needs to have data row ids
3432

@@ -114,7 +112,7 @@ def test_create_from_local_file(tmp_path, model_run_with_data_rows,
114112
]
115113

116114
with file_path.open("w") as f:
117-
ndjson.dump(predictions, f)
115+
parser.dump(predictions, f)
118116

119117
annotation_import, batch, mal_prediction_import = model_run_with_data_rows.upsert_predictions_and_send_to_project(
120118
name=name,

0 commit comments

Comments
 (0)