77from ..group import get_feature_pairs , get_identifying_key , has_no_annotations , has_no_matching_annotations
88from ...annotation_types import (ObjectAnnotation , ClassificationAnnotation ,
99 Mask , Geometry , Point , Line , Checklist , Text ,
10- Radio , ScalarMetricValue )
10+ TextEntity , Radio , ScalarMetricValue )
1111
1212
1313def miou (ground_truths : List [Union [ObjectAnnotation , ClassificationAnnotation ]],
@@ -61,6 +61,8 @@ def feature_miou(ground_truths: List[Union[ObjectAnnotation,
6161 return vector_miou (ground_truths , predictions , include_subclasses )
6262 elif isinstance (predictions [0 ], ClassificationAnnotation ):
6363 return classification_miou (ground_truths , predictions )
64+ elif isinstance (predictions [0 ].value , TextEntity ):
65+ return ner_miou (ground_truths , predictions , include_subclasses )
6466 else :
6567 raise ValueError (
6668 f"Unexpected annotation found. Found { type (predictions [0 ].value )} " )
@@ -269,3 +271,51 @@ def _ensure_valid_poly(poly):
269271def _mask_iou (mask1 : np .ndarray , mask2 : np .ndarray ) -> ScalarMetricValue :
270272 """Computes iou between two binary segmentation masks."""
271273 return np .sum (mask1 & mask2 ) / np .sum (mask1 | mask2 )
274+
275+
276+ def _get_ner_pairs (
277+ ground_truths : List [ObjectAnnotation ], predictions : List [ObjectAnnotation ]
278+ ) -> List [Tuple [ObjectAnnotation , ObjectAnnotation , ScalarMetricValue ]]:
279+ """Get iou score for all possible pairs of ground truths and predictions"""
280+ pairs = []
281+ for ground_truth , prediction in product (ground_truths , predictions ):
282+ score = _ner_iou (ground_truth .value , prediction .value )
283+ pairs .append ((ground_truth , prediction , score ))
284+ return pairs
285+
286+
287+ def _ner_iou (ner1 : TextEntity , ner2 : TextEntity ):
288+ """Computes iou between two text entity annotations"""
289+ intersection_start , intersection_end = max (ner1 .start , ner2 .start ), min (
290+ ner1 .end , ner2 .end )
291+ union_start , union_end = min (ner1 .start ,
292+ ner2 .start ), max (ner1 .end , ner2 .end )
293+ #edge case of only one character in text
294+ if union_start == union_end :
295+ return 1
296+ #if there is no intersection
297+ if intersection_start > intersection_end :
298+ return 0
299+ return (intersection_end - intersection_start ) / (union_end - union_start )
300+
301+
302+ def ner_miou (ground_truths : List [ObjectAnnotation ],
303+ predictions : List [ObjectAnnotation ],
304+ include_subclasses : bool ) -> Optional [ScalarMetricValue ]:
305+ """
306+ Computes iou score for all features with the same feature schema id.
307+ Calculation includes subclassifications.
308+
309+ Args:
310+ ground_truths: List of ground truth ner annotations
311+ predictions: List of prediction ner annotations
312+ Returns:
313+ float representing the iou score for the feature type.
314+ If there are no matches then this returns none
315+ """
316+ if has_no_matching_annotations (ground_truths , predictions ):
317+ return 0.
318+ elif has_no_annotations (ground_truths , predictions ):
319+ return None
320+ pairs = _get_ner_pairs (ground_truths , predictions )
321+ return object_pair_miou (pairs , include_subclasses )
0 commit comments