Skip to content

Commit 6b9e8b7

Browse files
author
Val Brodsky
committed
Add ConfidenceMixin to Classification annotation and it's *values*
1 parent bbff2f6 commit 6b9e8b7

File tree

4 files changed

+54
-31
lines changed

4 files changed

+54
-31
lines changed

labelbox/data/annotation_types/annotation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ class BaseAnnotation(FeatureSchema, abc.ABC):
1616
extra: Dict[str, Any] = {}
1717

1818

19-
class ClassificationAnnotation(BaseAnnotation):
19+
class ClassificationAnnotation(BaseAnnotation, ConfidenceMixin):
2020
"""Classification annotations (non localized)
2121
2222
>>> ClassificationAnnotation(

labelbox/data/annotation_types/classification/classification.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414

1515
# TODO: Replace when pydantic adds support for unions that don't coerce types
16-
class _TempName(BaseModel):
16+
class _TempName(ConfidenceMixin, BaseModel):
1717
name: str
1818

1919
def dict(self, *args, **kwargs):
@@ -43,7 +43,7 @@ def dict(self, *args, **kwargs) -> Dict[str, str]:
4343
return res
4444

4545

46-
class Radio(BaseModel):
46+
class Radio(ConfidenceMixin, BaseModel):
4747
""" A classification with only one selected option allowed
4848
4949
>>> Radio(answer = ClassificationAnswer(name = "dog"))
@@ -62,7 +62,7 @@ class Checklist(_TempName):
6262
answer: List[ClassificationAnswer]
6363

6464

65-
class Text(BaseModel):
65+
class Text(ConfidenceMixin, BaseModel):
6666
""" Free form text
6767
6868
>>> Text(answer = "some text answer")

labelbox/data/serialization/ndjson/classification.py

Lines changed: 28 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -120,25 +120,33 @@ def from_common(cls, radio: Radio, name: str,
120120
class NDText(NDAnnotation, NDTextSubclass):
121121

122122
@classmethod
123-
def from_common(cls, text: Text, name: str, feature_schema_id: Cuid,
124-
extra: Dict[str, Any], data: Union[TextData,
125-
ImageData]) -> "NDText":
123+
def from_common(cls,
124+
text: Text,
125+
name: str,
126+
feature_schema_id: Cuid,
127+
extra: Dict[str, Any],
128+
data: Union[TextData, ImageData],
129+
confidence: Optional[float] = None) -> "NDText":
126130
return cls(
127131
answer=text.answer,
128132
data_row=DataRow(id=data.uid, global_key=data.global_key),
129133
name=name,
130134
schema_id=feature_schema_id,
131135
uuid=extra.get('uuid'),
136+
confidence=confidence,
132137
)
133138

134139

135140
class NDChecklist(NDAnnotation, NDChecklistSubclass, VideoSupported):
136141

137142
@classmethod
138-
def from_common(
139-
cls, checklist: Checklist, name: str, feature_schema_id: Cuid,
140-
extra: Dict[str, Any], data: Union[VideoData, TextData,
141-
ImageData]) -> "NDChecklist":
143+
def from_common(cls,
144+
checklist: Checklist,
145+
name: str,
146+
feature_schema_id: Cuid,
147+
extra: Dict[str, Any],
148+
data: Union[VideoData, TextData, ImageData],
149+
confidence: Optional[float] = None) -> "NDChecklist":
142150
return cls(answer=[
143151
NDFeature(name=answer.name,
144152
schema_id=answer.feature_schema_id,
@@ -149,23 +157,29 @@ def from_common(
149157
name=name,
150158
schema_id=feature_schema_id,
151159
uuid=extra.get('uuid'),
152-
frames=extra.get('frames'))
160+
frames=extra.get('frames'),
161+
confidence=confidence)
153162

154163

155164
class NDRadio(NDAnnotation, NDRadioSubclass, VideoSupported):
156165

157166
@classmethod
158-
def from_common(cls, radio: Radio, name: str, feature_schema_id: Cuid,
159-
extra: Dict[str, Any], data: Union[VideoData, TextData,
160-
ImageData]) -> "NDRadio":
167+
def from_common(cls,
168+
radio: Radio,
169+
name: str,
170+
feature_schema_id: Cuid,
171+
extra: Dict[str, Any],
172+
data: Union[VideoData, TextData, ImageData],
173+
confidence: Optional[float] = None) -> "NDRadio":
161174
return cls(answer=NDFeature(name=radio.answer.name,
162175
schema_id=radio.answer.feature_schema_id,
163176
confidence=radio.answer.confidence),
164177
data_row=DataRow(id=data.uid, global_key=data.global_key),
165178
name=name,
166179
schema_id=feature_schema_id,
167180
uuid=extra.get('uuid'),
168-
frames=extra.get('frames'))
181+
frames=extra.get('frames'),
182+
confidence=confidence)
169183

170184

171185
class NDSubclassification:
@@ -235,7 +249,8 @@ def from_common(
235249
)
236250
return classify_obj.from_common(annotation.value, annotation.name,
237251
annotation.feature_schema_id,
238-
annotation.extra, data)
252+
annotation.extra, data,
253+
annotation.confidence)
239254

240255
@staticmethod
241256
def lookup_classification(
Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,28 @@
11
import json
2+
from labelbox.data.annotation_types.annotation import ClassificationAnnotation
3+
from labelbox.data.annotation_types.classification.classification import ClassificationAnswer, Radio, Text
4+
from labelbox.data.annotation_types.data.text import TextData
5+
from labelbox.data.annotation_types.label import Label
26

37
from labelbox.data.serialization.ndjson.converter import NDJsonConverter
48

59

6-
def test_text():
7-
with open('tests/data/assets/ndjson/text_import.json', 'r') as file:
8-
data = json.load(file)
9-
res = list(NDJsonConverter.deserialize(data))
10-
res = list(NDJsonConverter.serialize(res))
11-
assert res == data
10+
def test_serialization():
11+
label = Label(uid="ckj7z2q0b0000jx6x0q2q7q0d",
12+
data=TextData(
13+
uid="bkj7z2q0b0000jx6x0q2q7q0d",
14+
text="This is a test",
15+
),
16+
annotations=[
17+
ClassificationAnnotation(
18+
name="radio_question_geo",
19+
confidence=0.5,
20+
value=Text(answer="first_radio_answer"))
21+
])
1222

13-
14-
def test_text_name_only():
15-
with open('tests/data/assets/ndjson/text_import_name_only.json',
16-
'r') as file:
17-
data = json.load(file)
18-
res = list(NDJsonConverter.deserialize(data))
19-
res = list(NDJsonConverter.serialize(res))
20-
assert res == data
23+
serialized = NDJsonConverter.serialize([label])
24+
res = next(serialized)
25+
assert res['confidence'] == 0.5
26+
assert res['name'] == "radio_question_geo"
27+
assert res['answer'] == "first_radio_answer"
28+
assert res['dataRow']['id'] == "bkj7z2q0b0000jx6x0q2q7q0d"

0 commit comments

Comments
 (0)