Skip to content

Commit d78495e

Browse files
committed
update args for average & shade
1 parent 2ec52b1 commit d78495e

File tree

2 files changed

+42
-20
lines changed

2 files changed

+42
-20
lines changed

fooof/plts/spectra.py

Lines changed: 31 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from itertools import repeat, cycle
99

1010
import numpy as np
11+
from scipy.stats import sem
1112

1213
from fooof.core.modutils import safe_import, check_dependency
1314
from fooof.plts.settings import PLT_FIGSIZES
@@ -177,8 +178,9 @@ def plot_spectra_shading(freqs, power_spectra, shades, shade_colors='r',
177178
@savefig
178179
@style_plot
179180
@check_dependency(plt, 'matplotlib')
180-
def plot_spectra_yshade(freqs, power_spectra, shade=None, scale=1, log_freqs=False,
181-
log_powers=False, ax=None, **plot_kwargs):
181+
def plot_spectra_yshade(freqs, power_spectra, shade='std', average='mean', scale=1,
182+
log_freqs=False, log_powers=False, color=None, label=None,
183+
ax=None, **plot_kwargs):
182184
"""Plot standard deviation or error as a shaded region around the mean spectrum.
183185
184186
Parameters
@@ -187,23 +189,27 @@ def plot_spectra_yshade(freqs, power_spectra, shade=None, scale=1, log_freqs=Fal
187189
Frequency values, to be plotted on the x-axis.
188190
power_spectra : 1d or 2d array
189191
Power values, to be plotted on the y-axis. ``shade`` must be provided if 1d.
190-
shade : 1d array, optional, default: None
191-
Powers to shade above/below the mean spectrum. None defaults to one standard deviation.
192+
shade : 'std', 'sem', 1d array or callable, optional, default: 'std'
193+
Approach for shading above/below the mean spectrum.
194+
average : 'mean', 'median' or callable, optional, default: 'mean'
195+
Averaging approach for the average spectrum to plot. Only used if power_spectra is 2d.
192196
scale : int, optional, default: 1
193-
Factor to multiply the the standard deviation, or ``shade``, by.
197+
Factor to multiply the plotted shade by.
194198
log_freqs : bool, optional, default: False
195199
Whether to plot the frequency axis in log spacing.
196200
log_powers : bool, optional, default: False
197201
Whether to plot the power axis in log spacing.
202+
color : str, optional, default: None
203+
Line color of the spectrum.
204+
label : str, optional, default: None
205+
Legend label for the spectrum.
198206
ax : matplotlib.Axes, optional
199207
Figure axes upon which to plot.
200-
plot_style : callable, optional, default: style_spectrum_plot
201-
A function to call to apply styling & aesthetics to the plot.
202208
**plot_kwargs
203209
Keyword arguments to be passed to `plot_spectra` or to the plot call.
204210
"""
205211

206-
if shade is None and power_spectra.ndim != 2:
212+
if isinstance(shade, str) and power_spectra.ndim != 2:
207213
raise ValueError('Power spectra must be 2d if shade is not given.')
208214

209215
ax = check_ax(ax, plot_kwargs.pop('figsize', PLT_FIGSIZES['spectral']))
@@ -212,16 +218,25 @@ def plot_spectra_yshade(freqs, power_spectra, shade=None, scale=1, log_freqs=Fal
212218
plt_freqs = np.log10(freqs) if log_freqs else freqs
213219
plt_powers = np.log10(power_spectra) if log_powers else power_spectra
214220

215-
# Plot mean
216-
powers_mean = np.mean(plt_powers, axis=0) if plt_powers.ndim == 2 else plt_powers
217-
ax.plot(plt_freqs, powers_mean)
221+
# Organize mean spectrum to plot
222+
avg_funcs = {'mean' : np.mean, 'median' : np.median}
223+
avg_func = avg_funcs[average] if isinstance(average, str) else average
224+
avg_powers = avg_func(plt_powers, axis=0) if plt_powers.ndim == 2 else plt_powers
225+
226+
# Plot average power spectrum
227+
ax.plot(plt_freqs, avg_powers, linewidth=2.0, color=color, label=label)
218228

219-
# Shade +/- scale * (standard deviation or shade)
220-
shade = scale * np.std(plt_powers, axis=0) if shade is None else scale * shade
221-
upper_shade = powers_mean + shade
222-
lower_shade = powers_mean - shade
229+
# Organize shading to plot
230+
shade_funcs = {'std' : np.std, 'sem' : sem}
231+
shade_func = shade_funcs[shade] if isinstance(shade, str) else shade
232+
shade_vals = scale * shade_func(plt_powers, axis=0) \
233+
if isinstance(shade, str) else scale * shade
234+
upper_shade = avg_powers + shade_vals
235+
lower_shade = avg_powers - shade_vals
223236

237+
# Plot +/- yshading around spectrum
224238
alpha = plot_kwargs.pop('alpha', 0.25)
225-
ax.fill_between(plt_freqs, lower_shade, upper_shade, alpha=alpha, **plot_kwargs)
239+
ax.fill_between(plt_freqs, lower_shade, upper_shade,
240+
alpha=alpha, color=color, **plot_kwargs)
226241

227242
style_spectrum_plot(ax, log_freqs, log_powers)

fooof/tests/plts/test_spectra.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -72,11 +72,18 @@ def test_plot_spectra_yshade(skip_if_no_mpl, tfg):
7272
with raises(ValueError):
7373
plot_spectra_yshade(freqs, powers[0])
7474

75-
# Valid 1d array with shade
76-
plot_spectra_yshade(freqs, np.mean(powers, axis=0), shade=np.std(powers, axis=0),
75+
# Plot with 2d array
76+
plot_spectra_yshade(freqs, powers, shade='std',
7777
save_fig=True, file_path=TEST_PLOTS_PATH,
7878
file_name='test_plot_spectra_yshade1.png')
7979

80-
# 2d array
81-
plot_spectra_yshade(freqs, powers, save_fig=True, file_path=TEST_PLOTS_PATH,
80+
# Plot shade with given 1d array
81+
plot_spectra_yshade(freqs, np.mean(powers, axis=0),
82+
shade=np.std(powers, axis=0),
83+
save_fig=True, file_path=TEST_PLOTS_PATH,
8284
file_name='test_plot_spectra_yshade2.png')
85+
86+
# Plot shade with different average and shade approaches
87+
plot_spectra_yshade(freqs, powers, shade='sem', average='median',
88+
save_fig=True, file_path=TEST_PLOTS_PATH,
89+
file_name='test_plot_spectra_yshade3.png')

0 commit comments

Comments
 (0)