Skip to content

Commit 849d9db

Browse files
committed
custom callables fix
1 parent d78495e commit 849d9db

File tree

2 files changed

+25
-6
lines changed

2 files changed

+25
-6
lines changed

fooof/plts/spectra.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
This file contains functions for plotting power spectra, that take in data directly.
66
"""
77

8+
from inspect import isfunction
89
from itertools import repeat, cycle
910

1011
import numpy as np
@@ -209,7 +210,7 @@ def plot_spectra_yshade(freqs, power_spectra, shade='std', average='mean', scale
209210
Keyword arguments to be passed to `plot_spectra` or to the plot call.
210211
"""
211212

212-
if isinstance(shade, str) and power_spectra.ndim != 2:
213+
if (isinstance(shade, str) or isfunction(shade)) and power_spectra.ndim != 2:
213214
raise ValueError('Power spectra must be 2d if shade is not given.')
214215

215216
ax = check_ax(ax, plot_kwargs.pop('figsize', PLT_FIGSIZES['spectral']))
@@ -220,17 +221,27 @@ def plot_spectra_yshade(freqs, power_spectra, shade='std', average='mean', scale
220221

221222
# Organize mean spectrum to plot
222223
avg_funcs = {'mean' : np.mean, 'median' : np.median}
223-
avg_func = avg_funcs[average] if isinstance(average, str) else average
224-
avg_powers = avg_func(plt_powers, axis=0) if plt_powers.ndim == 2 else plt_powers
224+
225+
if isinstance(average, str) and plt_powers.ndim == 2:
226+
avg_powers = avg_funcs[average](plt_powers, axis=0)
227+
elif isfunction(average) and plt_powers.ndim == 2:
228+
avg_powers = average(plt_powers)
229+
else:
230+
avg_powers = plt_powers
225231

226232
# Plot average power spectrum
227233
ax.plot(plt_freqs, avg_powers, linewidth=2.0, color=color, label=label)
228234

229235
# Organize shading to plot
230236
shade_funcs = {'std' : np.std, 'sem' : sem}
231-
shade_func = shade_funcs[shade] if isinstance(shade, str) else shade
232-
shade_vals = scale * shade_func(plt_powers, axis=0) \
233-
if isinstance(shade, str) else scale * shade
237+
238+
if isinstance(shade, str):
239+
shade_vals = scale * shade_funcs[shade](plt_powers, axis=0)
240+
elif isfunction(shade):
241+
shade_vals = scale * shade(plt_powers)
242+
else:
243+
shade_vals = scale * shade
244+
234245
upper_shade = avg_powers + shade_vals
235246
lower_shade = avg_powers - shade_vals
236247

fooof/tests/plts/test_spectra.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,3 +87,11 @@ def test_plot_spectra_yshade(skip_if_no_mpl, tfg):
8787
plot_spectra_yshade(freqs, powers, shade='sem', average='median',
8888
save_fig=True, file_path=TEST_PLOTS_PATH,
8989
file_name='test_plot_spectra_yshade3.png')
90+
91+
# Plot shade with custom average and shade callables
92+
def _average_callable(powers): return np.mean(powers, axis=0)
93+
def _shade_callable(powers): return np.std(powers, axis=0)
94+
95+
plot_spectra_yshade(freqs, powers, shade=_shade_callable, average=_average_callable,
96+
log_powers=True, save_fig=True, file_path=TEST_PLOTS_PATH,
97+
file_name='test_plot_spectra_yshade4.png')

0 commit comments

Comments
 (0)