1+ from labelbox .data .annotation_types .metrics .aggregations import MetricAggregation
2+ from labelbox .data .annotation_types .metrics .scalar import CustomScalarMetric
13from labelbox .data .annotation_types .collection import LabelList
24from labelbox .data .annotation_types import ScalarMetric , Label , ImageData
35
6+ import pytest
7+
48
59def test_scalar_metric ():
610 value = 10
@@ -27,3 +31,49 @@ def test_scalar_metric():
2731 }
2832 assert label .dict () == expected
2933 next (LabelList ([label ])).dict () == expected
34+
35+
36+ @pytest .mark .parametrize ('feature_name,subclass_name,aggregation' , [
37+ ("cat" , "orange" , MetricAggregation .ARITHMETIC_MEAN ),
38+ ("cat" , None , MetricAggregation .ARITHMETIC_MEAN ),
39+ (None , None , MetricAggregation .ARITHMETIC_MEAN ),
40+ (None , None , None ),
41+ ("cat" , "orange" , MetricAggregation .ARITHMETIC_MEAN ),
42+ ("cat" , None , MetricAggregation .HARMONIC_MEAN ),
43+ (None , None , MetricAggregation .GEOMETRIC_MEAN ),
44+ (None , None , MetricAggregation .SUM ),
45+ ])
46+ def test_custom_scalar_metric (feature_name , subclass_name , aggregation ):
47+ value = 0.5
48+ kwargs = {'aggregation' : aggregation } if aggregation is not None else {}
49+ metric = CustomScalarMetric (metric_name = "iou" ,
50+ value = value ,
51+ feature_name = feature_name ,
52+ subclass_name = subclass_name ,
53+ ** kwargs )
54+ assert metric .value == value
55+
56+ label = Label (data = ImageData (uid = "ckrmd9q8g000009mg6vej7hzg" ),
57+ annotations = [metric ])
58+ expected = {
59+ 'data' : {
60+ 'external_id' : None ,
61+ 'uid' : 'ckrmd9q8g000009mg6vej7hzg' ,
62+ 'im_bytes' : None ,
63+ 'file_path' : None ,
64+ 'url' : None ,
65+ 'arr' : None
66+ },
67+ 'annotations' : [{
68+ 'value' : value ,
69+ 'metric_name' : 'iou' ,
70+ 'feature_name' : feature_name ,
71+ 'subclass_name' : subclass_name ,
72+ 'aggregation' : aggregation or MetricAggregation .ARITHMETIC_MEAN ,
73+ 'extra' : {}
74+ }],
75+ 'extra' : {},
76+ 'uid' : None
77+ }
78+ assert label .dict () == expected
79+ next (LabelList ([label ])).dict () == expected
0 commit comments