Skip to content

Commit e2253c3

Browse files
author
Val Brodsky
committed
Support confidence as an attribute of Text (freetext)
Support confidence inside Text for top-level free text Support for deserialization of confidence for free text as top node Add free text as a classification to bbox fixture
1 parent d801389 commit e2253c3

File tree

4 files changed

+151
-23
lines changed

4 files changed

+151
-23
lines changed

labelbox/data/serialization/ndjson/classification.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -64,12 +64,15 @@ class NDTextSubclass(NDAnswer):
6464
answer: str
6565

6666
def to_common(self) -> Text:
67-
return Text(answer=self.answer)
67+
return Text(answer=self.answer, confidence=self.confidence)
6868

6969
@classmethod
7070
def from_common(cls, text: Text, name: str,
7171
feature_schema_id: Cuid) -> "NDTextSubclass":
72-
return cls(answer=text.answer, name=name, schema_id=feature_schema_id)
72+
return cls(answer=text.answer,
73+
name=name,
74+
schema_id=feature_schema_id,
75+
confidence=text.confidence)
7376

7477

7578
class NDChecklistSubclass(NDAnswer):
@@ -161,7 +164,7 @@ def from_common(cls,
161164
schema_id=feature_schema_id,
162165
uuid=uuid,
163166
message_id=message_id,
164-
confidence=confidence,
167+
confidence=text.confidence,
165168
)
166169

167170

@@ -273,7 +276,6 @@ def to_common(
273276
feature_schema_id=annotation.schema_id,
274277
extra={'uuid': annotation.uuid},
275278
message_id=annotation.message_id,
276-
confidence=annotation.confidence,
277279
)
278280

279281
if getattr(annotation, 'frames', None) is None:
Lines changed: 68 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from labelbox.data.annotation_types.annotation import ClassificationAnnotation
2-
from labelbox.data.annotation_types.classification.classification import ClassificationAnswer, Radio, Text
2+
from labelbox.data.annotation_types.classification.classification import Checklist, ClassificationAnswer, Radio, Text
33
from labelbox.data.annotation_types.data.text import TextData
44
from labelbox.data.annotation_types.label import Label
55

@@ -13,40 +13,92 @@ def test_serialization():
1313
text="This is a test",
1414
),
1515
annotations=[
16-
ClassificationAnnotation(
17-
name="radio_question_geo",
18-
confidence=0.5,
19-
value=Text(answer="first_radio_answer"))
16+
ClassificationAnnotation(name="free_text_annotation",
17+
value=Text(confidence=0.5,
18+
answer="text_answer"))
2019
])
2120

2221
serialized = NDJsonConverter.serialize([label])
2322
res = next(serialized)
23+
2424
assert res['confidence'] == 0.5
25-
assert res['name'] == "radio_question_geo"
26-
assert res['answer'] == "first_radio_answer"
25+
assert res['name'] == "free_text_annotation"
26+
assert res['answer'] == "text_answer"
2727
assert res['dataRow']['id'] == "bkj7z2q0b0000jx6x0q2q7q0d"
2828

2929
deserialized = NDJsonConverter.deserialize([res])
3030
res = next(deserialized)
31+
3132
annotation = res.annotations[0]
32-
assert annotation.confidence == 0.5
33+
answer = annotation.value.answer
3334

3435
annotation_value = annotation.value
3536
assert type(annotation_value) is Text
36-
assert annotation_value.answer == "first_radio_answer"
37+
assert annotation_value.answer == "text_answer"
38+
assert annotation_value.confidence == 0.5
39+
40+
41+
def test_nested_serialization():
42+
label = Label(
43+
uid="ckj7z2q0b0000jx6x0q2q7q0d",
44+
data=TextData(
45+
uid="bkj7z2q0b0000jx6x0q2q7q0d",
46+
text="This is a test",
47+
),
48+
annotations=[
49+
ClassificationAnnotation(
50+
name="nested test",
51+
value=Checklist(answer=[
52+
ClassificationAnswer(
53+
name="first_answer",
54+
confidence=0.9,
55+
classifications=[
56+
ClassificationAnnotation(
57+
name="sub_radio_question",
58+
value=Radio(answer=ClassificationAnswer(
59+
name="first_sub_radio_answer",
60+
confidence=0.8,
61+
classifications=[
62+
ClassificationAnnotation(
63+
name="nested answer",
64+
value=Text(
65+
answer="nested answer",
66+
confidence=0.7,
67+
))
68+
])))
69+
])
70+
]),
71+
)
72+
])
3773

3874
serialized = NDJsonConverter.serialize([label])
3975
res = next(serialized)
40-
assert res['confidence'] == 0.5
41-
assert res['name'] == "radio_question_geo"
42-
assert res['answer'] == "first_radio_answer"
76+
4377
assert res['dataRow']['id'] == "bkj7z2q0b0000jx6x0q2q7q0d"
78+
answer = res['answer'][0]
79+
assert answer['confidence'] == 0.9
80+
assert answer['name'] == "first_answer"
81+
classification = answer['classifications'][0]
82+
nested_classification_answer = classification['answer']
83+
assert nested_classification_answer['confidence'] == 0.8
84+
assert nested_classification_answer['name'] == "first_sub_radio_answer"
85+
sub_classification = nested_classification_answer['classifications'][0]
86+
assert sub_classification['name'] == "nested answer"
87+
assert sub_classification['answer'] == "nested answer"
88+
assert sub_classification['confidence'] == 0.7
4489

4590
deserialized = NDJsonConverter.deserialize([res])
4691
res = next(deserialized)
4792
annotation = res.annotations[0]
48-
assert annotation.confidence == 0.5
93+
answer = annotation.value.answer[0]
94+
assert answer.confidence == 0.9
95+
assert answer.name == "first_answer"
4996

50-
annotation_value = annotation.value
51-
assert type(annotation_value) is Text
52-
assert annotation_value.answer == "first_radio_answer"
97+
classification_answer = answer.classifications[0].value.answer
98+
assert classification_answer.confidence == 0.8
99+
assert classification_answer.name == "first_sub_radio_answer"
100+
101+
sub_classification_answer = classification_answer.classifications[0].value
102+
assert type(sub_classification_answer) is Text
103+
assert sub_classification_answer.answer == "nested answer"
104+
assert sub_classification_answer.confidence == 0.7

tests/integration/annotation_import/conftest.py

Lines changed: 52 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -242,13 +242,20 @@ def ontology():
242242
'checklist',
243243
'options': [{
244244
'label': 'nested_checkbox_option_1',
245-
'value': 'nested_checkbox_value_1'
245+
'value': 'nested_checkbox_value_1',
246+
'options': []
246247
}, {
247248
'label': 'nested_checkbox_option_2',
248249
'value': 'nested_checkbox_value_2'
249250
}]
251+
}, {
252+
'required': False,
253+
'instructions': 'nested_text',
254+
'name': 'nested_text',
255+
'type': 'text',
256+
'options': []
250257
}]
251-
}]
258+
},]
252259
}]
253260
}
254261

@@ -430,6 +437,7 @@ def configured_project(client, ontology, rand_gen, image_url):
430437
where=LabelingFrontend.name == "editor"))[0]
431438
project.setup(editor, ontology)
432439
data_row_ids = []
440+
433441
for _ in range(len(ontology['tools']) + len(ontology['classifications'])):
434442
data_row_ids.append(dataset.create_data_row(row_data=image_url).uid)
435443
project._wait_until_data_rows_are_processed(data_row_ids=data_row_ids)
@@ -559,14 +567,34 @@ def rectangle_inference(prediction_id_mapping):
559567
['featureSchemaId'],
560568
"name":
561569
rectangle['tool']['classifications'][0]['options'][0]
562-
['value']
570+
['value'],
571+
"classifications": [{
572+
"schemaId":
573+
rectangle['tool']['classifications'][0]['options'][0]
574+
['options'][1]['featureSchemaId'],
575+
"name":
576+
rectangle['tool']['classifications'][0]['options'][0]
577+
['options'][1]['name'],
578+
"answer":
579+
'nested answer'
580+
}],
563581
}
564582
}]
565583
})
566584
del rectangle['tool']
567585
return rectangle
568586

569587

588+
@pytest.fixture
589+
def rectangle_inference_with_confidence(rectangle_inference):
590+
rectangle = rectangle_inference.copy()
591+
rectangle.update({"confidence": 0.9})
592+
rectangle["classifications"][0]["answer"]["confidence"] = 0.8
593+
rectangle["classifications"][0]["answer"]["classifications"][0][
594+
"confidence"] = 0.7
595+
return rectangle
596+
597+
570598
@pytest.fixture
571599
def rectangle_inference_document(rectangle_inference):
572600
rectangle = rectangle_inference.copy()
@@ -743,6 +771,13 @@ def text_inference(prediction_id_mapping):
743771
return text
744772

745773

774+
@pytest.fixture
775+
def text_inference_with_confidence(text_inference):
776+
text = text_inference.copy()
777+
text.update({'confidence': 0.9})
778+
return text
779+
780+
746781
@pytest.fixture
747782
def text_inference_index(prediction_id_mapping):
748783
text = prediction_id_mapping['text_index'].copy()
@@ -799,6 +834,12 @@ def predictions(object_predictions, classification_predictions):
799834
return object_predictions + classification_predictions
800835

801836

837+
@pytest.fixture
838+
def predictions_with_confidence(text_inference_with_confidence,
839+
rectangle_inference_with_confidence):
840+
return [text_inference_with_confidence, rectangle_inference_with_confidence]
841+
842+
802843
@pytest.fixture
803844
def model(client, rand_gen, configured_project):
804845
ontology = configured_project.ontology()
@@ -896,6 +937,14 @@ def check_running_state(req, name, url=None):
896937
assert req.status_file_url is None
897938
assert req.state == AnnotationImportState.RUNNING
898939

940+
@staticmethod
941+
def download_and_assert_status(status_file_url):
942+
response = requests.get(status_file_url)
943+
assert response.status_code == 200
944+
for line in parser.loads(response.content):
945+
status = line['status']
946+
assert status.upper() == 'SUCCESS'
947+
899948
@staticmethod
900949
def _convert_to_plain_object(obj):
901950
"""Some Python objects e.g. tuples can't be compared with JSON serialized data, serialize to JSON and deserialize to get plain objects"""

tests/integration/annotation_import/test_mea_prediction_import.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,31 @@ def test_create_from_objects(model_run_with_data_rows, object_predictions,
3737
annotation_import.wait_until_done()
3838

3939

40+
def test_create_from_objects_with_confidence(predictions_with_confidence,
41+
model_run_with_data_rows,
42+
annotation_import_test_helpers):
43+
name = str(uuid.uuid4())
44+
45+
object_prediction_data_rows = [
46+
object_prediction["dataRow"]["id"]
47+
for object_prediction in predictions_with_confidence
48+
]
49+
# MUST have all data rows in the model run
50+
model_run_with_data_rows.upsert_data_rows(
51+
data_row_ids=object_prediction_data_rows)
52+
53+
annotation_import = model_run_with_data_rows.add_predictions(
54+
name=name, predictions=predictions_with_confidence)
55+
56+
assert annotation_import.model_run_id == model_run_with_data_rows.uid
57+
annotation_import_test_helpers.check_running_state(annotation_import, name)
58+
annotation_import_test_helpers.assert_file_content(
59+
annotation_import.input_file_url, predictions_with_confidence)
60+
annotation_import.wait_until_done()
61+
annotation_import_test_helpers.download_and_assert_status(
62+
annotation_import.status_file_url)
63+
64+
4065
def test_create_from_objects_all_project_labels(
4166
model_run_with_all_project_labels, object_predictions,
4267
annotation_import_test_helpers):

0 commit comments

Comments
 (0)