Skip to content

Commit ede44e7

Browse files
author
Matt Sokoloff
committed
don't add iou in the name for classification
1 parent 48df179 commit ede44e7

File tree

1 file changed

+32
-4
lines changed

1 file changed

+32
-4
lines changed

labelbox/data/metrics/confusion_matrix/confusion_matrix.py

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,10 @@ def confusion_matrix_metric(ground_truths: List[Union[
4040
if value is None:
4141
return []
4242

43-
return [
44-
ConfusionMatrixMetric(metric_name=f"{int(iou*100)}pct_iou", value=value)
45-
]
43+
metric_name = _get_metric_name(annotation_pairs[key][0],
44+
annotation_pairs[key][1], iou)
45+
46+
return [ConfusionMatrixMetric(metric_name=metric_name, value=value)]
4647

4748

4849
def feature_confusion_matrix_metric(
@@ -74,8 +75,35 @@ def feature_confusion_matrix_metric(
7475
include_subclasses, iou)
7576
if value is None:
7677
continue
78+
79+
metric_name = _get_metric_name(annotation_pairs[key][0],
80+
annotation_pairs[key][1], iou)
7781
metrics.append(
78-
ConfusionMatrixMetric(metric_name=f"{int(iou*100)}pct_iou",
82+
ConfusionMatrixMetric(metric_name=name,
7983
feature_name=key,
8084
value=value))
8185
return metrics
86+
87+
88+
def _get_metric_name(ground_truths: List[Union[ObjectAnnotation,
89+
ClassificationAnnotation]],
90+
predictions: List[Union[ObjectAnnotation,
91+
ClassificationAnnotation]],
92+
iou: float):
93+
if _is_classification(ground_truths, predictions):
94+
return "classification"
95+
else:
96+
return f"{int(iou*100)}pct_iou"
97+
98+
99+
def _is_classification(ground_truths: List[Union[ObjectAnnotation,
100+
ClassificationAnnotation]],
101+
predictions: List[Union[ObjectAnnotation,
102+
ClassificationAnnotation]]):
103+
if len(predictions) and isinstance(predictions[0],
104+
ClassificationAnnotation):
105+
return True
106+
elif len(ground_truths) and isinstance(ground_truths[0],
107+
ClassificationAnnotation):
108+
return True
109+
return False

0 commit comments

Comments
 (0)