Skip to content

Commit 6df07ce

Browse files
Merge pull request #770 from Labelbox/sdubinin/al-4149
[AL-4149] Validate metric names against reserved names
2 parents 9682825 + 9d5883a commit 6df07ce

File tree

5 files changed

+34
-9
lines changed

5 files changed

+34
-9
lines changed

labelbox/data/annotation_types/metrics/scalar.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from typing import Dict, Optional, Union
22
from enum import Enum
33

4-
from pydantic import confloat
4+
from pydantic import confloat, validator
55

66
from .base import ConfidenceValue, BaseMetric
77

@@ -16,6 +16,11 @@ class ScalarMetricAggregation(Enum):
1616
SUM = "SUM"
1717

1818

19+
RESERVED_METRIC_NAMES = ('true_positive_count', 'false_positive_count',
20+
'true_negative_count', 'false_negative_count',
21+
'precision', 'recall', 'f1', 'iou')
22+
23+
1924
class ScalarMetric(BaseMetric):
2025
""" Class representing scalar metrics
2126
@@ -28,6 +33,16 @@ class ScalarMetric(BaseMetric):
2833
value: Union[ScalarMetricValue, ScalarMetricConfidenceValue]
2934
aggregation: ScalarMetricAggregation = ScalarMetricAggregation.ARITHMETIC_MEAN
3035

36+
@validator('metric_name')
37+
def validate_metric_name(cls, name: Union[str, None]):
38+
if name is None:
39+
return None
40+
clean_name = name.lower().strip()
41+
if clean_name in RESERVED_METRIC_NAMES:
42+
raise ValueError(f"`{clean_name}` is a reserved metric name. "
43+
"Please provide another value for `metric_name`.")
44+
return name
45+
3146
def dict(self, *args, **kwargs):
3247
res = super().dict(*args, **kwargs)
3348
if res.get('metric_name') is None:

labelbox/data/metrics/iou/iou.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def miou_metric(ground_truths: List[Union[ObjectAnnotation,
3131
# If both gt and preds are empty there is no metric
3232
if iou is None:
3333
return []
34-
return [ScalarMetric(metric_name="iou", value=iou)]
34+
return [ScalarMetric(metric_name="custom_iou", value=iou)]
3535

3636

3737
def feature_miou_metric(ground_truths: List[Union[ObjectAnnotation,
@@ -62,7 +62,9 @@ def feature_miou_metric(ground_truths: List[Union[ObjectAnnotation,
6262
if value is None:
6363
continue
6464
metrics.append(
65-
ScalarMetric(metric_name="iou", feature_name=key, value=value))
65+
ScalarMetric(metric_name="custom_iou",
66+
feature_name=key,
67+
value=value))
6668
return metrics
6769

6870

tests/data/annotation_types/test_metrics.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from labelbox.data.annotation_types.metrics import ConfusionMatrixMetric, ScalarMetric
66
from labelbox.data.annotation_types.collection import LabelList
77
from labelbox.data.annotation_types import ScalarMetric, Label, ImageData
8+
from labelbox.data.annotation_types.metrics.scalar import RESERVED_METRIC_NAMES
89

910

1011
def test_legacy_scalar_metric():
@@ -56,7 +57,7 @@ def test_legacy_scalar_metric():
5657
])
5758
def test_custom_scalar_metric(feature_name, subclass_name, aggregation, value):
5859
kwargs = {'aggregation': aggregation} if aggregation is not None else {}
59-
metric = ScalarMetric(metric_name="iou",
60+
metric = ScalarMetric(metric_name="custom_iou",
6061
value=value,
6162
feature_name=feature_name,
6263
subclass_name=subclass_name,
@@ -80,7 +81,7 @@ def test_custom_scalar_metric(feature_name, subclass_name, aggregation, value):
8081
'value':
8182
value,
8283
'metric_name':
83-
'iou',
84+
'custom_iou',
8485
**({
8586
'feature_name': feature_name
8687
} if feature_name else {}),
@@ -192,3 +193,10 @@ def test_invalid_number_of_confidence_scores():
192193
metric_name="too many scores",
193194
value={i / 20.: [0, 1, 2, 3] for i in range(20)})
194195
assert "Number of confidence scores must be greater" in str(exc_info.value)
196+
197+
198+
@pytest.mark.parametrize("metric_name", RESERVED_METRIC_NAMES)
199+
def test_reserved_names(metric_name: str):
200+
with pytest.raises(ValidationError) as exc_info:
201+
ScalarMetric(metric_name=metric_name, value=0.5)
202+
assert 'is a reserved metric name' in exc_info.value.errors()[0]['msg']
Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
[{"uuid" : "a22bbf6e-b2da-4abe-9a11-df84759f7672","dataRow" : {"id": "ckrmdnqj4000007msh9p2a27r"}, "metricValue" : 0.1, "metricName" : "iou", "featureName" : "sample_class", "subclassName" : "sample_subclass", "aggregation" : "SUM"},
2-
{"uuid" : "a22bbf6e-b2da-4abe-9a11-df84759f7672","dataRow" : {"id": "ckrmdnqj4000007msh9p2a27r"}, "metricValue" : 0.1, "metricName" : "iou", "featureName" : "sample_class", "aggregation" : "SUM"},
3-
{"uuid" : "a22bbf6e-b2da-4abe-9a11-df84759f7672","dataRow" : {"id": "ckrmdnqj4000007msh9p2a27r"}, "metricValue" : { "0.1" : 0.1, "0.2" : 0.5}, "metricName" : "iou", "aggregation" : "SUM"}]
1+
[{"uuid" : "a22bbf6e-b2da-4abe-9a11-df84759f7672","dataRow" : {"id": "ckrmdnqj4000007msh9p2a27r"}, "metricValue" : 0.1, "metricName" : "custom_iou", "featureName" : "sample_class", "subclassName" : "sample_subclass", "aggregation" : "SUM"},
2+
{"uuid" : "a22bbf6e-b2da-4abe-9a11-df84759f7672","dataRow" : {"id": "ckrmdnqj4000007msh9p2a27r"}, "metricValue" : 0.1, "metricName" : "custom_iou", "featureName" : "sample_class", "aggregation" : "SUM"},
3+
{"uuid" : "a22bbf6e-b2da-4abe-9a11-df84759f7672","dataRow" : {"id": "ckrmdnqj4000007msh9p2a27r"}, "metricValue" : { "0.1" : 0.1, "0.2" : 0.5}, "metricName" : "custom_iou", "aggregation" : "SUM"}]

tests/data/metrics/iou/feature/test_feature_iou.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ def check_iou(pair):
1313
assert math.isclose(result[key], pair.expected[key])
1414

1515
for metric in metrics:
16-
assert metric.metric_name == "iou"
16+
assert metric.metric_name == "custom_iou"
1717

1818
if len(pair.expected):
1919
assert len(one_metrics)

0 commit comments

Comments
 (0)