Skip to content

Commit 1fbeb54

Browse files
authored
Merge pull request #1004 from Labelbox/VB/classification-missing-confidence_AL-5239
Vb/classification missing confidence al 5239
2 parents bbff2f6 + faa5745 commit 1fbeb54

File tree

10 files changed

+164
-73
lines changed

10 files changed

+164
-73
lines changed

labelbox/data/annotation_types/annotation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ class BaseAnnotation(FeatureSchema, abc.ABC):
1616
extra: Dict[str, Any] = {}
1717

1818

19-
class ClassificationAnnotation(BaseAnnotation):
19+
class ClassificationAnnotation(BaseAnnotation, ConfidenceMixin):
2020
"""Classification annotations (non localized)
2121
2222
>>> ClassificationAnnotation(

labelbox/data/annotation_types/classification/classification.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414

1515
# TODO: Replace when pydantic adds support for unions that don't coerce types
16-
class _TempName(BaseModel):
16+
class _TempName(ConfidenceMixin, BaseModel):
1717
name: str
1818

1919
def dict(self, *args, **kwargs):
@@ -43,7 +43,7 @@ def dict(self, *args, **kwargs) -> Dict[str, str]:
4343
return res
4444

4545

46-
class Radio(BaseModel):
46+
class Radio(ConfidenceMixin, BaseModel):
4747
""" A classification with only one selected option allowed
4848
4949
>>> Radio(answer = ClassificationAnswer(name = "dog"))
@@ -62,7 +62,7 @@ class Checklist(_TempName):
6262
answer: List[ClassificationAnswer]
6363

6464

65-
class Text(BaseModel):
65+
class Text(ConfidenceMixin, BaseModel):
6666
""" Free form text
6767
6868
>>> Text(answer = "some text answer")

labelbox/data/serialization/ndjson/classification.py

Lines changed: 30 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -120,25 +120,33 @@ def from_common(cls, radio: Radio, name: str,
120120
class NDText(NDAnnotation, NDTextSubclass):
121121

122122
@classmethod
123-
def from_common(cls, text: Text, name: str, feature_schema_id: Cuid,
124-
extra: Dict[str, Any], data: Union[TextData,
125-
ImageData]) -> "NDText":
123+
def from_common(cls,
124+
text: Text,
125+
name: str,
126+
feature_schema_id: Cuid,
127+
extra: Dict[str, Any],
128+
data: Union[TextData, ImageData],
129+
confidence: Optional[float] = None) -> "NDText":
126130
return cls(
127131
answer=text.answer,
128132
data_row=DataRow(id=data.uid, global_key=data.global_key),
129133
name=name,
130134
schema_id=feature_schema_id,
131135
uuid=extra.get('uuid'),
136+
confidence=confidence,
132137
)
133138

134139

135140
class NDChecklist(NDAnnotation, NDChecklistSubclass, VideoSupported):
136141

137142
@classmethod
138-
def from_common(
139-
cls, checklist: Checklist, name: str, feature_schema_id: Cuid,
140-
extra: Dict[str, Any], data: Union[VideoData, TextData,
141-
ImageData]) -> "NDChecklist":
143+
def from_common(cls,
144+
checklist: Checklist,
145+
name: str,
146+
feature_schema_id: Cuid,
147+
extra: Dict[str, Any],
148+
data: Union[VideoData, TextData, ImageData],
149+
confidence: Optional[float] = None) -> "NDChecklist":
142150
return cls(answer=[
143151
NDFeature(name=answer.name,
144152
schema_id=answer.feature_schema_id,
@@ -149,23 +157,29 @@ def from_common(
149157
name=name,
150158
schema_id=feature_schema_id,
151159
uuid=extra.get('uuid'),
152-
frames=extra.get('frames'))
160+
frames=extra.get('frames'),
161+
confidence=confidence)
153162

154163

155164
class NDRadio(NDAnnotation, NDRadioSubclass, VideoSupported):
156165

157166
@classmethod
158-
def from_common(cls, radio: Radio, name: str, feature_schema_id: Cuid,
159-
extra: Dict[str, Any], data: Union[VideoData, TextData,
160-
ImageData]) -> "NDRadio":
167+
def from_common(cls,
168+
radio: Radio,
169+
name: str,
170+
feature_schema_id: Cuid,
171+
extra: Dict[str, Any],
172+
data: Union[VideoData, TextData, ImageData],
173+
confidence: Optional[float] = None) -> "NDRadio":
161174
return cls(answer=NDFeature(name=radio.answer.name,
162175
schema_id=radio.answer.feature_schema_id,
163176
confidence=radio.answer.confidence),
164177
data_row=DataRow(id=data.uid, global_key=data.global_key),
165178
name=name,
166179
schema_id=feature_schema_id,
167180
uuid=extra.get('uuid'),
168-
frames=extra.get('frames'))
181+
frames=extra.get('frames'),
182+
confidence=confidence)
169183

170184

171185
class NDSubclassification:
@@ -212,7 +226,8 @@ def to_common(
212226
value=annotation.to_common(),
213227
name=annotation.name,
214228
feature_schema_id=annotation.schema_id,
215-
extra={'uuid': annotation.uuid})
229+
extra={'uuid': annotation.uuid},
230+
confidence=annotation.confidence)
216231
if getattr(annotation, 'frames', None) is None:
217232
return [common]
218233
results = []
@@ -235,7 +250,8 @@ def from_common(
235250
)
236251
return classify_obj.from_common(annotation.value, annotation.name,
237252
annotation.feature_schema_id,
238-
annotation.extra, data)
253+
annotation.extra, data,
254+
annotation.confidence)
239255

240256
@staticmethod
241257
def lookup_classification(
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
from labelbox.data.annotation_types.classification.classification import Text
2+
3+
4+
def test_text():
5+
text_entity = Text(answer="good job")
6+
assert text_entity.answer == "good job"
7+
8+
9+
def test_text_confidence():
10+
text_entity = Text(answer="good job", confidence=0.5)
11+
assert text_entity.answer == "good job"
12+
assert text_entity.confidence == 0.5

tests/data/assets/ndjson/text_import.json

Lines changed: 0 additions & 25 deletions
This file was deleted.

tests/data/assets/ndjson/text_import_name_only.json

Lines changed: 0 additions & 15 deletions
This file was deleted.
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
import json
2+
from labelbox.data.annotation_types.annotation import ClassificationAnnotation
3+
from labelbox.data.annotation_types.classification.classification import Checklist, ClassificationAnswer, Radio
4+
from labelbox.data.annotation_types.data.text import TextData
5+
from labelbox.data.annotation_types.label import Label
6+
7+
from labelbox.data.serialization.ndjson.converter import NDJsonConverter
8+
9+
10+
def test_serialization():
11+
label = Label(uid="ckj7z2q0b0000jx6x0q2q7q0d",
12+
data=TextData(
13+
uid="bkj7z2q0b0000jx6x0q2q7q0d",
14+
text="This is a test",
15+
),
16+
annotations=[
17+
ClassificationAnnotation(
18+
name="checkbox_question_geo",
19+
confidence=0.5,
20+
value=Checklist(answer=[
21+
ClassificationAnswer(name="first_answer"),
22+
ClassificationAnswer(name="second_answer")
23+
]))
24+
])
25+
26+
serialized = NDJsonConverter.serialize([label])
27+
28+
res = next(serialized)
29+
assert res['confidence'] == 0.5
30+
assert res['name'] == "checkbox_question_geo"
31+
assert res['answer'][0]['name'] == "first_answer"
32+
assert res['answer'][1]['name'] == "second_answer"
33+
assert res['dataRow']['id'] == "bkj7z2q0b0000jx6x0q2q7q0d"
34+
35+
deserialized = NDJsonConverter.deserialize([res])
36+
res = next(deserialized)
37+
annotation = res.annotations[0]
38+
assert annotation.confidence == 0.5
39+
40+
annotation_value = annotation.value
41+
assert type(annotation_value) is Checklist
42+
assert annotation_value.answer[0].name == "first_answer"
43+
assert annotation_value.answer[1].name == "second_answer"
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
import json
2+
from labelbox.data.annotation_types.annotation import ClassificationAnnotation
3+
from labelbox.data.annotation_types.classification.classification import ClassificationAnswer, Radio
4+
from labelbox.data.annotation_types.data.text import TextData
5+
from labelbox.data.annotation_types.label import Label
6+
7+
from labelbox.data.serialization.ndjson.converter import NDJsonConverter
8+
9+
10+
def test_serialization():
11+
label = Label(uid="ckj7z2q0b0000jx6x0q2q7q0d",
12+
data=TextData(
13+
uid="bkj7z2q0b0000jx6x0q2q7q0d",
14+
text="This is a test",
15+
),
16+
annotations=[
17+
ClassificationAnnotation(
18+
name="radio_question_geo",
19+
confidence=0.5,
20+
value=Radio(answer=ClassificationAnswer(
21+
confidence=0.6, name="first_radio_answer")))
22+
])
23+
24+
serialized = NDJsonConverter.serialize([label])
25+
res = next(serialized)
26+
assert res['confidence'] == 0.5
27+
assert res['name'] == "radio_question_geo"
28+
assert res['answer']['name'] == "first_radio_answer"
29+
assert res['answer']['confidence'] == 0.6
30+
assert res['dataRow']['id'] == "bkj7z2q0b0000jx6x0q2q7q0d"
31+
32+
deserialized = NDJsonConverter.deserialize([res])
33+
res = next(deserialized)
34+
annotation = res.annotations[0]
35+
assert annotation.confidence == 0.5
36+
37+
annotation_value = annotation.value
38+
assert type(annotation_value) is Radio
39+
assert annotation_value.answer.name == "first_radio_answer"
Lines changed: 30 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,37 @@
11
import json
2+
from labelbox.data.annotation_types.annotation import ClassificationAnnotation
3+
from labelbox.data.annotation_types.classification.classification import ClassificationAnswer, Radio, Text
4+
from labelbox.data.annotation_types.data.text import TextData
5+
from labelbox.data.annotation_types.label import Label
26

37
from labelbox.data.serialization.ndjson.converter import NDJsonConverter
48

59

6-
def test_text():
7-
with open('tests/data/assets/ndjson/text_import.json', 'r') as file:
8-
data = json.load(file)
9-
res = list(NDJsonConverter.deserialize(data))
10-
res = list(NDJsonConverter.serialize(res))
11-
assert res == data
10+
def test_serialization():
11+
label = Label(uid="ckj7z2q0b0000jx6x0q2q7q0d",
12+
data=TextData(
13+
uid="bkj7z2q0b0000jx6x0q2q7q0d",
14+
text="This is a test",
15+
),
16+
annotations=[
17+
ClassificationAnnotation(
18+
name="radio_question_geo",
19+
confidence=0.5,
20+
value=Text(answer="first_radio_answer"))
21+
])
1222

23+
serialized = NDJsonConverter.serialize([label])
24+
res = next(serialized)
25+
assert res['confidence'] == 0.5
26+
assert res['name'] == "radio_question_geo"
27+
assert res['answer'] == "first_radio_answer"
28+
assert res['dataRow']['id'] == "bkj7z2q0b0000jx6x0q2q7q0d"
1329

14-
def test_text_name_only():
15-
with open('tests/data/assets/ndjson/text_import_name_only.json',
16-
'r') as file:
17-
data = json.load(file)
18-
res = list(NDJsonConverter.deserialize(data))
19-
res = list(NDJsonConverter.serialize(res))
20-
assert res == data
30+
deserialized = NDJsonConverter.deserialize([res])
31+
res = next(deserialized)
32+
annotation = res.annotations[0]
33+
assert annotation.confidence == 0.5
34+
35+
annotation_value = annotation.value
36+
assert type(annotation_value) is Text
37+
assert annotation_value.answer == "first_radio_answer"

tests/integration/annotation_import/conftest.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -658,7 +658,7 @@ def model_run_with_model_run_data_rows(client, configured_project,
658658
labels = wait_for_label_processing(configured_project)
659659
label_ids = [label.uid for label in labels]
660660
model_run.upsert_labels(label_ids)
661-
time.sleep(3)
661+
time.sleep(300)
662662
yield model_run
663663
model_run.delete()
664664
# TODO: Delete resources when that is possible ..
@@ -670,6 +670,11 @@ def model_run_with_all_project_labels(client, configured_project,
670670
wait_for_label_processing):
671671
configured_project.enable_model_assisted_labeling()
672672

673+
data_row_ids = configured_project.data_row_ids
674+
675+
configured_project._wait_until_data_rows_are_processed(
676+
data_row_ids=data_row_ids)
677+
673678
upload_task = LabelImport.create_from_objects(
674679
client, configured_project.uid, f"label-import-{uuid.uuid4()}",
675680
model_run_predictions)
@@ -680,7 +685,6 @@ def model_run_with_all_project_labels(client, configured_project,
680685
) == 0, f"Label Import {upload_task.name} failed with errors {upload_task.errors}"
681686
wait_for_label_processing(configured_project)
682687
model_run.upsert_labels(project_id=configured_project.uid)
683-
time.sleep(3)
684688
yield model_run
685689
model_run.delete()
686690
# TODO: Delete resources when that is possible ..

0 commit comments

Comments
 (0)