Skip to content

Commit 1b45ef7

Browse files
Merge pull request #761 from Labelbox/sdubinin/al-3578
[AL-3578] Confidence support in SDK
2 parents bf824b8 + f05d09e commit 1b45ef7

26 files changed

+2152
-319
lines changed

labelbox/data/annotation_types/annotation.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import abc
22
from typing import Any, Dict, List, Optional, Union
33

4+
from labelbox.data.mixins import ConfidenceNotSupportedMixin, ConfidenceMixin
5+
46
from .classification import Checklist, Dropdown, Radio, Text
57
from .feature import FeatureSchema
68
from .geometry import Geometry, Rectangle, Point
@@ -31,7 +33,7 @@ class ClassificationAnnotation(BaseAnnotation):
3133
value: Union[Text, Checklist, Radio, Dropdown]
3234

3335

34-
class ObjectAnnotation(BaseAnnotation):
36+
class ObjectAnnotation(BaseAnnotation, ConfidenceMixin):
3537
"""Generic localized annotation (non classifications)
3638
3739
>>> ObjectAnnotation(
@@ -53,7 +55,7 @@ class ObjectAnnotation(BaseAnnotation):
5355
classifications: List[ClassificationAnnotation] = []
5456

5557

56-
class VideoObjectAnnotation(ObjectAnnotation):
58+
class VideoObjectAnnotation(ObjectAnnotation, ConfidenceNotSupportedMixin):
5759
"""Video object annotation
5860
5961
>>> VideoObjectAnnotation(
@@ -76,6 +78,7 @@ class VideoObjectAnnotation(ObjectAnnotation):
7678
classifications (List[ClassificationAnnotation]) = []
7779
extra (Dict[str, Any])
7880
"""
81+
7982
frame: int
8083
keyframe: bool
8184
segment_index: Optional[int] = None

labelbox/data/annotation_types/classification/classification.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from typing import Any, Dict, List, Union, Optional
22
import warnings
33

4+
from labelbox.data.mixins import ConfidenceMixin
5+
46
try:
57
from typing import Literal
68
except:
@@ -20,7 +22,7 @@ def dict(self, *args, **kwargs):
2022
return res
2123

2224

23-
class ClassificationAnswer(FeatureSchema):
25+
class ClassificationAnswer(FeatureSchema, ConfidenceMixin):
2426
"""
2527
- Represents a classification option.
2628
- Because it inherits from FeatureSchema

labelbox/data/mixins.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
from typing import Optional
2+
3+
from pydantic import BaseModel, validator
4+
5+
from labelbox.exceptions import ConfidenceNotSupportedException
6+
7+
8+
class ConfidenceMixin(BaseModel):
9+
confidence: Optional[float] = None
10+
11+
@validator('confidence')
12+
def confidence_valid_float(cls, value):
13+
if value is None:
14+
return value
15+
if not isinstance(value, (int, float)) or not 0 <= value <= 1:
16+
raise ValueError('must be a number within [0,1] range')
17+
return value
18+
19+
def dict(self, *args, **kwargs):
20+
res = super().dict(*args, **kwargs)
21+
if 'confidence' in res and res['confidence'] is None:
22+
res.pop('confidence')
23+
return res
24+
25+
26+
class ConfidenceNotSupportedMixin:
27+
28+
def __new__(cls, *args, **kwargs):
29+
if 'confidence' in kwargs:
30+
raise ConfidenceNotSupportedException(
31+
'Confidence is not supported for this annotaiton type yet')
32+
return super().__new__(cls)

labelbox/data/serialization/ndjson/classification.py

Lines changed: 27 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from typing import Any, Dict, List, Union, Optional
22

33
from pydantic import BaseModel, Field, root_validator
4+
from labelbox.data.mixins import ConfidenceMixin
45

56
from labelbox.utils import camel_case
67
from ...annotation_types.annotation import ClassificationAnnotation, VideoClassificationAnnotation
@@ -10,7 +11,7 @@
1011
from .base import NDAnnotation
1112

1213

13-
class NDFeature(BaseModel):
14+
class NDFeature(ConfidenceMixin):
1415
name: Optional[str] = None
1516
schema_id: Optional[Cuid] = None
1617

@@ -41,7 +42,7 @@ class FrameLocation(BaseModel):
4142

4243

4344
class VideoSupported(BaseModel):
44-
#Note that frames are only allowed as top level inferences for video
45+
# Note that frames are only allowed as top level inferences for video
4546
frames: Optional[List[FrameLocation]] = None
4647

4748
def dict(self, *args, **kwargs):
@@ -70,15 +71,18 @@ class NDChecklistSubclass(NDFeature):
7071
def to_common(self) -> Checklist:
7172
return Checklist(answer=[
7273
ClassificationAnswer(name=answer.name,
73-
feature_schema_id=answer.schema_id)
74+
feature_schema_id=answer.schema_id,
75+
confidence=answer.confidence)
7476
for answer in self.answer
7577
])
7678

7779
@classmethod
7880
def from_common(cls, checklist: Checklist, name: str,
7981
feature_schema_id: Cuid) -> "NDChecklistSubclass":
8082
return cls(answer=[
81-
NDFeature(name=answer.name, schema_id=answer.feature_schema_id)
83+
NDFeature(name=answer.name,
84+
schema_id=answer.feature_schema_id,
85+
confidence=answer.confidence)
8286
for answer in checklist.answer
8387
],
8488
name=name,
@@ -95,19 +99,22 @@ class NDRadioSubclass(NDFeature):
9599
answer: NDFeature
96100

97101
def to_common(self) -> Radio:
98-
return Radio(answer=ClassificationAnswer(
99-
name=self.answer.name, feature_schema_id=self.answer.schema_id))
102+
return Radio(
103+
answer=ClassificationAnswer(name=self.answer.name,
104+
feature_schema_id=self.answer.schema_id,
105+
confidence=self.answer.confidence))
100106

101107
@classmethod
102108
def from_common(cls, radio: Radio, name: str,
103109
feature_schema_id: Cuid) -> "NDRadioSubclass":
104110
return cls(answer=NDFeature(name=radio.answer.name,
105-
schema_id=radio.answer.feature_schema_id),
111+
schema_id=radio.answer.feature_schema_id,
112+
confidence=radio.answer.confidence),
106113
name=name,
107114
schema_id=feature_schema_id)
108115

109116

110-
### ====== End of subclasses
117+
# ====== End of subclasses
111118

112119

113120
class NDText(NDAnnotation, NDTextSubclass):
@@ -133,7 +140,9 @@ def from_common(
133140
extra: Dict[str, Any], data: Union[VideoData, TextData,
134141
ImageData]) -> "NDChecklist":
135142
return cls(answer=[
136-
NDFeature(name=answer.name, schema_id=answer.feature_schema_id)
143+
NDFeature(name=answer.name,
144+
schema_id=answer.feature_schema_id,
145+
confidence=answer.confidence)
137146
for answer in checklist.answer
138147
],
139148
data_row={'id': data.uid},
@@ -150,7 +159,8 @@ def from_common(cls, radio: Radio, name: str, feature_schema_id: Cuid,
150159
extra: Dict[str, Any], data: Union[VideoData, TextData,
151160
ImageData]) -> "NDRadio":
152161
return cls(answer=NDFeature(name=radio.answer.name,
153-
schema_id=radio.answer.feature_schema_id),
162+
schema_id=radio.answer.feature_schema_id,
163+
confidence=radio.answer.confidence),
154164
data_row={'id': data.uid},
155165
name=name,
156166
schema_id=feature_schema_id,
@@ -241,6 +251,11 @@ def lookup_classification(
241251
}.get(type(annotation.value))
242252

243253

244-
NDSubclassificationType = Union[NDRadioSubclass, NDChecklistSubclass,
254+
# Make sure to keep NDChecklistSubclass prior to NDRadioSubclass in the list,
255+
# otherwise list of answers gets parsed by NDRadio whereas NDChecklist must be used
256+
NDSubclassificationType = Union[NDChecklistSubclass, NDRadioSubclass,
245257
NDTextSubclass]
246-
NDClassificationType = Union[NDRadio, NDChecklist, NDText]
258+
259+
# Make sure to keep NDChecklist prior to NDRadio in the list,
260+
# otherwise list of answers gets parsed by NDRadio whereas NDChecklist must be used
261+
NDClassificationType = Union[NDChecklist, NDRadio, NDText]

0 commit comments

Comments
 (0)