Skip to content

Commit 6a0e9e3

Browse files
committed
AL-3578: Simplify confidence presence checker
1 parent cad8224 commit 6a0e9e3

File tree

1 file changed

+17
-54
lines changed

1 file changed

+17
-54
lines changed
Lines changed: 17 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,4 @@
1-
from typing import Any, Dict, List, Union
2-
3-
from labelbox.data.annotation_types.annotation import (
4-
ClassificationAnnotation, ObjectAnnotation, VideoClassificationAnnotation,
5-
VideoObjectAnnotation)
6-
from labelbox.data.annotation_types.classification.classification import (
7-
Checklist, ClassificationAnswer, Dropdown, Radio, Text)
8-
from labelbox.data.annotation_types.label import Label
9-
from labelbox.data.annotation_types.metrics.confusion_matrix import \
10-
ConfusionMatrixMetric
11-
from labelbox.data.annotation_types.metrics.scalar import ScalarMetric
12-
from labelbox.data.serialization.ndjson.converter import NDJsonConverter
1+
from typing import Any, Dict, List, Set
132

143

154
class LabelsConfidencePresenceChecker:
@@ -19,49 +8,23 @@ class LabelsConfidencePresenceChecker:
198

209
@classmethod
2110
def check(cls, raw_labels: List[Dict[str, Any]]):
22-
label_list = NDJsonConverter.deserialize(raw_labels).as_list()
23-
return any([cls._check_label(x) for x in label_list])
24-
25-
@classmethod
26-
def _check_label(cls, label: Label):
27-
return any([cls._check_annotation(x) for x in label.annotations])
28-
29-
@classmethod
30-
def _check_annotation(cls, annotation: Union[ClassificationAnnotation,
31-
ObjectAnnotation,
32-
VideoObjectAnnotation,
33-
VideoClassificationAnnotation,
34-
ScalarMetric,
35-
ConfusionMatrixMetric]):
36-
37-
confidence: Union[float,
38-
None] = getattr(annotation, 'confidence') if hasattr(
39-
annotation, 'confidence') else None
40-
if confidence is not None:
41-
return True
42-
43-
classifications: Union[List[ClassificationAnnotation], None] = getattr(
44-
annotation, 'classifications') if hasattr(
45-
annotation, 'classifications') else None
46-
if classifications:
47-
return any([cls._check_classification(x) for x in classifications])
48-
return False
11+
keys = set([])
12+
cls._collect_keys_from_list(raw_labels, keys)
13+
return len(keys.intersection(set(["confidence"]))) == 1
4914

5015
@classmethod
51-
def _check_classification(cls,
52-
classification: ClassificationAnnotation) -> bool:
53-
if isinstance(classification.value, (Checklist, Dropdown)):
54-
return any(
55-
cls._check_classification_answer(x)
56-
for x in classification.value.answer)
57-
if isinstance(classification.value, Radio):
58-
return cls._check_classification_answer(classification.value.answer)
59-
if isinstance(classification.value, Text):
60-
return False
61-
raise Exception(
62-
f"Unexpected classification value type {type(classification.value)}"
63-
)
16+
def _collect_keys_from_list(cls, objects: List[Dict[str, Any]], keys: Set):
17+
for obj in objects:
18+
if isinstance(obj, (list, tuple)):
19+
cls._collect_keys_from_list(obj, keys)
20+
elif isinstance(obj, dict):
21+
cls._collect_keys_from_object(obj, keys)
6422

6523
@classmethod
66-
def _check_classification_answer(cls, answer: ClassificationAnswer) -> bool:
67-
return answer.confidence is not None
24+
def _collect_keys_from_object(cls, object: Dict[str, Any], keys: Set):
25+
for key in object:
26+
keys.add(key)
27+
if isinstance(object[key], dict):
28+
cls._collect_keys_from_object(object[key], keys)
29+
if isinstance(object[key], (list, tuple)):
30+
cls._collect_keys_from_list(object[key], keys)

0 commit comments

Comments
 (0)