Skip to content

Commit ca1260b

Browse files
committed
refactor metric definitions
1 parent a8bd2fb commit ca1260b

File tree

1 file changed

+51
-8
lines changed

1 file changed

+51
-8
lines changed

specparam/measures/metrics.py

Lines changed: 51 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,19 +6,62 @@
66
from specparam.measures.gof import compute_r_squared, compute_adj_r_squared
77

88
###################################################################################################
9+
## ERROR METRICS
10+
11+
error_mae = Metric(
12+
category='error',
13+
measure='mae',
14+
func=compute_mean_abs_error,
15+
)
16+
17+
error_mse = Metric(
18+
category='error',
19+
measure='mse',
20+
func=compute_mean_squared_error
21+
)
22+
23+
error_rmse = Metric(
24+
category='error',
25+
measure='rmse',
26+
func=compute_root_mean_squared_error,
27+
)
28+
29+
error_medae = Metric(
30+
category='error',
31+
measure='medae',
32+
func=compute_median_abs_error,
33+
)
34+
935
###################################################################################################
36+
## GOF
37+
38+
gof_rsquared = Metric(
39+
category='gof',
40+
measure='rsquared',
41+
func=compute_r_squared,
42+
)
43+
44+
gof_adjrsquared = Metric(
45+
category='gof',
46+
measure='adjrsquared',
47+
func=compute_adj_r_squared,
48+
kwargs={'n_params' : lambda data, results: \
49+
results.params.periodic.params.size + results.params.aperiodic.params.size},
50+
)
51+
52+
###################################################################################################
53+
## COLLECT ALL METRICS TOGETHER
1054

1155
METRICS = {
1256

1357
# Available error metrics
14-
'error_mae' : Metric('error', 'mae', compute_mean_abs_error),
15-
'error_mse' : Metric('error', 'mse', compute_mean_squared_error),
16-
'error_rmse' : Metric('error', 'rmse', compute_root_mean_squared_error),
17-
'error_medae' : Metric('error', 'medae', compute_median_abs_error),
58+
'error_mae' : error_mae,
59+
'error_mse' : error_mse,
60+
'error_rmse' : error_rmse,
61+
'error_medae' : error_medae,
1862

1963
# Available GOF / r-squared metrics
20-
'gof_rsquared' : Metric('gof', 'rsquared', compute_r_squared),
21-
'gof_adjrsquared' : Metric('gof', 'adjrsquared', compute_adj_r_squared, \
22-
{'n_params' : lambda data, results: \
23-
results.params.periodic.params.size + results.params.aperiodic.params.size})
64+
'gof_rsquared' : gof_rsquared,
65+
'gof_adjrsquared' : gof_adjrsquared,
66+
2467
}

0 commit comments

Comments
 (0)