Skip to content

Commit b35518b

Browse files
author
Dmitriy Apollonin
committed
validate labels format prior sending to api
1 parent f8e612d commit b35518b

File tree

4 files changed

+82
-16
lines changed

4 files changed

+82
-16
lines changed

labelbox/schema/annotation_import.py

Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,17 @@ def _create_from_bytes(cls, client, variables, query_str, file_name,
141141
files = {file_name: file_data}
142142
return client.execute(data=data, files=files)
143143

144+
@classmethod
145+
def _get_ndjson_from_objects(cls, data, slug):
146+
if not isinstance(data, list):
147+
raise TypeError(f"{slug} must be in a form of list. Found {type(data)}")
148+
149+
data_str = ndjson.dumps(data)
150+
if not data_str:
151+
raise ValueError(f"{slug} cannot be empty")
152+
153+
return data_str.encode('utf-8')
154+
144155
def refresh(self) -> None:
145156
"""Synchronizes values of all fields with the database.
146157
"""
@@ -198,7 +209,7 @@ def create_from_file(cls, client: "labelbox.Client", model_run_id: str,
198209

199210
@classmethod
200211
def create_from_objects(cls, client: "labelbox.Client", model_run_id: str,
201-
name, predictions) -> "MEAPredictionImport":
212+
name, predictions: List[Dict[str, Any]]) -> "MEAPredictionImport":
202213
"""
203214
Create an MEA prediction import job from an in memory dictionary
204215
@@ -210,10 +221,8 @@ def create_from_objects(cls, client: "labelbox.Client", model_run_id: str,
210221
Returns:
211222
MEAPredictionImport
212223
"""
213-
data_str = ndjson.dumps(predictions)
214-
if not data_str:
215-
raise ValueError('annotations cannot be empty')
216-
data = data_str.encode('utf-8')
224+
data = cls._get_ndjson_from_objects(predictions, 'annotations')
225+
217226
return cls._create_mea_import_from_bytes(client, model_run_id, name,
218227
data, len(data))
219228

@@ -448,16 +457,13 @@ def create_from_objects(
448457
Returns:
449458
MALPredictionImport
450459
"""
451-
data_str = ndjson.dumps(predictions)
452-
if not data_str:
453-
raise ValueError('annotations cannot be empty')
454-
data = data_str.encode('utf-8')
460+
data = cls._get_ndjson_from_objects(predictions, 'annotations')
455461

456462
has_confidence = LabelsConfidencePresenceChecker.check(predictions)
457463
if has_confidence:
458464
logger.warning("""
459-
Confidence scores are not supported in MAL Prediction Import.
460-
Corresponding confidence score values will be ingored.
465+
Confidence scores are not supported in MAL Prediction Import.
466+
Corresponding confidence score values will be ignored.
461467
""")
462468
return cls._create_mal_import_from_bytes(client, project_id, name, data,
463469
len(data))
@@ -607,15 +613,12 @@ def create_from_objects(cls, client: "labelbox.Client", project_id: str,
607613
Returns:
608614
LabelImport
609615
"""
610-
data_str = ndjson.dumps(labels)
611-
if not data_str:
612-
raise ValueError('labels cannot be empty')
613-
data = data_str.encode('utf-8')
616+
data = cls._get_ndjson_from_objects(labels, 'labels')
614617

615618
has_confidence = LabelsConfidencePresenceChecker.check(labels)
616619
if has_confidence:
617620
logger.warning("""
618-
Confidence scores are not supported in Label Import.
621+
Confidence scores are not supported in Label Import.
619622
Corresponding confidence score values will be ignored.
620623
""")
621624
return cls._create_label_import_from_bytes(client, project_id, name,

labelbox/schema/bulk_import_request.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -310,6 +310,9 @@ def create_from_objects(cls,
310310
Returns:
311311
BulkImportRequest object
312312
"""
313+
if not isinstance(predictions, list):
314+
raise TypeError(f"annotations must be in a form of Iterable. Found {type(predictions)}")
315+
313316
if validate:
314317
_validate_ndjson(predictions, client.get_project(project_id))
315318

tests/unit/test_mal_import.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import uuid
2+
import pytest
23
from unittest.mock import MagicMock, patch
34

45
from labelbox.schema.annotation_import import MALPredictionImport, logger
@@ -39,3 +40,36 @@ def test_should_warn_user_about_unsupported_confidence():
3940
warning_mock.assert_called_once()
4041
"Confidence scores are not supported in MAL Prediction Import" in warning_mock.call_args_list[
4142
0].args[0]
43+
44+
45+
def test_invalid_labels_format():
46+
"""this test should confirm that annotations are required to be in a form of dict"""
47+
id = str(uuid.uuid4())
48+
49+
label = {
50+
"bbox": {
51+
"height": 428,
52+
"left": 2089,
53+
"top": 1251,
54+
"width": 158
55+
},
56+
"classifications": [{
57+
"answer": [{
58+
"schemaId": "ckrb1sfl8099e0y919v260awv",
59+
"confidence": 0.894
60+
}],
61+
"schemaId": "ckrb1sfkn099c0y910wbo0p1a"
62+
}],
63+
"dataRow": {
64+
"id": "ckrb1sf1i1g7i0ybcdc6oc8ct"
65+
},
66+
"schemaId": "ckrb1sfjx099a0y914hl319ie",
67+
"uuid": "d009925d-91a3-4f67-abd9-753453f5a584"
68+
}
69+
70+
with patch.object(MALPredictionImport, '_create_mal_import_from_bytes'):
71+
with pytest.raises(TypeError):
72+
MALPredictionImport.create_from_objects(client=MagicMock(),
73+
project_id=id,
74+
name=id,
75+
predictions=label)

tests/unit/test_unit_label_import.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import uuid
2+
import pytest
23
from unittest.mock import MagicMock, patch
34

45
from labelbox.schema.annotation_import import LabelImport, logger
@@ -33,3 +34,28 @@ def test_should_warn_user_about_unsupported_confidence():
3334
warning_mock.assert_called_once()
3435
"Confidence scores are not supported in Label Import" in warning_mock.call_args_list[
3536
0].args[0]
37+
38+
39+
def test_invalid_labels_format():
40+
"""this test should confirm that labels are required to be in a form of dict"""
41+
id = str(uuid.uuid4())
42+
43+
label = {
44+
"uuid": "b862c586-8614-483c-b5e6-82810f70cac0",
45+
"schemaId": "ckrazcueb16og0z6609jj7y3y",
46+
"dataRow": {
47+
"id": "ckrazctum0z8a0ybc0b0o0g0v"
48+
},
49+
"bbox": {
50+
"top": 1352,
51+
"left": 2275,
52+
"height": 350,
53+
"width": 139
54+
}
55+
}
56+
with patch.object(LabelImport, '_create_label_import_from_bytes'):
57+
with pytest.raises(TypeError):
58+
LabelImport.create_from_objects(client=MagicMock(),
59+
project_id=id,
60+
name=id,
61+
labels=label)

0 commit comments

Comments
 (0)