@@ -40,9 +40,8 @@ def confusion_matrix_metric(ground_truths: List[Union[
4040 if value is None :
4141 return []
4242
43- return [
44- ConfusionMatrixMetric (metric_name = f"{ int (iou * 100 )} pct_iou" , value = value )
45- ]
43+ metric_name = _get_metric_name (ground_truths , predictions , iou )
44+ return [ConfusionMatrixMetric (metric_name = metric_name , value = value )]
4645
4746
4847def feature_confusion_matrix_metric (
@@ -74,8 +73,34 @@ def feature_confusion_matrix_metric(
7473 include_subclasses , iou )
7574 if value is None :
7675 continue
76+
77+ metric_name = _get_metric_name (annotation_pairs [key ][0 ],
78+ annotation_pairs [key ][1 ], iou )
7779 metrics .append (
78- ConfusionMatrixMetric (metric_name = f" { int ( iou * 100 ) } pct_iou" ,
80+ ConfusionMatrixMetric (metric_name = metric_name ,
7981 feature_name = key ,
8082 value = value ))
8183 return metrics
84+
85+
86+ def _get_metric_name (ground_truths : List [Union [ObjectAnnotation ,
87+ ClassificationAnnotation ]],
88+ predictions : List [Union [ObjectAnnotation ,
89+ ClassificationAnnotation ]],
90+ iou : float ):
91+
92+ if _is_classification (ground_truths , predictions ):
93+ return "classification"
94+
95+ return f"{ int (iou * 100 )} pct_iou"
96+
97+
98+ def _is_classification (ground_truths : List [Union [ObjectAnnotation ,
99+ ClassificationAnnotation ]],
100+ predictions : List [Union [ObjectAnnotation ,
101+ ClassificationAnnotation ]]):
102+ # Check if either the prediction or label contains a classification annotation
103+ return (len (predictions ) and
104+ isinstance (predictions [0 ], ClassificationAnnotation ) or
105+ len (ground_truths ) and
106+ isinstance (ground_truths [0 ], ClassificationAnnotation ))
0 commit comments