@@ -40,9 +40,10 @@ def confusion_matrix_metric(ground_truths: List[Union[
4040 if value is None :
4141 return []
4242
43- return [
44- ConfusionMatrixMetric (metric_name = f"{ int (iou * 100 )} pct_iou" , value = value )
45- ]
43+ metric_name = _get_metric_name (annotation_pairs [key ][0 ],
44+ annotation_pairs [key ][1 ], iou )
45+
46+ return [ConfusionMatrixMetric (metric_name = metric_name , value = value )]
4647
4748
4849def feature_confusion_matrix_metric (
@@ -74,8 +75,35 @@ def feature_confusion_matrix_metric(
7475 include_subclasses , iou )
7576 if value is None :
7677 continue
78+
79+ metric_name = _get_metric_name (annotation_pairs [key ][0 ],
80+ annotation_pairs [key ][1 ], iou )
7781 metrics .append (
78- ConfusionMatrixMetric (metric_name = f" { int ( iou * 100 ) } pct_iou" ,
82+ ConfusionMatrixMetric (metric_name = name ,
7983 feature_name = key ,
8084 value = value ))
8185 return metrics
86+
87+
88+ def _get_metric_name (ground_truths : List [Union [ObjectAnnotation ,
89+ ClassificationAnnotation ]],
90+ predictions : List [Union [ObjectAnnotation ,
91+ ClassificationAnnotation ]],
92+ iou : float ):
93+ if _is_classification (ground_truths , predictions ):
94+ return "classification"
95+ else :
96+ return f"{ int (iou * 100 )} pct_iou"
97+
98+
99+ def _is_classification (ground_truths : List [Union [ObjectAnnotation ,
100+ ClassificationAnnotation ]],
101+ predictions : List [Union [ObjectAnnotation ,
102+ ClassificationAnnotation ]]):
103+ if len (predictions ) and isinstance (predictions [0 ],
104+ ClassificationAnnotation ):
105+ return True
106+ elif len (ground_truths ) and isinstance (ground_truths [0 ],
107+ ClassificationAnnotation ):
108+ return True
109+ return False
0 commit comments