Skip to content

Commit 2c70cd6

Browse files
committed
AL-3578: Added cases for text classifications
1 parent 3eebb61 commit 2c70cd6

File tree

2 files changed

+21
-12
lines changed

2 files changed

+21
-12
lines changed

labelbox/data/annotation_types/classification/classification.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from typing import Any, Dict, List, Union, Optional
22
import warnings
33

4+
from labelbox.data.mixins import ConfidenceMixin
5+
46
try:
57
from typing import Literal
68
except:
@@ -20,7 +22,7 @@ def dict(self, *args, **kwargs):
2022
return res
2123

2224

23-
class ClassificationAnswer(FeatureSchema):
25+
class ClassificationAnswer(FeatureSchema, ConfidenceMixin):
2426
"""
2527
- Represents a classification option.
2628
- Because it inherits from FeatureSchema
@@ -33,7 +35,6 @@ class ClassificationAnswer(FeatureSchema):
3335
"""
3436
extra: Dict[str, Any] = {}
3537
keyframe: Optional[bool] = None
36-
confidence: Optional[float] = None
3738

3839
def dict(self, *args, **kwargs) -> Dict[str, str]:
3940
res = super().dict(*args, **kwargs)

tests/data/annotation_types/classification/test_classification.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,12 @@ def test_classification_answer():
1212

1313
feature_schema_id = "schema_id"
1414
name = "my_feature"
15-
answer = ClassificationAnswer(name=name)
15+
confidence = 0.9
16+
answer = ClassificationAnswer(name=name, confidence=confidence)
1617

1718
assert answer.feature_schema_id is None
1819
assert answer.name == name
20+
assert answer.confidence == confidence
1921

2022
answer = ClassificationAnswer(feature_schema_id=feature_schema_id,
2123
name=name)
@@ -79,7 +81,7 @@ def test_subclass():
7981

8082

8183
def test_radio():
82-
answer = ClassificationAnswer(name="1")
84+
answer = ClassificationAnswer(name="1", confidence=0.81)
8385
feature_schema_id = "feature_schema_id"
8486
name = "my_feature"
8587

@@ -89,12 +91,13 @@ def test_radio():
8991

9092
with pytest.raises(ValidationError):
9193
classification = Radio(answer=[answer])
92-
classification = Radio(answer=answer)
94+
classification = Radio(answer=answer,)
9395
assert classification.dict() == {
9496
'answer': {
9597
'name': answer.name,
9698
'feature_schema_id': None,
97-
'extra': {}
99+
'extra': {},
100+
'confidence': 0.81
98101
}
99102
}
100103
classification = ClassificationAnnotation(
@@ -109,14 +112,15 @@ def test_radio():
109112
'answer': {
110113
'name': answer.name,
111114
'feature_schema_id': None,
112-
'extra': {}
115+
'extra': {},
116+
'confidence': 0.81
113117
}
114118
}
115119
}
116120

117121

118122
def test_checklist():
119-
answer = ClassificationAnswer(name="1")
123+
answer = ClassificationAnswer(name="1", confidence=0.99)
120124
feature_schema_id = "feature_schema_id"
121125
name = "my_feature"
122126

@@ -131,7 +135,8 @@ def test_checklist():
131135
'answer': [{
132136
'name': answer.name,
133137
'feature_schema_id': None,
134-
'extra': {}
138+
'extra': {},
139+
'confidence': 0.99
135140
}]
136141
}
137142
classification = ClassificationAnnotation(
@@ -147,14 +152,15 @@ def test_checklist():
147152
'answer': [{
148153
'name': answer.name,
149154
'feature_schema_id': None,
150-
'extra': {}
155+
'extra': {},
156+
'confidence': 0.99
151157
}]
152158
},
153159
}
154160

155161

156162
def test_dropdown():
157-
answer = ClassificationAnswer(name="1")
163+
answer = ClassificationAnswer(name="1", confidence=1)
158164
feature_schema_id = "feature_schema_id"
159165
name = "my_feature"
160166

@@ -169,7 +175,8 @@ def test_dropdown():
169175
'answer': [{
170176
'name': '1',
171177
'feature_schema_id': None,
172-
'extra': {}
178+
'extra': {},
179+
'confidence': 1
173180
}]
174181
}
175182
classification = ClassificationAnnotation(
@@ -184,6 +191,7 @@ def test_dropdown():
184191
'answer': [{
185192
'name': answer.name,
186193
'feature_schema_id': None,
194+
'confidence': 1,
187195
'extra': {}
188196
}]
189197
}

0 commit comments

Comments
 (0)