Skip to content

Commit 0f5d2fa

Browse files
authored
Merge branch 'main' into plts
2 parents 3541e51 + 083cc08 commit 0f5d2fa

26 files changed

+592
-254
lines changed

fooof/core/utils.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,13 @@ def group_three(vec):
1313
1414
Parameters
1515
----------
16-
vec : 1d array
17-
Array of items to group by 3. Length of array must be divisible by three.
16+
vec : list or 1d array
17+
List or array of items to group by 3. Length of array must be divisible by three.
1818
1919
Returns
2020
-------
21-
list of list
22-
List of lists, each with three items.
21+
array or list of list
22+
Array or list of lists, each with three items. Output type will match input type.
2323
2424
Raises
2525
------
@@ -30,7 +30,11 @@ def group_three(vec):
3030
if len(vec) % 3 != 0:
3131
raise ValueError("Wrong size array to group by three.")
3232

33-
return [list(vec[ii:ii+3]) for ii in range(0, len(vec), 3)]
33+
# Reshape, if an array, as it's faster, otherwise asssume lise
34+
if isinstance(vec, np.ndarray):
35+
return np.reshape(vec, (-1, 3))
36+
else:
37+
return [list(vec[ii:ii+3]) for ii in range(0, len(vec), 3)]
3438

3539

3640
def nearest_ind(array, value):

fooof/objs/fit.py

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,6 @@
7474
gen_issue_str, gen_width_warning_str)
7575

7676
from fooof.plts.fm import plot_fm
77-
from fooof.plts.style import style_spectrum_plot
7877
from fooof.utils.data import trim_spectrum
7978
from fooof.utils.params import compute_gauss_std
8079
from fooof.data import FOOOFResults, FOOOFSettings, FOOOFMetaData
@@ -633,12 +632,13 @@ def get_results(self):
633632
@copy_doc_func_to_method(plot_fm)
634633
def plot(self, plot_peaks=None, plot_aperiodic=True, plt_log=False,
635634
add_legend=True, save_fig=False, file_name=None, file_path=None,
636-
ax=None, plot_style=style_spectrum_plot,
637-
data_kwargs=None, model_kwargs=None, aperiodic_kwargs=None, peak_kwargs=None):
635+
ax=None, data_kwargs=None, model_kwargs=None,
636+
aperiodic_kwargs=None, peak_kwargs=None, **plot_kwargs):
638637

639-
plot_fm(self, plot_peaks, plot_aperiodic, plt_log, add_legend,
640-
save_fig, file_name, file_path, ax, plot_style,
641-
data_kwargs, model_kwargs, aperiodic_kwargs, peak_kwargs)
638+
plot_fm(self, plot_peaks=plot_peaks, plot_aperiodic=plot_aperiodic, plt_log=plt_log,
639+
add_legend=add_legend, save_fig=save_fig, file_name=file_name,
640+
file_path=file_path, ax=ax, data_kwargs=data_kwargs, model_kwargs=model_kwargs,
641+
aperiodic_kwargs=aperiodic_kwargs, peak_kwargs=peak_kwargs, **plot_kwargs)
642642

643643

644644
@copy_doc_func_to_method(save_report_fm)
@@ -1005,18 +1005,16 @@ def _create_peak_params(self, gaus_params):
10051005
with `freqs`, `fooofed_spectrum_` and `_ap_fit` all required to be available.
10061006
"""
10071007

1008-
peak_params = np.empty([0, 3])
1008+
peak_params = np.empty((len(gaus_params), 3))
10091009

10101010
for ii, peak in enumerate(gaus_params):
10111011

10121012
# Gets the index of the power_spectrum at the frequency closest to the CF of the peak
1013-
ind = min(range(len(self.freqs)), key=lambda ii: abs(self.freqs[ii] - peak[0]))
1013+
ind = np.argmin(np.abs(self.freqs - peak[0]))
10141014

10151015
# Collect peak parameter data
1016-
peak_params = np.vstack((peak_params,
1017-
[peak[0],
1018-
self.fooofed_spectrum_[ind] - self._ap_fit[ind],
1019-
peak[2] * 2]))
1016+
peak_params[ii] = [peak[0], self.fooofed_spectrum_[ind] - self._ap_fit[ind],
1017+
peak[2] * 2]
10201018

10211019
return peak_params
10221020

@@ -1035,8 +1033,8 @@ def _drop_peak_cf(self, guess):
10351033
Guess parameters for gaussian peak fits. Shape: [n_peaks, 3].
10361034
"""
10371035

1038-
cf_params = [item[0] for item in guess]
1039-
bw_params = [item[2] * self._bw_std_edge for item in guess]
1036+
cf_params = guess[:, 0]
1037+
bw_params = guess[:, 2] * self._bw_std_edge
10401038

10411039
# Check if peaks within drop threshold from the edge of the frequency range
10421040
keep_peak = \

fooof/objs/group.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -396,9 +396,9 @@ def get_params(self, name, col=None):
396396

397397

398398
@copy_doc_func_to_method(plot_fg)
399-
def plot(self, save_fig=False, file_name=None, file_path=None):
399+
def plot(self, save_fig=False, file_name=None, file_path=None, **plot_kwargs):
400400

401-
plot_fg(self, save_fig, file_name, file_path)
401+
plot_fg(self, save_fig=save_fig, file_name=file_name, file_path=file_path, **plot_kwargs)
402402

403403

404404
@copy_doc_func_to_method(save_report_fg)

fooof/objs/utils.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -219,9 +219,14 @@ def fit_fooof_3d(fg, freqs, power_spectra, freq_range=None, n_jobs=1):
219219
>>> fgs = fit_fooof_3d(fg, freqs, power_spectra, freq_range=[3, 30]) # doctest:+SKIP
220220
"""
221221

222-
fgs = []
223-
for cond_spectra in power_spectra:
224-
fg.fit(freqs, cond_spectra, freq_range, n_jobs)
225-
fgs.append(fg.copy())
222+
# Reshape 3d data to 2d and fit, in order to fit with a single group model object
223+
shape = np.shape(power_spectra)
224+
powers_2d = np.reshape(power_spectra, (shape[0] * shape[1], shape[2]))
225+
226+
fg.fit(freqs, powers_2d, freq_range, n_jobs)
227+
228+
# Reorganize 2d results into a list of model group objects, to reflect original shape
229+
fgs = [fg.get_group(range(dim_a * shape[1], (dim_a + 1) * shape[1])) \
230+
for dim_a in range(shape[0])]
226231

227232
return fgs

fooof/plts/annotate.py

Lines changed: 22 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -6,30 +6,31 @@
66
from fooof.core.errors import NoModelError
77
from fooof.core.funcs import gaussian_function
88
from fooof.core.modutils import safe_import, check_dependency
9+
910
from fooof.sim.gen import gen_aperiodic
10-
from fooof.plts.utils import check_ax
11-
from fooof.plts.spectra import plot_spectra
12-
from fooof.plts.settings import PLT_FIGSIZES, PLT_COLORS
13-
from fooof.plts.style import check_n_style, style_spectrum_plot
1411
from fooof.analysis.periodic import get_band_peak_fm
1512
from fooof.utils.params import compute_knee_frequency, compute_fwhm
1613

14+
from fooof.plts.spectra import plot_spectra
15+
from fooof.plts.utils import check_ax, savefig
16+
from fooof.plts.settings import PLT_FIGSIZES, PLT_COLORS
17+
from fooof.plts.style import style_spectrum_plot
18+
1719
plt = safe_import('.pyplot', 'matplotlib')
1820
mpatches = safe_import('.patches', 'matplotlib')
1921

2022
###################################################################################################
2123
###################################################################################################
2224

25+
@savefig
2326
@check_dependency(plt, 'matplotlib')
24-
def plot_annotated_peak_search(fm, plot_style=style_spectrum_plot):
27+
def plot_annotated_peak_search(fm):
2528
"""Plot a series of plots illustrating the peak search from a flattened spectrum.
2629
2730
Parameters
2831
----------
2932
fm : FOOOF
3033
FOOOF object, with model fit, data and settings available.
31-
plot_style : callable, optional, default: style_spectrum_plot
32-
A function to call to apply styling & aesthetics to the plots.
3334
"""
3435

3536
# Recalculate the initial aperiodic fit and flattened spectrum that
@@ -46,14 +47,12 @@ def plot_annotated_peak_search(fm, plot_style=style_spectrum_plot):
4647
# This forces the creation of a new plotting axes per iteration
4748
ax = check_ax(None, PLT_FIGSIZES['spectral'])
4849

49-
plot_spectra(fm.freqs, flatspec, ax=ax, plot_style=None,
50-
label='Flattened Spectrum', color=PLT_COLORS['data'], linewidth=2.5)
51-
plot_spectra(fm.freqs, [fm.peak_threshold * np.std(flatspec)]*len(fm.freqs),
52-
ax=ax, plot_style=None, label='Relative Threshold',
53-
color='orange', linewidth=2.5, linestyle='dashed')
54-
plot_spectra(fm.freqs, [fm.min_peak_height]*len(fm.freqs),
55-
ax=ax, plot_style=None, label='Absolute Threshold',
56-
color='red', linewidth=2.5, linestyle='dashed')
50+
plot_spectra(fm.freqs, flatspec, ax=ax, linewidth=2.5,
51+
label='Flattened Spectrum', color=PLT_COLORS['data'])
52+
plot_spectra(fm.freqs, [fm.peak_threshold * np.std(flatspec)]*len(fm.freqs), ax=ax,
53+
label='Relative Threshold', color='orange', linewidth=2.5, linestyle='dashed')
54+
plot_spectra(fm.freqs, [fm.min_peak_height]*len(fm.freqs), ax=ax,
55+
label='Absolute Threshold', color='red', linewidth=2.5, linestyle='dashed')
5756

5857
maxi = np.argmax(flatspec)
5958
ax.plot(fm.freqs[maxi], flatspec[maxi], '.',
@@ -65,18 +64,18 @@ def plot_annotated_peak_search(fm, plot_style=style_spectrum_plot):
6564
if ind < fm.n_peaks_:
6665

6766
gauss = gaussian_function(fm.freqs, *fm.gaussian_params_[ind, :])
68-
plot_spectra(fm.freqs, gauss, ax=ax, plot_style=None,
69-
label='Gaussian Fit', color=PLT_COLORS['periodic'],
70-
linestyle=':', linewidth=3.0)
67+
plot_spectra(fm.freqs, gauss, ax=ax, label='Gaussian Fit',
68+
color=PLT_COLORS['periodic'], linestyle=':', linewidth=3.0)
7169

7270
flatspec = flatspec - gauss
7371

74-
check_n_style(plot_style, ax, False, True)
72+
style_spectrum_plot(ax, False, True)
7573

7674

75+
@savefig
7776
@check_dependency(plt, 'matplotlib')
78-
def plot_annotated_model(fm, plt_log=False, annotate_peaks=True, annotate_aperiodic=True,
79-
ax=None, plot_style=style_spectrum_plot):
77+
def plot_annotated_model(fm, plt_log=False, annotate_peaks=True,
78+
annotate_aperiodic=True, ax=None):
8079
"""Plot a an annotated power spectrum and model, from a FOOOF object.
8180
8281
Parameters
@@ -91,8 +90,6 @@ def plot_annotated_model(fm, plt_log=False, annotate_peaks=True, annotate_aperio
9190
Whether to annotate the aperiodic components of the model fit.
9291
ax : matplotlib.Axes, optional
9392
Figure axes upon which to plot.
94-
plot_style : callable, optional, default: style_spectrum_plot
95-
A function to call to apply styling & aesthetics to the plots.
9693
9794
Raises
9895
------
@@ -112,7 +109,7 @@ def plot_annotated_model(fm, plt_log=False, annotate_peaks=True, annotate_aperio
112109

113110
# Create the baseline figure
114111
ax = check_ax(ax, PLT_FIGSIZES['spectral'])
115-
fm.plot(plot_peaks='dot-shade-width', plt_log=plt_log, ax=ax, plot_style=None,
112+
fm.plot(plot_peaks='dot-shade-width', plt_log=plt_log, ax=ax,
116113
data_kwargs={'lw' : lw1, 'alpha' : 0.6},
117114
aperiodic_kwargs={'lw' : lw1, 'zorder' : 10},
118115
model_kwargs={'lw' : lw1, 'alpha' : 0.5},
@@ -219,7 +216,7 @@ def plot_annotated_model(fm, plt_log=False, annotate_peaks=True, annotate_aperio
219216
color=PLT_COLORS['aperiodic'], fontsize=fontsize)
220217

221218
# Apply style to plot & tune grid styling
222-
check_n_style(plot_style, ax, plt_log, True)
219+
style_spectrum_plot(ax, plt_log, True)
223220
ax.grid(True, alpha=0.5)
224221

225222
# Add labels to plot in the legend

fooof/plts/aperiodic.py

Lines changed: 20 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -3,21 +3,23 @@
33
from itertools import cycle
44

55
import numpy as np
6+
import matplotlib.pyplot as plt
67

78
from fooof.sim.gen import gen_freqs, gen_aperiodic
89
from fooof.core.modutils import safe_import, check_dependency
910
from fooof.plts.settings import PLT_FIGSIZES
10-
from fooof.plts.style import check_n_style, style_param_plot
11-
from fooof.plts.utils import check_ax, recursive_plot, check_plot_kwargs
11+
from fooof.plts.style import style_param_plot, style_plot
12+
from fooof.plts.utils import check_ax, recursive_plot, savefig, check_plot_kwargs
1213

1314
plt = safe_import('.pyplot', 'matplotlib')
1415

1516
###################################################################################################
1617
###################################################################################################
1718

19+
@savefig
20+
@style_plot
1821
@check_dependency(plt, 'matplotlib')
19-
def plot_aperiodic_params(aps, colors=None, labels=None,
20-
ax=None, plot_style=style_param_plot, **plot_kwargs):
22+
def plot_aperiodic_params(aps, colors=None, labels=None, ax=None, **plot_kwargs):
2123
"""Plot aperiodic parameters as dots representing offset and exponent value.
2224
2325
Parameters
@@ -30,38 +32,38 @@ def plot_aperiodic_params(aps, colors=None, labels=None,
3032
Label(s) for plotted data, to be added in a legend.
3133
ax : matplotlib.Axes, optional
3234
Figure axes upon which to plot.
33-
plot_style : callable, optional, default: style_param_plot
34-
A function to call to apply styling & aesthetics to the plot.
3535
**plot_kwargs
36-
Keyword arguments to pass into the plot call.
36+
Keyword arguments to pass into the ``style_plot``.
3737
"""
3838

3939
ax = check_ax(ax, plot_kwargs.pop('figsize', PLT_FIGSIZES['params']))
4040

4141
if isinstance(aps, list):
42-
recursive_plot(aps, plot_aperiodic_params, ax, colors=colors, labels=labels,
43-
plot_style=plot_style, **plot_kwargs)
42+
recursive_plot(aps, plot_aperiodic_params, ax, colors=colors, labels=labels)
4443

4544
else:
4645

4746
# Unpack data: offset as x; exponent as y
4847
xs, ys = aps[:, 0], aps[:, -1]
4948
sizes = plot_kwargs.pop('s', 150)
5049

50+
# Create the plot
5151
plot_kwargs = check_plot_kwargs(plot_kwargs, {'alpha' : 0.7})
5252
ax.scatter(xs, ys, sizes, c=colors, label=labels, **plot_kwargs)
5353

5454
# Add axis labels
5555
ax.set_xlabel('Offset')
5656
ax.set_ylabel('Exponent')
5757

58-
check_n_style(plot_style, ax)
58+
style_param_plot(ax)
5959

6060

61+
@savefig
62+
@style_plot
6163
@check_dependency(plt, 'matplotlib')
6264
def plot_aperiodic_fits(aps, freq_range, control_offset=False,
6365
log_freqs=False, colors=None, labels=None,
64-
ax=None, plot_style=style_param_plot, **plot_kwargs):
66+
ax=None, **plot_kwargs):
6567
"""Plot reconstructions of model aperiodic fits.
6668
6769
Parameters
@@ -80,10 +82,8 @@ def plot_aperiodic_fits(aps, freq_range, control_offset=False,
8082
Label(s) for plotted data, to be added in a legend.
8183
ax : matplotlib.Axes, optional
8284
Figure axes upon which to plot.
83-
plot_style : callable, optional, default: style_param_plot
84-
A function to call to apply styling & aesthetics to the plot.
8585
**plot_kwargs
86-
Keyword arguments to pass into the plot call.
86+
Keyword arguments to pass into the ``style_plot``.
8787
"""
8888

8989
ax = check_ax(ax, plot_kwargs.pop('figsize', PLT_FIGSIZES['params']))
@@ -93,11 +93,9 @@ def plot_aperiodic_fits(aps, freq_range, control_offset=False,
9393
if not colors:
9494
colors = cycle(plt.rcParams['axes.prop_cycle'].by_key()['color'])
9595

96-
recursive_plot(aps, plot_function=plot_aperiodic_fits, ax=ax,
97-
freq_range=tuple(freq_range),
98-
control_offset=control_offset,
99-
log_freqs=log_freqs, colors=colors, labels=labels,
100-
plot_style=plot_style, **plot_kwargs)
96+
recursive_plot(aps, plot_aperiodic_fits, ax=ax, freq_range=tuple(freq_range),
97+
control_offset=control_offset, log_freqs=log_freqs, colors=colors,
98+
labels=labels, **plot_kwargs)
10199
else:
102100

103101
freqs = gen_freqs(freq_range, 0.1)
@@ -118,17 +116,15 @@ def plot_aperiodic_fits(aps, freq_range, control_offset=False,
118116
# Recreate & plot the aperiodic component from parameters
119117
ap_vals = gen_aperiodic(freqs, ap_params)
120118

121-
plot_kwargs = check_plot_kwargs(plot_kwargs, {'alpha' : 0.35, 'linewidth' : 1.25})
122-
ax.plot(plt_freqs, ap_vals, color=colors, **plot_kwargs)
119+
ax.plot(plt_freqs, ap_vals, color=colors, alpha=0.35, linewidth=1.25)
123120

124121
# Collect a running average across components
125122
avg_vals = np.nansum(np.vstack([avg_vals, ap_vals]), axis=0)
126123

127124
# Plot the average component
128125
avg = avg_vals / aps.shape[0]
129126
avg_color = 'black' if not colors else colors
130-
ax.plot(plt_freqs, avg, linewidth=plot_kwargs.get('linewidth')*3,
131-
color=avg_color, label=labels)
127+
ax.plot(plt_freqs, avg, linewidth=3.75, color=avg_color, label=labels)
132128

133129
# Add axis labels
134130
ax.set_xlabel('log(Frequency)' if log_freqs else 'Frequency')
@@ -137,5 +133,4 @@ def plot_aperiodic_fits(aps, freq_range, control_offset=False,
137133
# Set plot limit
138134
ax.set_xlim(np.log10(freq_range) if log_freqs else freq_range)
139135

140-
# Apply plot style
141-
check_n_style(plot_style, ax)
136+
style_param_plot(ax)

0 commit comments

Comments
 (0)