Skip to content

Commit 8f09793

Browse files
committed
AL-4149: Validate metric names against reserved names
1 parent 0f38024 commit 8f09793

File tree

3 files changed

+33
-5
lines changed

3 files changed

+33
-5
lines changed

labelbox/data/annotation_types/metrics/scalar.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from typing import Dict, Optional, Union
22
from enum import Enum
33

4-
from pydantic import confloat
4+
from pydantic import confloat, validator
55

66
from .base import ConfidenceValue, BaseMetric
77

@@ -16,6 +16,18 @@ class ScalarMetricAggregation(Enum):
1616
SUM = "SUM"
1717

1818

19+
RESERVED_METRIC_NAMES = (
20+
'true_positive_count',
21+
'false_positive_count',
22+
'true_negative_count',
23+
'false_negative_count',
24+
'precision',
25+
'recall',
26+
'f1',
27+
'iou'
28+
)
29+
30+
1931
class ScalarMetric(BaseMetric):
2032
""" Class representing scalar metrics
2133
@@ -28,6 +40,14 @@ class ScalarMetric(BaseMetric):
2840
value: Union[ScalarMetricValue, ScalarMetricConfidenceValue]
2941
aggregation: ScalarMetricAggregation = ScalarMetricAggregation.ARITHMETIC_MEAN
3042

43+
@validator('metric_name')
44+
def validate_metric_name(cls, name: str):
45+
clean_name = name.lower().strip()
46+
if name.lower().strip() in RESERVED_METRIC_NAMES:
47+
raise ValueError(f"`{clean_name}` is a reserved metric name. "
48+
"Please provide another value for `metric_name`.")
49+
return clean_name
50+
3151
def dict(self, *args, **kwargs):
3252
res = super().dict(*args, **kwargs)
3353
if res.get('metric_name') is None:

labelbox/data/metrics/iou/iou.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def miou_metric(ground_truths: List[Union[ObjectAnnotation,
3131
# If both gt and preds are empty there is no metric
3232
if iou is None:
3333
return []
34-
return [ScalarMetric(metric_name="iou", value=iou)]
34+
return [ScalarMetric(metric_name="custom_iou", value=iou)]
3535

3636

3737
def feature_miou_metric(ground_truths: List[Union[ObjectAnnotation,
@@ -62,7 +62,7 @@ def feature_miou_metric(ground_truths: List[Union[ObjectAnnotation,
6262
if value is None:
6363
continue
6464
metrics.append(
65-
ScalarMetric(metric_name="iou", feature_name=key, value=value))
65+
ScalarMetric(metric_name="custom_iou", feature_name=key, value=value))
6666
return metrics
6767

6868

tests/data/annotation_types/test_metrics.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from labelbox.data.annotation_types.metrics import ConfusionMatrixMetric, ScalarMetric
66
from labelbox.data.annotation_types.collection import LabelList
77
from labelbox.data.annotation_types import ScalarMetric, Label, ImageData
8+
from labelbox.data.annotation_types.metrics.scalar import RESERVED_METRIC_NAMES
89

910

1011
def test_legacy_scalar_metric():
@@ -56,7 +57,7 @@ def test_legacy_scalar_metric():
5657
])
5758
def test_custom_scalar_metric(feature_name, subclass_name, aggregation, value):
5859
kwargs = {'aggregation': aggregation} if aggregation is not None else {}
59-
metric = ScalarMetric(metric_name="iou",
60+
metric = ScalarMetric(metric_name="custom_iou",
6061
value=value,
6162
feature_name=feature_name,
6263
subclass_name=subclass_name,
@@ -80,7 +81,7 @@ def test_custom_scalar_metric(feature_name, subclass_name, aggregation, value):
8081
'value':
8182
value,
8283
'metric_name':
83-
'iou',
84+
'custom_iou',
8485
**({
8586
'feature_name': feature_name
8687
} if feature_name else {}),
@@ -192,3 +193,10 @@ def test_invalid_number_of_confidence_scores():
192193
metric_name="too many scores",
193194
value={i / 20.: [0, 1, 2, 3] for i in range(20)})
194195
assert "Number of confidence scores must be greater" in str(exc_info.value)
196+
197+
198+
@pytest.mark.parametrize("metric_name", RESERVED_METRIC_NAMES)
199+
def test_reserved_names(metric_name: str):
200+
with pytest.raises(ValidationError) as exc_info:
201+
ScalarMetric(metric_name=metric_name, value=0.5)
202+
assert 'is a reserved metric name' in exc_info.value.errors()[0]['msg']

0 commit comments

Comments
 (0)