Skip to content

Commit f9bcd45

Browse files
authored
[AL-6008] Add confidence to free text Python annotation class (#1131)
2 parents 1e28193 + b284d0d commit f9bcd45

File tree

6 files changed

+307
-34
lines changed

6 files changed

+307
-34
lines changed

labelbox/data/serialization/ndjson/classification.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,12 +64,15 @@ class NDTextSubclass(NDAnswer):
6464
answer: str
6565

6666
def to_common(self) -> Text:
67-
return Text(answer=self.answer)
67+
return Text(answer=self.answer, confidence=self.confidence)
6868

6969
@classmethod
7070
def from_common(cls, text: Text, name: str,
7171
feature_schema_id: Cuid) -> "NDTextSubclass":
72-
return cls(answer=text.answer, name=name, schema_id=feature_schema_id)
72+
return cls(answer=text.answer,
73+
name=name,
74+
schema_id=feature_schema_id,
75+
confidence=text.confidence)
7376

7477

7578
class NDChecklistSubclass(NDAnswer):
@@ -161,7 +164,7 @@ def from_common(cls,
161164
schema_id=feature_schema_id,
162165
uuid=uuid,
163166
message_id=message_id,
164-
confidence=confidence,
167+
confidence=text.confidence,
165168
)
166169

167170

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
from labelbox.data.annotation_types.annotation import ClassificationAnnotation
2+
from labelbox.data.annotation_types.classification.classification import Checklist, ClassificationAnswer, Radio, Text
3+
from labelbox.data.annotation_types.data.text import TextData
4+
from labelbox.data.annotation_types.label import Label
5+
6+
from labelbox.data.serialization.ndjson.converter import NDJsonConverter
7+
8+
9+
def test_serialization():
10+
label = Label(uid="ckj7z2q0b0000jx6x0q2q7q0d",
11+
data=TextData(
12+
uid="bkj7z2q0b0000jx6x0q2q7q0d",
13+
text="This is a test",
14+
),
15+
annotations=[
16+
ClassificationAnnotation(name="free_text_annotation",
17+
value=Text(confidence=0.5,
18+
answer="text_answer"))
19+
])
20+
21+
serialized = NDJsonConverter.serialize([label])
22+
res = next(serialized)
23+
24+
assert res['confidence'] == 0.5
25+
assert res['name'] == "free_text_annotation"
26+
assert res['answer'] == "text_answer"
27+
assert res['dataRow']['id'] == "bkj7z2q0b0000jx6x0q2q7q0d"
28+
29+
deserialized = NDJsonConverter.deserialize([res])
30+
res = next(deserialized)
31+
32+
annotation = res.annotations[0]
33+
34+
annotation_value = annotation.value
35+
assert type(annotation_value) is Text
36+
assert annotation_value.answer == "text_answer"
37+
assert annotation_value.confidence == 0.5
38+
39+
40+
def test_nested_serialization():
41+
label = Label(
42+
uid="ckj7z2q0b0000jx6x0q2q7q0d",
43+
data=TextData(
44+
uid="bkj7z2q0b0000jx6x0q2q7q0d",
45+
text="This is a test",
46+
),
47+
annotations=[
48+
ClassificationAnnotation(
49+
name="nested test",
50+
value=Checklist(answer=[
51+
ClassificationAnswer(
52+
name="first_answer",
53+
confidence=0.9,
54+
classifications=[
55+
ClassificationAnnotation(
56+
name="sub_radio_question",
57+
value=Radio(answer=ClassificationAnswer(
58+
name="first_sub_radio_answer",
59+
confidence=0.8,
60+
classifications=[
61+
ClassificationAnnotation(
62+
name="nested answer",
63+
value=Text(
64+
answer="nested answer",
65+
confidence=0.7,
66+
))
67+
])))
68+
])
69+
]),
70+
)
71+
])
72+
73+
serialized = NDJsonConverter.serialize([label])
74+
res = next(serialized)
75+
76+
assert res['dataRow']['id'] == "bkj7z2q0b0000jx6x0q2q7q0d"
77+
answer = res['answer'][0]
78+
assert answer['confidence'] == 0.9
79+
assert answer['name'] == "first_answer"
80+
classification = answer['classifications'][0]
81+
nested_classification_answer = classification['answer']
82+
assert nested_classification_answer['confidence'] == 0.8
83+
assert nested_classification_answer['name'] == "first_sub_radio_answer"
84+
sub_classification = nested_classification_answer['classifications'][0]
85+
assert sub_classification['name'] == "nested answer"
86+
assert sub_classification['answer'] == "nested answer"
87+
assert sub_classification['confidence'] == 0.7
88+
89+
deserialized = NDJsonConverter.deserialize([res])
90+
res = next(deserialized)
91+
annotation = res.annotations[0]
92+
answer = annotation.value.answer[0]
93+
assert answer.confidence == 0.9
94+
assert answer.name == "first_answer"
95+
96+
classification_answer = answer.classifications[0].value.answer
97+
assert classification_answer.confidence == 0.8
98+
assert classification_answer.name == "first_sub_radio_answer"
99+
100+
sub_classification_answer = classification_answer.classifications[0].value
101+
assert type(sub_classification_answer) is Text
102+
assert sub_classification_answer.answer == "nested answer"
103+
assert sub_classification_answer.confidence == 0.7

tests/data/serialization/ndjson/test_text.py

Lines changed: 1 addition & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -21,32 +21,15 @@ def test_serialization():
2121

2222
serialized = NDJsonConverter.serialize([label])
2323
res = next(serialized)
24-
assert res['confidence'] == 0.5
24+
assert 'confidence' not in res # because confidence needs to be set on the annotation itself
2525
assert res['name'] == "radio_question_geo"
2626
assert res['answer'] == "first_radio_answer"
2727
assert res['dataRow']['id'] == "bkj7z2q0b0000jx6x0q2q7q0d"
2828

2929
deserialized = NDJsonConverter.deserialize([res])
3030
res = next(deserialized)
3131
annotation = res.annotations[0]
32-
assert annotation.confidence == 0.5
3332

3433
annotation_value = annotation.value
3534
assert type(annotation_value) is Text
3635
assert annotation_value.answer == "first_radio_answer"
37-
38-
serialized = NDJsonConverter.serialize([label])
39-
res = next(serialized)
40-
assert res['confidence'] == 0.5
41-
assert res['name'] == "radio_question_geo"
42-
assert res['answer'] == "first_radio_answer"
43-
assert res['dataRow']['id'] == "bkj7z2q0b0000jx6x0q2q7q0d"
44-
45-
deserialized = NDJsonConverter.deserialize([res])
46-
res = next(deserialized)
47-
annotation = res.annotations[0]
48-
assert annotation.confidence == 0.5
49-
50-
annotation_value = annotation.value
51-
assert type(annotation_value) is Text
52-
assert annotation_value.answer == "first_radio_answer"

0 commit comments

Comments
 (0)