Skip to content

Commit 33a6c8f

Browse files
committed
AL-3578: Added MAL and LI warning message
1 parent 86c505e commit 33a6c8f

File tree

6 files changed

+157
-7
lines changed

6 files changed

+157
-7
lines changed

labelbox/data/annotation_types/label.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,8 @@ def frame_annotations(
6262
frame_dict = defaultdict(list)
6363
for annotation in self.annotations:
6464
if isinstance(
65-
annotation,
66-
(VideoObjectAnnotation, VideoClassificationAnnotation)):
65+
annotation,
66+
(VideoObjectAnnotation, VideoClassificationAnnotation)):
6767
frame_dict[annotation.frame].append(annotation)
6868
return frame_dict
6969

@@ -151,7 +151,8 @@ def assign_feature_schema_ids(
151151
elif isinstance(annotation, ObjectAnnotation):
152152
self._assign_or_raise(annotation, tool_lookup)
153153
for classification in annotation.classifications:
154-
self._assign_or_raise(classification, classification_lookup)
154+
self._assign_or_raise(
155+
classification, classification_lookup)
155156
self._assign_option(classification, classification_lookup)
156157
else:
157158
raise TypeError(

labelbox/data/mixins.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ class ConfidenceMixin(BaseModel):
1010

1111
@validator('confidence')
1212
def confidence_valid_float(cls, value):
13+
if value is None:
14+
return value
1315
if not isinstance(value, (int, float)) or not 0 <= value <= 1:
1416
raise ValueError('must be float within [0,1] range')
1517
return value

labelbox/schema/annotation_import.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,19 @@
33
import logging
44
import os
55
import time
6-
from typing import Any, Dict, List, BinaryIO
7-
from tqdm import tqdm # type: ignore
6+
from typing import Any, BinaryIO, Dict, List
87

98
import backoff
109
import ndjson
1110
import requests
11+
from tqdm import tqdm # type: ignore
1212

1313
import labelbox
1414
from labelbox.orm import query
1515
from labelbox.orm.db_object import DbObject
1616
from labelbox.orm.model import Field, Relationship
17+
from labelbox.schema.annotation_import_validators import \
18+
LabelsConfidencePresenceChecker
1719
from labelbox.schema.enums import AnnotationImportState
1820

1921
NDJSON_MIME_TYPE = "application/x-ndjson"
@@ -91,7 +93,7 @@ def wait_until_done(self,
9193
"""
9294
pbar = tqdm(total=100,
9395
bar_format="{n}% |{bar}| [{elapsed}, {rate_fmt}{postfix}]"
94-
) if show_progress else None
96+
) if show_progress else None
9597
while self.state.value == AnnotationImportState.RUNNING.value:
9698
logger.info(f"Sleeping for {sleep_time_seconds} seconds...")
9799
time.sleep(sleep_time_seconds)
@@ -451,6 +453,13 @@ def create_from_objects(
451453
if not data_str:
452454
raise ValueError('annotations cannot be empty')
453455
data = data_str.encode('utf-8')
456+
457+
has_confidence = LabelsConfidencePresenceChecker.check(predictions)
458+
if has_confidence:
459+
logger.warning("""
460+
Confidence scores are not supported in MAL Prediction Import.
461+
Corresponding confidence score values will be ingored.
462+
""")
454463
return cls._create_mal_import_from_bytes(client, project_id, name, data,
455464
len(data))
456465

@@ -603,6 +612,13 @@ def create_from_objects(cls, client: "labelbox.Client", project_id: str,
603612
if not data_str:
604613
raise ValueError('labels cannot be empty')
605614
data = data_str.encode('utf-8')
615+
616+
has_confidence = LabelsConfidencePresenceChecker.check(labels)
617+
if has_confidence:
618+
logger.warning("""
619+
Confidence scores are not supported in Label Import.
620+
Corresponding confidence score values will be ignored.
621+
""")
606622
return cls._create_label_import_from_bytes(client, project_id, name,
607623
data, len(data))
608624

@@ -661,7 +677,8 @@ def from_name(cls,
661677
}
662678
response = client.execute(query_str, params)
663679
if response is None:
664-
raise labelbox.exceptions.ResourceNotFoundError(LabelImport, params)
680+
raise labelbox.exceptions.ResourceNotFoundError(
681+
LabelImport, params)
665682
response = response["labelImport"]
666683
if as_json:
667684
return response
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
from typing import Any, Dict, List, Union
2+
3+
from labelbox.data.annotation_types.annotation import (
4+
ClassificationAnnotation, ObjectAnnotation, VideoClassificationAnnotation,
5+
VideoObjectAnnotation)
6+
from labelbox.data.annotation_types.classification.classification import (
7+
Checklist, ClassificationAnswer, Dropdown, Radio, Text)
8+
from labelbox.data.annotation_types.label import Label
9+
from labelbox.data.annotation_types.metrics.confusion_matrix import \
10+
ConfusionMatrixMetric
11+
from labelbox.data.annotation_types.metrics.scalar import ScalarMetric
12+
from labelbox.data.serialization.ndjson.converter import NDJsonConverter
13+
14+
15+
class LabelsConfidencePresenceChecker:
16+
"""
17+
Checks if a given list of labels contains at least one confidence score
18+
"""
19+
@classmethod
20+
def check(cls, raw_labels: List[Dict[str, Any]]):
21+
label_list = NDJsonConverter.deserialize(raw_labels).as_list()
22+
return any([cls._check_label(x) for x in label_list])
23+
24+
@classmethod
25+
def _check_label(cls, label: Label):
26+
return any([cls._check_annotation(x) for x in label.annotations])
27+
28+
@classmethod
29+
def _check_annotation(cls, annotation: Union[ClassificationAnnotation, ObjectAnnotation,
30+
VideoObjectAnnotation,
31+
VideoClassificationAnnotation, ScalarMetric,
32+
ConfusionMatrixMetric]):
33+
if annotation.confidence is not None:
34+
return True
35+
if annotation.classifications:
36+
return any([cls._check_classification(x) for x in annotation.classifications])
37+
return False
38+
39+
@classmethod
40+
def _check_classification(cls, classification: ClassificationAnnotation) -> bool:
41+
if isinstance(classification.value, (Checklist, Dropdown)):
42+
return any(cls._check_classification_answer(x) for x in classification.value.answer)
43+
if isinstance(classification.value, Radio):
44+
return cls._check_classification_answer(classification.value.answer)
45+
if isinstance(classification.value, Text):
46+
return False
47+
raise Exception(
48+
f"Unexpected classification value type {type(classification.value)}")
49+
50+
@classmethod
51+
def _check_classification_answer(cls, answer: ClassificationAnswer) -> bool:
52+
return answer.confidence is not None

tests/unit/test_label_import.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
import uuid
2+
from unittest.mock import MagicMock, patch
3+
4+
from labelbox.schema.annotation_import import LabelImport, logger
5+
6+
7+
def test_should_warn_user_about_unsupported_confidence():
8+
"""this test should check running state only to validate running, not completed"""
9+
id = str(uuid.uuid4())
10+
11+
labels = [{
12+
"uuid": "b862c586-8614-483c-b5e6-82810f70cac0",
13+
"schemaId": "ckrazcueb16og0z6609jj7y3y",
14+
"dataRow": {
15+
"id": "ckrazctum0z8a0ybc0b0o0g0v"
16+
},
17+
"confidence": 0.851,
18+
"bbox": {
19+
"top": 1352,
20+
"left": 2275,
21+
"height": 350,
22+
"width": 139
23+
}
24+
}, ]
25+
with patch.object(LabelImport, '_create_label_import_from_bytes'):
26+
with patch.object(logger, 'warning') as warning_mock:
27+
LabelImport.create_from_objects(
28+
client=MagicMock(),
29+
project_id=id,
30+
name=id,
31+
labels=labels)
32+
warning_mock.assert_called_once()
33+
"Confidence scores are not supported in Label Import" in warning_mock.call_args_list[
34+
0].args[0]

tests/unit/test_mal_import.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
import uuid
2+
from unittest.mock import MagicMock, patch
3+
4+
from labelbox.schema.annotation_import import MALPredictionImport, logger
5+
6+
7+
def test_should_warn_user_about_unsupported_confidence():
8+
"""this test should check running state only to validate running, not completed"""
9+
id = str(uuid.uuid4())
10+
11+
labels = [{
12+
"bbox": {
13+
"height": 428,
14+
"left": 2089,
15+
"top": 1251,
16+
"width": 158
17+
},
18+
"classifications": [
19+
{
20+
"answer": [
21+
{
22+
"schemaId": "ckrb1sfl8099e0y919v260awv",
23+
"confidence": 0.894
24+
}
25+
],
26+
"schemaId": "ckrb1sfkn099c0y910wbo0p1a"
27+
}
28+
],
29+
"dataRow": {
30+
"id": "ckrb1sf1i1g7i0ybcdc6oc8ct"
31+
},
32+
"schemaId": "ckrb1sfjx099a0y914hl319ie",
33+
"uuid": "d009925d-91a3-4f67-abd9-753453f5a584"
34+
}, ]
35+
with patch.object(MALPredictionImport, '_create_mal_import_from_bytes'):
36+
with patch.object(logger, 'warning') as warning_mock:
37+
MALPredictionImport.create_from_objects(
38+
client=MagicMock(),
39+
project_id=id,
40+
name=id,
41+
predictions=labels)
42+
warning_mock.assert_called_once()
43+
"Confidence scores are not supported in MAL Prediction Import" in warning_mock.call_args_list[
44+
0].args[0]

0 commit comments

Comments
 (0)