55
66import pytest
77
8+
89def test_scalar_metric ():
910 value = 10
1011 metric = ScalarMetric (value = value )
@@ -32,17 +33,24 @@ def test_scalar_metric():
3233 next (LabelList ([label ])).dict () == expected
3334
3435
35- @pytest .mark .parametrize (
36- 'feature_name,subclass_name,aggregation' ,
37- [
38- ("cat" , "orange" , MetricAggregation .ARITHMETIC_MEAN ),
39- ("cat" , None , MetricAggregation .ARITHMETIC_MEAN ),
40- (None , None , MetricAggregation .ARITHMETIC_MEAN ),
41- (None , None , None ),
42- ])
36+ @pytest .mark .parametrize ('feature_name,subclass_name,aggregation' , [
37+ ("cat" , "orange" , MetricAggregation .ARITHMETIC_MEAN ),
38+ ("cat" , None , MetricAggregation .ARITHMETIC_MEAN ),
39+ (None , None , MetricAggregation .ARITHMETIC_MEAN ),
40+ (None , None , None ),
41+ ("cat" , "orange" , MetricAggregation .ARITHMETIC_MEAN ),
42+ ("cat" , None , MetricAggregation .HARMONIC_MEAN ),
43+ (None , None , MetricAggregation .GEOMETRIC_MEAN ),
44+ (None , None , MetricAggregation .SUM ),
45+ ])
4346def test_custom_scalar_metric (feature_name , subclass_name , aggregation ):
4447 value = 0.5
45- metric = CustomScalarMetric (metric_name = "iou" , value = value , feature_name = feature_name , subclass_name = subclass_name , aggregation = aggregation )
48+ kwargs = {'aggregation' : aggregation } if aggregation is not None else {}
49+ metric = CustomScalarMetric (metric_name = "iou" ,
50+ value = value ,
51+ feature_name = feature_name ,
52+ subclass_name = subclass_name ,
53+ ** kwargs )
4654 assert metric .value == value
4755
4856 label = Label (data = ImageData (uid = "ckrmd9q8g000009mg6vej7hzg" ),
@@ -57,15 +65,15 @@ def test_custom_scalar_metric(feature_name, subclass_name, aggregation):
5765 'arr' : None
5866 },
5967 'annotations' : [{
60- 'value' : 10.0 ,
61-
68+ 'value' : value ,
69+ 'metric_name' : 'iou' ,
70+ 'feature_name' : feature_name ,
71+ 'subclass_name' : subclass_name ,
72+ 'aggregation' : aggregation or MetricAggregation .ARITHMETIC_MEAN ,
6273 'extra' : {}
6374 }],
6475 'extra' : {},
6576 'uid' : None
6677 }
6778 assert label .dict () == expected
6879 next (LabelList ([label ])).dict () == expected
69-
70-
71-
0 commit comments