Skip to content

Commit 64badb2

Browse files
authored
[AL-7738] Support custom metrics for predictions (#1399)
2 parents f50b8b7 + 868efc8 commit 64badb2

28 files changed

+663
-171
lines changed

labelbox/data/annotation_types/annotation.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,13 @@
33
from labelbox.data.annotation_types.base_annotation import BaseAnnotation
44
from labelbox.data.annotation_types.geometry.geometry import Geometry
55

6-
from labelbox.data.mixins import ConfidenceMixin
6+
from labelbox.data.mixins import ConfidenceMixin, CustomMetricsMixin
77

88
from labelbox.data.annotation_types.classification.classification import ClassificationAnnotation
99
from .ner import DocumentEntity, TextEntity, ConversationEntity
1010

1111

12-
class ObjectAnnotation(BaseAnnotation, ConfidenceMixin):
12+
class ObjectAnnotation(BaseAnnotation, ConfidenceMixin, CustomMetricsMixin):
1313
"""Generic localized annotation (non classifications)
1414
1515
>>> ObjectAnnotation(
@@ -27,5 +27,6 @@ class ObjectAnnotation(BaseAnnotation, ConfidenceMixin):
2727
classifications (Optional[List[ClassificationAnnotation]]): Optional sub classification of the annotation
2828
extra (Dict[str, Any])
2929
"""
30+
3031
value: Union[TextEntity, ConversationEntity, DocumentEntity, Geometry]
3132
classifications: List[ClassificationAnnotation] = []

labelbox/data/annotation_types/classification/classification.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import warnings
33
from labelbox.data.annotation_types.base_annotation import BaseAnnotation
44

5-
from labelbox.data.mixins import ConfidenceMixin
5+
from labelbox.data.mixins import ConfidenceMixin, CustomMetricsMixin
66

77
try:
88
from typing import Literal
@@ -23,7 +23,7 @@ def dict(self, *args, **kwargs):
2323
return res
2424

2525

26-
class ClassificationAnswer(FeatureSchema, ConfidenceMixin):
26+
class ClassificationAnswer(FeatureSchema, ConfidenceMixin, CustomMetricsMixin):
2727
"""
2828
- Represents a classification option.
2929
- Because it inherits from FeatureSchema
@@ -47,7 +47,7 @@ def dict(self, *args, **kwargs) -> Dict[str, str]:
4747
return res
4848

4949

50-
class Radio(ConfidenceMixin, BaseModel):
50+
class Radio(ConfidenceMixin, CustomMetricsMixin, BaseModel):
5151
""" A classification with only one selected option allowed
5252
5353
>>> Radio(answer = ClassificationAnswer(name = "dog"))
@@ -66,7 +66,7 @@ class Checklist(_TempName):
6666
answer: List[ClassificationAnswer]
6767

6868

69-
class Text(ConfidenceMixin, BaseModel):
69+
class Text(ConfidenceMixin, CustomMetricsMixin, BaseModel):
7070
""" Free form text
7171
7272
>>> Text(answer = "some text answer")
@@ -93,7 +93,8 @@ def __init__(self, **data: Any):
9393
"removed in a future release")
9494

9595

96-
class ClassificationAnnotation(BaseAnnotation, ConfidenceMixin):
96+
class ClassificationAnnotation(BaseAnnotation, ConfidenceMixin,
97+
CustomMetricsMixin):
9798
"""Classification annotations (non localized)
9899
99100
>>> ClassificationAnnotation(

labelbox/data/annotation_types/video.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from labelbox.data.annotation_types.annotation import ClassificationAnnotation, ObjectAnnotation
88
from labelbox.data.annotation_types.feature import FeatureSchema
9-
from labelbox.data.mixins import ConfidenceNotSupportedMixin
9+
from labelbox.data.mixins import ConfidenceNotSupportedMixin, CustomMetricsNotSupportedMixin
1010
from labelbox.utils import _CamelCaseMixin, is_valid_uri
1111

1212

@@ -24,7 +24,8 @@ class VideoClassificationAnnotation(ClassificationAnnotation):
2424
segment_index: Optional[int] = None
2525

2626

27-
class VideoObjectAnnotation(ObjectAnnotation, ConfidenceNotSupportedMixin):
27+
class VideoObjectAnnotation(ObjectAnnotation, ConfidenceNotSupportedMixin,
28+
CustomMetricsNotSupportedMixin):
2829
"""Video object annotation
2930
>>> VideoObjectAnnotation(
3031
>>> keyframe=True,

labelbox/data/mixins.py

Lines changed: 49 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,73 @@
1-
from typing import Optional
1+
from typing import Optional, List
22

33
from pydantic import BaseModel, validator
44

5-
from labelbox.exceptions import ConfidenceNotSupportedException
5+
from labelbox.exceptions import ConfidenceNotSupportedException, CustomMetricsNotSupportedException
66

77

88
class ConfidenceMixin(BaseModel):
99
confidence: Optional[float] = None
1010

11-
@validator('confidence')
11+
@validator("confidence")
1212
def confidence_valid_float(cls, value):
1313
if value is None:
1414
return value
1515
if not isinstance(value, (int, float)) or not 0 <= value <= 1:
16-
raise ValueError('must be a number within [0,1] range')
16+
raise ValueError("must be a number within [0,1] range")
1717
return value
1818

1919
def dict(self, *args, **kwargs):
2020
res = super().dict(*args, **kwargs)
21-
if 'confidence' in res and res['confidence'] is None:
22-
res.pop('confidence')
21+
if "confidence" in res and res["confidence"] is None:
22+
res.pop("confidence")
2323
return res
2424

2525

2626
class ConfidenceNotSupportedMixin:
2727

2828
def __new__(cls, *args, **kwargs):
29-
if 'confidence' in kwargs:
29+
if "confidence" in kwargs:
3030
raise ConfidenceNotSupportedException(
31-
'Confidence is not supported for this annotaiton type yet')
31+
"Confidence is not supported for this annotation type yet")
32+
return super().__new__(cls)
33+
34+
35+
class CustomMetric(BaseModel):
36+
name: str
37+
value: float
38+
39+
@validator("name")
40+
def confidence_valid_float(cls, value):
41+
if not isinstance(value, str):
42+
raise ValueError("Name must be a string")
43+
return value
44+
45+
@validator("value")
46+
def value_valid_float(cls, value):
47+
if not isinstance(value, (int, float)):
48+
raise ValueError("Value must be a number")
49+
return value
50+
51+
52+
class CustomMetricsMixin(BaseModel):
53+
custom_metrics: Optional[List[CustomMetric]] = None
54+
55+
def dict(self, *args, **kwargs):
56+
res = super().dict(*args, **kwargs)
57+
58+
if "customMetrics" in res and res["customMetrics"] is None:
59+
res.pop("customMetrics")
60+
61+
if "custom_metrics" in res and res["custom_metrics"] is None:
62+
res.pop("custom_metrics")
63+
64+
return res
65+
66+
67+
class CustomMetricsNotSupportedMixin:
68+
69+
def __new__(cls, *args, **kwargs):
70+
if "custom_metrics" in kwargs:
71+
raise CustomMetricsNotSupportedException(
72+
"Custom metrics is not supported for this annotation type yet")
3273
return super().__new__(cls)

labelbox/data/serialization/ndjson/classification.py

Lines changed: 36 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from typing import Any, Dict, List, Union, Optional
22

33
from pydantic import BaseModel, Field, root_validator
4-
from labelbox.data.mixins import ConfidenceMixin
4+
from labelbox.data.mixins import ConfidenceMixin, CustomMetric, CustomMetricsMixin
55
from labelbox.data.serialization.ndjson.base import DataRow, NDAnnotation
66

77
from labelbox.utils import camel_case
@@ -12,7 +12,7 @@
1212
from ...annotation_types.data import TextData, VideoData, ImageData
1313

1414

15-
class NDAnswer(ConfidenceMixin):
15+
class NDAnswer(ConfidenceMixin, CustomMetricsMixin):
1616
name: Optional[str] = None
1717
schema_id: Optional[Cuid] = None
1818
classifications: Optional[List['NDSubclassificationType']] = []
@@ -64,15 +64,20 @@ class NDTextSubclass(NDAnswer):
6464
answer: str
6565

6666
def to_common(self) -> Text:
67-
return Text(answer=self.answer, confidence=self.confidence)
67+
return Text(answer=self.answer,
68+
confidence=self.confidence,
69+
custom_metrics=self.custom_metrics)
6870

6971
@classmethod
7072
def from_common(cls, text: Text, name: str,
7173
feature_schema_id: Cuid) -> "NDTextSubclass":
72-
return cls(answer=text.answer,
73-
name=name,
74-
schema_id=feature_schema_id,
75-
confidence=text.confidence)
74+
return cls(
75+
answer=text.answer,
76+
name=name,
77+
schema_id=feature_schema_id,
78+
confidence=text.confidence,
79+
custom_metrics=text.custom_metrics,
80+
)
7681

7782

7883
class NDChecklistSubclass(NDAnswer):
@@ -87,7 +92,8 @@ def to_common(self) -> Checklist:
8792
classifications=[
8893
NDSubclassification.to_common(annot)
8994
for annot in answer.classifications
90-
])
95+
],
96+
custom_metrics=answer.custom_metrics)
9197
for answer in self.answer
9298
])
9399

@@ -101,7 +107,8 @@ def from_common(cls, checklist: Checklist, name: str,
101107
classifications=[
102108
NDSubclassification.from_common(annot)
103109
for annot in answer.classifications
104-
])
110+
],
111+
custom_metrics=answer.custom_metrics)
105112
for answer in checklist.answer
106113
],
107114
name=name,
@@ -126,7 +133,7 @@ def to_common(self) -> Radio:
126133
NDSubclassification.to_common(annot)
127134
for annot in self.answer.classifications
128135
],
129-
))
136+
custom_metrics=self.answer.custom_metrics))
130137

131138
@classmethod
132139
def from_common(cls, radio: Radio, name: str,
@@ -137,7 +144,8 @@ def from_common(cls, radio: Radio, name: str,
137144
classifications=[
138145
NDSubclassification.from_common(annot)
139146
for annot in radio.answer.classifications
140-
]),
147+
],
148+
custom_metrics=radio.answer.custom_metrics),
141149
name=name,
142150
schema_id=feature_schema_id)
143151

@@ -165,21 +173,25 @@ def from_common(cls,
165173
uuid=uuid,
166174
message_id=message_id,
167175
confidence=text.confidence,
176+
custom_metrics=text.custom_metrics,
168177
)
169178

170179

171180
class NDChecklist(NDAnnotation, NDChecklistSubclass, VideoSupported):
172181

173182
@classmethod
174-
def from_common(cls,
175-
uuid: str,
176-
checklist: Checklist,
177-
name: str,
178-
feature_schema_id: Cuid,
179-
extra: Dict[str, Any],
180-
data: Union[VideoData, TextData, ImageData],
181-
message_id: str,
182-
confidence: Optional[float] = None) -> "NDChecklist":
183+
def from_common(
184+
cls,
185+
uuid: str,
186+
checklist: Checklist,
187+
name: str,
188+
feature_schema_id: Cuid,
189+
extra: Dict[str, Any],
190+
data: Union[VideoData, TextData, ImageData],
191+
message_id: str,
192+
confidence: Optional[float] = None,
193+
custom_metrics: Optional[List[CustomMetric]] = None
194+
) -> "NDChecklist":
183195

184196
return cls(answer=[
185197
NDAnswer(name=answer.name,
@@ -188,7 +200,8 @@ def from_common(cls,
188200
classifications=[
189201
NDSubclassification.from_common(annot)
190202
for annot in answer.classifications
191-
])
203+
],
204+
custom_metrics=answer.custom_metrics)
192205
for answer in checklist.answer
193206
],
194207
data_row=DataRow(id=data.uid, global_key=data.global_key),
@@ -220,7 +233,8 @@ def from_common(
220233
classifications=[
221234
NDSubclassification.from_common(annot)
222235
for annot in radio.answer.classifications
223-
]),
236+
],
237+
custom_metrics=radio.answer.custom_metrics),
224238
data_row=DataRow(id=data.uid, global_key=data.global_key),
225239
name=name,
226240
schema_id=feature_schema_id,

0 commit comments

Comments
 (0)