@@ -40,9 +40,7 @@ def confusion_matrix_metric(ground_truths: List[Union[
4040 if value is None :
4141 return []
4242
43- metric_name = _get_metric_name (annotation_pairs [key ][0 ],
44- annotation_pairs [key ][1 ], iou )
45-
43+ metric_name = _get_metric_name (ground_truths , predictions , iou )
4644 return [ConfusionMatrixMetric (metric_name = metric_name , value = value )]
4745
4846
@@ -90,20 +88,19 @@ def _get_metric_name(ground_truths: List[Union[ObjectAnnotation,
9088 predictions : List [Union [ObjectAnnotation ,
9189 ClassificationAnnotation ]],
9290 iou : float ):
91+
9392 if _is_classification (ground_truths , predictions ):
9493 return "classification"
95- else :
96- return f"{ int (iou * 100 )} pct_iou"
94+
95+ return f"{ int (iou * 100 )} pct_iou"
9796
9897
9998def _is_classification (ground_truths : List [Union [ObjectAnnotation ,
10099 ClassificationAnnotation ]],
101100 predictions : List [Union [ObjectAnnotation ,
102101 ClassificationAnnotation ]]):
103- if len (predictions ) and isinstance (predictions [0 ],
104- ClassificationAnnotation ):
105- return True
106- elif len (ground_truths ) and isinstance (ground_truths [0 ],
107- ClassificationAnnotation ):
108- return True
109- return False
102+ # Check if either the prediction or label contains a classification annotation
103+ return (len (predictions ) and
104+ isinstance (predictions [0 ], ClassificationAnnotation ) or
105+ len (ground_truths ) and
106+ isinstance (ground_truths [0 ], ClassificationAnnotation ))
0 commit comments