Skip to content

Commit 9507039

Browse files
committed
merge name
2 parents 97e45db + 9d854fd commit 9507039

File tree

19 files changed

+308
-73
lines changed

19 files changed

+308
-73
lines changed

doc/api.rst

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -326,6 +326,27 @@ Annotated plots that describe the model and fitting process.
326326
plot_annotated_model
327327
plot_annotated_peak_search
328328

329+
Plot Utilities & Styling
330+
~~~~~~~~~~~~~~~~~~~~~~~~
331+
332+
Plot related utilies for styling and managing plots.
333+
334+
.. currentmodule:: fooof.plts.style
335+
336+
.. autosummary::
337+
:toctree: generated/
338+
339+
check_style_options
340+
341+
.. currentmodule:: fooof.plts.utils
342+
343+
.. autosummary::
344+
:toctree: generated/
345+
346+
check_ax
347+
recursive_plot
348+
save_figure
349+
329350
Utilities
330351
---------
331352

specparam/core/funcs.py

Lines changed: 5 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,7 @@ def gaussian_function(xs, *params):
3232

3333
ys = np.zeros_like(xs)
3434

35-
for ii in range(0, len(params), 3):
36-
37-
ctr, hgt, wid = params[ii:ii+3]
35+
for ctr, hgt, wid in zip(*[iter(params)] * 3):
3836

3937
ys = ys + hgt * np.exp(-(xs-ctr)**2 / (2*wid**2))
4038

@@ -60,11 +58,8 @@ def expo_function(xs, *params):
6058
Output values for exponential function.
6159
"""
6260

63-
ys = np.zeros_like(xs)
64-
6561
offset, knee, exp = params
66-
67-
ys = ys + offset - np.log10(knee + xs**exp)
62+
ys = offset - np.log10(knee + xs**exp)
6863

6964
return ys
7065

@@ -88,11 +83,8 @@ def expo_nk_function(xs, *params):
8883
Output values for exponential function, without a knee.
8984
"""
9085

91-
ys = np.zeros_like(xs)
92-
9386
offset, exp = params
94-
95-
ys = ys + offset - np.log10(xs**exp)
87+
ys = offset - np.log10(xs**exp)
9688

9789
return ys
9890

@@ -113,11 +105,8 @@ def linear_function(xs, *params):
113105
Output values for linear function.
114106
"""
115107

116-
ys = np.zeros_like(xs)
117-
118108
offset, slope = params
119-
120-
ys = ys + offset + (xs*slope)
109+
ys = offset + (xs*slope)
121110

122111
return ys
123112

@@ -138,11 +127,8 @@ def quadratic_function(xs, *params):
138127
Output values for quadratic function.
139128
"""
140129

141-
ys = np.zeros_like(xs)
142-
143130
offset, slope, curve = params
144-
145-
ys = ys + offset + (xs*slope) + ((xs**2)*curve)
131+
ys = offset + (xs*slope) + ((xs**2)*curve)
146132

147133
return ys
148134

specparam/core/jacobians.py

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
""""Functions for computing Jacobian matrices to be used during fitting.
2+
3+
Notes
4+
-----
5+
These functions line up with those in `funcs`.
6+
The parameters in these functions are labeled {a, b, c, ...}, but follow the order in `funcs`.
7+
These functions are designed to be passed into `curve_fit` to provide a computed Jacobian.
8+
"""
9+
10+
import numpy as np
11+
12+
###################################################################################################
13+
###################################################################################################
14+
15+
## Periodic Jacobian functions
16+
17+
def jacobian_gauss(xs, *params):
18+
"""Create the Jacobian matrix for the Gaussian function.
19+
20+
Parameters
21+
----------
22+
xs : 1d array
23+
Input x-axis values.
24+
*params : float
25+
Parameters for the function.
26+
27+
Returns
28+
-------
29+
jacobian : 2d array
30+
Jacobian matrix, with shape [len(xs), n_params].
31+
"""
32+
33+
jacobian = np.zeros((len(xs), len(params)))
34+
35+
for i, (a, b, c) in enumerate(zip(*[iter(params)] * 3)):
36+
37+
ax = -a + xs
38+
ax2 = ax**2
39+
40+
c2 = c**2
41+
c3 = c**3
42+
43+
exp = np.exp(-ax2 / (2 * c2))
44+
exp_b = exp * b
45+
46+
ii = i * 3
47+
jacobian[:, ii] = (exp_b * ax) / c2
48+
jacobian[:, ii+1] = exp
49+
jacobian[:, ii+2] = (exp_b * ax2) / c3
50+
51+
return jacobian
52+
53+
54+
## Aperiodic Jacobian functions
55+
56+
def jacobian_expo(xs, *params):
57+
"""Create the Jacobian matrix for the exponential function.
58+
59+
Parameters
60+
----------
61+
xs : 1d array
62+
Input x-axis values.
63+
*params : float
64+
Parameters for the function.
65+
66+
Returns
67+
-------
68+
jacobian : 2d array
69+
Jacobian matrix, with shape [len(xs), n_params].
70+
"""
71+
72+
a, b, c = params
73+
74+
xs_c = xs**c
75+
b_xs_c = xs_c + b
76+
77+
jacobian = np.ones((len(xs), len(params)))
78+
jacobian[:, 1] = -1 / b_xs_c
79+
jacobian[:, 2] = -(xs_c * np.log10(xs)) / b_xs_c
80+
81+
return jacobian
82+
83+
84+
def jacobian_expo_nk(xs, *params):
85+
"""Create the Jacobian matrix for the exponential no-knee function.
86+
87+
Parameters
88+
----------
89+
xs : 1d array
90+
Input x-axis values.
91+
*params : float
92+
Parameters for the function.
93+
94+
Returns
95+
-------
96+
jacobian : 2d array
97+
Jacobian matrix, with shape [len(xs), n_params].
98+
"""
99+
100+
jacobian = np.ones((len(xs), len(params)))
101+
jacobian[:, 1] = -np.log10(xs)
102+
103+
return jacobian

specparam/objs/fit.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@
7070
from specparam.core.modutils import copy_doc_func_to_method
7171
from specparam.core.utils import group_three, check_array_dim
7272
from specparam.core.funcs import gaussian_function, get_ap_func, infer_ap_func
73+
from specparam.core.jacobians import jacobian_gauss
7374
from specparam.core.errors import (FitError, NoModelError, DataError,
7475
NoDataError, InconsistentDataError)
7576
from specparam.core.strings import (gen_settings_str, gen_model_results_str,
@@ -191,12 +192,17 @@ def __init__(self, peak_width_limits=(0.5, 12.0), max_n_peaks=np.inf, min_peak_h
191192
self._gauss_overlap_thresh = 0.75
192193
# Parameter bounds for center frequency when fitting gaussians, in terms of +/- std dev
193194
self._cf_bound = 1.5
194-
# The maximum number of calls to the curve fitting function
195-
self._maxfev = 5000
196195
# The error metric to calculate, post model fitting. See `_calc_error` for options
197196
# Note: this is for checking error post fitting, not an objective function for fitting
198197
self._error_metric = 'MAE'
199198

199+
## PRIVATE CURVE_FIT SETTINGS
200+
# The maximum number of calls to the curve fitting function
201+
self._maxfev = 5000
202+
# The tolerance setting for curve fitting (see scipy.curve_fit - ftol / xtol / gtol)
203+
# Here reduce tolerance to speed fitting. Set value to 1e-8 to match curve_fit default
204+
self._tol = 0.00001
205+
200206
## RUN MODES
201207
# Set default debug mode - controls if an error is raised if model fitting is unsuccessful
202208
self._debug = False
@@ -400,7 +406,7 @@ def report(self, freqs=None, power_spectrum=None, freq_range=None,
400406
Only relevant / effective if `freqs` and `power_spectrum` passed in in this call.
401407
**plot_kwargs
402408
Keyword arguments to pass into the plot method.
403-
Plot options with a name conflict be passed by pre-pending 'plot_'.
409+
Plot options with a name conflict be passed by pre-pending `plot_`.
404410
e.g. `freqs`, `power_spectrum` and `freq_range`.
405411
406412
Notes
@@ -921,7 +927,9 @@ def _simple_ap_fit(self, freqs, power_spectrum):
921927
warnings.simplefilter("ignore")
922928
aperiodic_params, _ = curve_fit(get_ap_func(self.aperiodic_mode),
923929
freqs, power_spectrum, p0=guess,
924-
maxfev=self._maxfev, bounds=ap_bounds)
930+
maxfev=self._maxfev, bounds=ap_bounds,
931+
ftol=self._tol, xtol=self._tol, gtol=self._tol,
932+
check_finite=False)
925933
except RuntimeError as excp:
926934
error_msg = ("Model fitting failed due to not finding parameters in "
927935
"the simple aperiodic component fit.")
@@ -978,7 +986,9 @@ def _robust_ap_fit(self, freqs, power_spectrum):
978986
warnings.simplefilter("ignore")
979987
aperiodic_params, _ = curve_fit(get_ap_func(self.aperiodic_mode),
980988
freqs_ignore, spectrum_ignore, p0=popt,
981-
maxfev=self._maxfev, bounds=ap_bounds)
989+
maxfev=self._maxfev, bounds=ap_bounds,
990+
ftol=self._tol, xtol=self._tol, gtol=self._tol,
991+
check_finite=False)
982992
except RuntimeError as excp:
983993
error_msg = ("Model fitting failed due to not finding "
984994
"parameters in the robust aperiodic fit.")
@@ -1124,7 +1134,9 @@ def _fit_peak_guess(self, guess):
11241134
# Fit the peaks
11251135
try:
11261136
gaussian_params, _ = curve_fit(gaussian_function, self.freqs, self._spectrum_flat,
1127-
p0=guess, maxfev=self._maxfev, bounds=gaus_param_bounds)
1137+
p0=guess, maxfev=self._maxfev, bounds=gaus_param_bounds,
1138+
ftol=self._tol, xtol=self._tol, gtol=self._tol,
1139+
check_finite=False, jac=jacobian_gauss)
11281140
except RuntimeError as excp:
11291141
error_msg = ("Model fitting failed due to not finding "
11301142
"parameters in the peak component fit.")

specparam/plts/aperiodic.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def plot_aperiodic_params(aps, colors=None, labels=None, ax=None, **plot_kwargs)
3434
ax : matplotlib.Axes, optional
3535
Figure axes upon which to plot.
3636
**plot_kwargs
37-
Keyword arguments to pass into the ``style_plot``.
37+
Additional plot related keyword arguments, with styling options managed by ``style_plot``.
3838
"""
3939

4040
ax = check_ax(ax, plot_kwargs.pop('figsize', PLT_FIGSIZES['params']))
@@ -94,7 +94,7 @@ def plot_aperiodic_fits(aps, freq_range, control_offset=False,
9494
ax : matplotlib.Axes, optional
9595
Figure axes upon which to plot.
9696
**plot_kwargs
97-
Keyword arguments to pass into the ``style_plot``.
97+
Additional plot related keyword arguments, with styling options managed by ``style_plot``.
9898
"""
9999

100100
ax = check_ax(ax, plot_kwargs.pop('figsize', PLT_FIGSIZES['params']))

specparam/plts/error.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def plot_spectral_error(freqs, error, shade=None, log_freqs=False, ax=None, **pl
3333
ax : matplotlib.Axes, optional
3434
Figure axes upon which to plot.
3535
**plot_kwargs
36-
Keyword arguments to pass into the ``style_plot``.
36+
Additional plot related keyword arguments, with styling options managed by ``style_plot``.
3737
"""
3838

3939
ax = check_ax(ax, plot_kwargs.pop('figsize', PLT_FIGSIZES['spectral']))

specparam/plts/group.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def plot_group(group, **plot_kwargs):
2828
group : SpectralGroupModel
2929
Object containing results from fitting a group of power spectra.
3030
**plot_kwargs
31-
Keyword arguments to apply to the plot.
31+
Additional plot related keyword arguments, with styling options managed by ``style_plot``.
3232
3333
Raises
3434
------
@@ -72,7 +72,7 @@ def plot_group_aperiodic(group, ax=None, **plot_kwargs):
7272
ax : matplotlib.Axes, optional
7373
Figure axes upon which to plot.
7474
**plot_kwargs
75-
Keyword arguments to pass into the ``style_plot``.
75+
Additional plot related keyword arguments, with styling options managed by ``style_plot``.
7676
"""
7777

7878
if group.aperiodic_mode == 'knee':
@@ -97,7 +97,7 @@ def plot_group_goodness(group, ax=None, **plot_kwargs):
9797
ax : matplotlib.Axes, optional
9898
Figure axes upon which to plot.
9999
**plot_kwargs
100-
Keyword arguments to pass into the ``style_plot``.
100+
Additional plot related keyword arguments, with styling options managed by ``style_plot``.
101101
"""
102102

103103
plot_scatter_2(group.get_params('error'), 'Error',
@@ -117,7 +117,7 @@ def plot_group_peak_frequencies(group, ax=None, **plot_kwargs):
117117
ax : matplotlib.Axes, optional
118118
Figure axes upon which to plot.
119119
**plot_kwargs
120-
Keyword arguments to pass into the ``style_plot``.
120+
Additional plot related keyword arguments, with styling options managed by ``style_plot``.
121121
"""
122122

123123
plot_hist(group.get_params('peak_params', 0)[:, 0], 'Center Frequency',

specparam/plts/model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def plot_model(model, plot_peaks=None, plot_aperiodic=True, freqs=None, power_sp
5656
data_kwargs, model_kwargs, aperiodic_kwargs, peak_kwargs : None or dict, optional
5757
Keyword arguments to pass into the plot call for each plot element.
5858
**plot_kwargs
59-
Keyword arguments to apply to the plot.
59+
Additional plot related keyword arguments, with styling options managed by ``style_plot``.
6060
6161
Notes
6262
-----
@@ -163,7 +163,7 @@ def _add_peaks_shade(model, plt_log, ax, **plot_kwargs):
163163
ax : matplotlib.Axes
164164
Figure axes upon which to plot.
165165
**plot_kwargs
166-
Keyword arguments to pass into the ``fill_between``.
166+
Keyword arguments to pass into ``fill_between``.
167167
"""
168168

169169
defaults = {'color' : PLT_COLORS['periodic'], 'alpha' : 0.25}

specparam/plts/periodic.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def plot_peak_params(peaks, freq_range=None, colors=None, labels=None, ax=None,
3636
ax : matplotlib.Axes, optional
3737
Figure axes upon which to plot.
3838
**plot_kwargs
39-
Keyword arguments to pass into the ``style_plot``.
39+
Additional plot related keyword arguments, with styling options managed by ``style_plot``.
4040
"""
4141

4242
ax = check_ax(ax, plot_kwargs.pop('figsize', PLT_FIGSIZES['params']))
@@ -97,7 +97,7 @@ def plot_peak_fits(peaks, freq_range=None, average='mean', shade='sem', plot_ind
9797
ax : matplotlib.Axes, optional
9898
Figure axes upon which to plot.
9999
**plot_kwargs
100-
Keyword arguments to pass into the plot call.
100+
Additional plot related keyword arguments, with styling options managed by ``style_plot``.
101101
"""
102102

103103
ax = check_ax(ax, plot_kwargs.pop('figsize', PLT_FIGSIZES['params']))

specparam/plts/settings.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,8 @@
4646
'linestyle' : ['ls', 'linestyle']}
4747

4848
# Plot style arguments are those that can be defined on an axis object
49-
AXIS_STYLE_ARGS = ['title', 'xlabel', 'ylabel', 'xlim', 'ylim']
49+
AXIS_STYLE_ARGS = ['title', 'xlabel', 'ylabel', 'xlim', 'ylim',
50+
'xticks', 'yticks', 'xticklabels', 'yticklabels']
5051

5152
# Line style arguments are those that can be defined on a line object
5253
LINE_STYLE_ARGS = ['alpha', 'lw', 'linewidth', 'ls', 'linestyle',
@@ -58,8 +59,13 @@
5859
# Custom style arguments are those that are custom-handled by the plot style function
5960
CUSTOM_STYLE_ARGS = ['title_fontsize', 'label_size', 'tick_labelsize',
6061
'legend_size', 'legend_loc']
61-
STYLERS = ['axis_styler', 'line_styler', 'custom_styler']
62-
STYLE_ARGS = AXIS_STYLE_ARGS + LINE_STYLE_ARGS + CUSTOM_STYLE_ARGS + STYLERS
62+
63+
# Define list of available style functions - these can also be replaced by arguments
64+
STYLERS = ['axis_styler', 'line_styler', 'collection_styler', 'custom_styler']
65+
66+
# Collect the full set of possible style related input keyword arguments
67+
STYLE_ARGS = \
68+
AXIS_STYLE_ARGS + LINE_STYLE_ARGS + COLLECTION_STYLE_ARGS + CUSTOM_STYLE_ARGS + STYLERS
6369

6470
## Define default values for plot aesthetics
6571
# These are all custom style arguments

0 commit comments

Comments
 (0)