Skip to content

Commit da5d53a

Browse files
committed
merge optimize branch
2 parents 3db3559 + d1e8d8c commit da5d53a

File tree

7 files changed

+174
-32
lines changed

7 files changed

+174
-32
lines changed

specparam/core/funcs.py

Lines changed: 5 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,7 @@ def gaussian_function(xs, *params):
3232

3333
ys = np.zeros_like(xs)
3434

35-
for ii in range(0, len(params), 3):
36-
37-
ctr, hgt, wid = params[ii:ii+3]
35+
for ctr, hgt, wid in zip(*[iter(params)] * 3):
3836

3937
ys = ys + hgt * np.exp(-(xs-ctr)**2 / (2*wid**2))
4038

@@ -60,11 +58,8 @@ def expo_function(xs, *params):
6058
Output values for exponential function.
6159
"""
6260

63-
ys = np.zeros_like(xs)
64-
6561
offset, knee, exp = params
66-
67-
ys = ys + offset - np.log10(knee + xs**exp)
62+
ys = offset - np.log10(knee + xs**exp)
6863

6964
return ys
7065

@@ -88,11 +83,8 @@ def expo_nk_function(xs, *params):
8883
Output values for exponential function, without a knee.
8984
"""
9085

91-
ys = np.zeros_like(xs)
92-
9386
offset, exp = params
94-
95-
ys = ys + offset - np.log10(xs**exp)
87+
ys = offset - np.log10(xs**exp)
9688

9789
return ys
9890

@@ -113,11 +105,8 @@ def linear_function(xs, *params):
113105
Output values for linear function.
114106
"""
115107

116-
ys = np.zeros_like(xs)
117-
118108
offset, slope = params
119-
120-
ys = ys + offset + (xs*slope)
109+
ys = offset + (xs*slope)
121110

122111
return ys
123112

@@ -138,11 +127,8 @@ def quadratic_function(xs, *params):
138127
Output values for quadratic function.
139128
"""
140129

141-
ys = np.zeros_like(xs)
142-
143130
offset, slope, curve = params
144-
145-
ys = ys + offset + (xs*slope) + ((xs**2)*curve)
131+
ys = offset + (xs*slope) + ((xs**2)*curve)
146132

147133
return ys
148134

specparam/core/jacobians.py

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
""""Functions for computing Jacobian matrices to be used during fitting.
2+
3+
Notes
4+
-----
5+
These functions line up with those in `funcs`.
6+
The parameters in these functions are labeled {a, b, c, ...}, but follow the order in `funcs`.
7+
These functions are designed to be passed into `curve_fit` to provide a computed Jacobian.
8+
"""
9+
10+
import numpy as np
11+
12+
###################################################################################################
13+
###################################################################################################
14+
15+
## Periodic Jacobian functions
16+
17+
def jacobian_gauss(xs, *params):
18+
"""Create the Jacobian matrix for the Gaussian function.
19+
20+
Parameters
21+
----------
22+
xs : 1d array
23+
Input x-axis values.
24+
*params : float
25+
Parameters for the function.
26+
27+
Returns
28+
-------
29+
jacobian : 2d array
30+
Jacobian matrix, with shape [len(xs), n_params].
31+
"""
32+
33+
jacobian = np.zeros((len(xs), len(params)))
34+
35+
for i, (a, b, c) in enumerate(zip(*[iter(params)] * 3)):
36+
37+
ax = -a + xs
38+
ax2 = ax**2
39+
40+
c2 = c**2
41+
c3 = c**3
42+
43+
exp = np.exp(-ax2 / (2 * c2))
44+
exp_b = exp * b
45+
46+
ii = i * 3
47+
jacobian[:, ii] = (exp_b * ax) / c2
48+
jacobian[:, ii+1] = exp
49+
jacobian[:, ii+2] = (exp_b * ax2) / c3
50+
51+
return jacobian
52+
53+
54+
## Aperiodic Jacobian functions
55+
56+
def jacobian_expo(xs, *params):
57+
"""Create the Jacobian matrix for the exponential function.
58+
59+
Parameters
60+
----------
61+
xs : 1d array
62+
Input x-axis values.
63+
*params : float
64+
Parameters for the function.
65+
66+
Returns
67+
-------
68+
jacobian : 2d array
69+
Jacobian matrix, with shape [len(xs), n_params].
70+
"""
71+
72+
a, b, c = params
73+
74+
xs_c = xs**c
75+
b_xs_c = xs_c + b
76+
77+
jacobian = np.ones((len(xs), len(params)))
78+
jacobian[:, 1] = -1 / b_xs_c
79+
jacobian[:, 2] = -(xs_c * np.log10(xs)) / b_xs_c
80+
81+
return jacobian
82+
83+
84+
def jacobian_expo_nk(xs, *params):
85+
"""Create the Jacobian matrix for the exponential no-knee function.
86+
87+
Parameters
88+
----------
89+
xs : 1d array
90+
Input x-axis values.
91+
*params : float
92+
Parameters for the function.
93+
94+
Returns
95+
-------
96+
jacobian : 2d array
97+
Jacobian matrix, with shape [len(xs), n_params].
98+
"""
99+
100+
jacobian = np.ones((len(xs), len(params)))
101+
jacobian[:, 1] = -np.log10(xs)
102+
103+
return jacobian

specparam/objs/fit.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@
7171
from specparam.core.modutils import copy_doc_func_to_method
7272
from specparam.core.utils import group_three, check_array_dim
7373
from specparam.core.funcs import gaussian_function, get_ap_func, infer_ap_func
74+
from specparam.core.jacobians import jacobian_gauss
7475
from specparam.core.errors import (FitError, NoModelError, DataError,
7576
NoDataError, InconsistentDataError)
7677
from specparam.core.strings import (gen_settings_str, gen_model_results_str,
@@ -191,12 +192,17 @@ def __init__(self, peak_width_limits=(0.5, 12.0), max_n_peaks=np.inf, min_peak_h
191192
self._gauss_overlap_thresh = 0.75
192193
# Parameter bounds for center frequency when fitting gaussians, in terms of +/- std dev
193194
self._cf_bound = 1.5
194-
# The maximum number of calls to the curve fitting function
195-
self._maxfev = 5000
196195
# The error metric to calculate, post model fitting. See `_calc_error` for options
197196
# Note: this is for checking error post fitting, not an objective function for fitting
198197
self._error_metric = 'MAE'
199198

199+
## PRIVATE CURVE_FIT SETTINGS
200+
# The maximum number of calls to the curve fitting function
201+
self._maxfev = 5000
202+
# The tolerance setting for curve fitting (see scipy.curve_fit - ftol / xtol / gtol)
203+
# Here reduce tolerance to speed fitting. Set value to 1e-8 to match curve_fit default
204+
self._tol = 0.00001
205+
200206
## RUN MODES
201207
# Set default debug mode - controls if an error is raised if model fitting is unsuccessful
202208
self._debug = False
@@ -944,7 +950,9 @@ def _simple_ap_fit(self, freqs, power_spectrum):
944950
warnings.simplefilter("ignore")
945951
aperiodic_params, _ = curve_fit(get_ap_func(self.aperiodic_mode),
946952
freqs, power_spectrum, p0=guess,
947-
maxfev=self._maxfev, bounds=ap_bounds)
953+
maxfev=self._maxfev, bounds=ap_bounds,
954+
ftol=self._tol, xtol=self._tol, gtol=self._tol,
955+
check_finite=False)
948956
except RuntimeError as excp:
949957
error_msg = ("Model fitting failed due to not finding parameters in "
950958
"the simple aperiodic component fit.")
@@ -1001,7 +1009,9 @@ def _robust_ap_fit(self, freqs, power_spectrum):
10011009
warnings.simplefilter("ignore")
10021010
aperiodic_params, _ = curve_fit(get_ap_func(self.aperiodic_mode),
10031011
freqs_ignore, spectrum_ignore, p0=popt,
1004-
maxfev=self._maxfev, bounds=ap_bounds)
1012+
maxfev=self._maxfev, bounds=ap_bounds,
1013+
ftol=self._tol, xtol=self._tol, gtol=self._tol,
1014+
check_finite=False)
10051015
except RuntimeError as excp:
10061016
error_msg = ("Model fitting failed due to not finding "
10071017
"parameters in the robust aperiodic fit.")
@@ -1147,7 +1157,9 @@ def _fit_peak_guess(self, guess):
11471157
# Fit the peaks
11481158
try:
11491159
gaussian_params, _ = curve_fit(gaussian_function, self.freqs, self._spectrum_flat,
1150-
p0=guess, maxfev=self._maxfev, bounds=gaus_param_bounds)
1160+
p0=guess, maxfev=self._maxfev, bounds=gaus_param_bounds,
1161+
ftol=self._tol, xtol=self._tol, gtol=self._tol,
1162+
check_finite=False, jac=jacobian_gauss)
11511163
except RuntimeError as excp:
11521164
error_msg = ("Model fitting failed due to not finding "
11531165
"parameters in the peak component fit.")

specparam/plts/spectra.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,8 @@ def plot_spectra(freqs, power_spectra, log_freqs=False, log_powers=False, freq_r
6060
freq_range = np.log10(freq_range) if log_freqs else freq_range
6161

6262
# Make inputs iterable if need to be passed multiple times to plot each spectrum
63-
plt_powers = np.reshape(power_spectra, (1, -1)) if np.ndim(power_spectra) == 1 else \
64-
power_spectra
63+
plt_powers = np.reshape(power_spectra, (1, -1)) if isinstance(freqs, np.ndarray) and \
64+
np.ndim(power_spectra) == 1 else power_spectra
6565
plt_freqs = repeat(freqs) if isinstance(freqs, np.ndarray) and freqs.ndim == 1 else freqs
6666

6767
# Set labels
@@ -131,6 +131,10 @@ def plot_spectra_shading(freqs, power_spectra, shades, shade_colors='r',
131131
plot_kwargs.get('log_powers', False))
132132

133133

134+
# Alias `plot_spectrum_shading` to `plot_spectra_shading` for backwards compatibility
135+
plot_spectrum_shading = plot_spectra_shading
136+
137+
134138
@savefig
135139
@style_plot
136140
@check_dependency(plt, 'matplotlib')
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
"""Tests for fooof.core.jacobians."""
2+
3+
from fooof.core.jacobians import *
4+
5+
###################################################################################################
6+
###################################################################################################
7+
8+
def test_jacobian_gauss():
9+
10+
xs = np.arange(1, 100)
11+
ctr, hgt, wid = 50, 5, 10
12+
13+
jacobian = jacobian_gauss(xs, ctr, hgt, wid)
14+
assert isinstance(jacobian, np.ndarray)
15+
assert jacobian.shape == (len(xs), 3)
16+
17+
def test_jacobian_expo():
18+
19+
xs = np.arange(1, 100)
20+
off, knee, exp = 10, 5, 2
21+
22+
jacobian = jacobian_expo(xs, off, knee, exp)
23+
assert isinstance(jacobian, np.ndarray)
24+
assert jacobian.shape == (len(xs), 3)
25+
26+
def test_jacobian_expo_nk():
27+
28+
xs = np.arange(1, 100)
29+
off, exp = 10, 2
30+
31+
jacobian = jacobian_expo_nk(xs, off, exp)
32+
assert isinstance(jacobian, np.ndarray)
33+
assert jacobian.shape == (len(xs), 2)

specparam/tests/objs/test_fit.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -391,7 +391,7 @@ def test_fit_failure():
391391

392392
## Induce a runtime error, and check it runs through
393393
tfm = SpectralModel(verbose=False)
394-
tfm._maxfev = 5
394+
tfm._maxfev = 2
395395

396396
tfm.fit(*sim_power_spectrum([3, 50], [50, 2], [10, 0.5, 2, 20, 0.3, 4]))
397397

@@ -417,7 +417,7 @@ def test_debug():
417417
"""Test model object in debug mode, including with fit failures."""
418418

419419
tfm = SpectralModel(verbose=False)
420-
tfm._maxfev = 5
420+
tfm._maxfev = 2
421421

422422
tfm.set_debug_mode(True)
423423
assert tfm._debug is True

specparam/tests/plts/test_spectra.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,18 +15,22 @@
1515
@plot_test
1616
def test_plot_spectra(tfm, tfg, skip_if_no_mpl):
1717

18-
# Test with 1d inputs - 1d freq array and list of 1d power spectra
18+
# Test with 1d inputs - 1d freq array & list of 1d power spectra
1919
plot_spectra(tfm.freqs, tfm.power_spectrum,
2020
file_path=TEST_PLOTS_PATH, file_name='test_plot_spectra_1d.png')
2121

22-
# Test with 1d inputs - 1d freq array and list of 1d power spectra
22+
# Test with 1d inputs - 1d freq array & list of 1d power spectra
2323
plot_spectra(tfg.freqs, [tfg.power_spectra[0, :], tfg.power_spectra[1, :]],
2424
file_path=TEST_PLOTS_PATH, file_name='test_plot_spectra_list_1d.png')
2525

2626
# Test with multiple freq inputs - list of 1d freq array and list of 1d power spectra
2727
plot_spectra([tfg.freqs, tfg.freqs], [tfg.power_spectra[0, :], tfg.power_spectra[1, :]],
28-
file_path=TEST_PLOTS_PATH,
29-
file_name='test_plot_spectra_lists_1d.png')
28+
file_path=TEST_PLOTS_PATH, file_name='test_plot_spectra_list_1d_freqs.png')
29+
30+
# Test with multiple lists - list of 1d freqs & list of 1d power spectra (different f ranges)
31+
plot_spectra([tfg.freqs, tfg.freqs[:-5]],
32+
[tfg.power_spectra[0, :], tfg.power_spectra[1, :-5]],
33+
file_path=TEST_PLOTS_PATH, file_name='test_plot_spectra_lists_1d.png')
3034

3135
# Test with 2d array inputs
3236
plot_spectra(np.vstack([tfg.freqs, tfg.freqs]),

0 commit comments

Comments
 (0)