88from itertools import repeat , cycle
99
1010import numpy as np
11+ from scipy .stats import sem
1112
1213from fooof .core .modutils import safe_import , check_dependency
1314from fooof .plts .settings import PLT_FIGSIZES
@@ -177,8 +178,9 @@ def plot_spectra_shading(freqs, power_spectra, shades, shade_colors='r',
177178@savefig
178179@style_plot
179180@check_dependency (plt , 'matplotlib' )
180- def plot_spectra_yshade (freqs , power_spectra , shade = None , scale = 1 , log_freqs = False ,
181- log_powers = False , ax = None , ** plot_kwargs ):
181+ def plot_spectra_yshade (freqs , power_spectra , shade = 'std' , average = 'mean' , scale = 1 ,
182+ log_freqs = False , log_powers = False , color = None , label = None ,
183+ ax = None , ** plot_kwargs ):
182184 """Plot standard deviation or error as a shaded region around the mean spectrum.
183185
184186 Parameters
@@ -187,23 +189,27 @@ def plot_spectra_yshade(freqs, power_spectra, shade=None, scale=1, log_freqs=Fal
187189 Frequency values, to be plotted on the x-axis.
188190 power_spectra : 1d or 2d array
189191 Power values, to be plotted on the y-axis. ``shade`` must be provided if 1d.
190- shade : 1d array, optional, default: None
191- Powers to shade above/below the mean spectrum. None defaults to one standard deviation.
192+ shade : 'std', 'sem', 1d array or callable, optional, default: 'std'
193+ Approach for shading above/below the mean spectrum.
194+ average : 'mean', 'median' or callable, optional, default: 'mean'
195+ Averaging approach for the average spectrum to plot. Only used if power_spectra is 2d.
192196 scale : int, optional, default: 1
193- Factor to multiply the the standard deviation, or `` shade``, by.
197+ Factor to multiply the plotted shade by.
194198 log_freqs : bool, optional, default: False
195199 Whether to plot the frequency axis in log spacing.
196200 log_powers : bool, optional, default: False
197201 Whether to plot the power axis in log spacing.
202+ color : str, optional, default: None
203+ Line color of the spectrum.
204+ label : str, optional, default: None
205+ Legend label for the spectrum.
198206 ax : matplotlib.Axes, optional
199207 Figure axes upon which to plot.
200- plot_style : callable, optional, default: style_spectrum_plot
201- A function to call to apply styling & aesthetics to the plot.
202208 **plot_kwargs
203209 Keyword arguments to be passed to `plot_spectra` or to the plot call.
204210 """
205211
206- if shade is None and power_spectra .ndim != 2 :
212+ if isinstance ( shade , str ) and power_spectra .ndim != 2 :
207213 raise ValueError ('Power spectra must be 2d if shade is not given.' )
208214
209215 ax = check_ax (ax , plot_kwargs .pop ('figsize' , PLT_FIGSIZES ['spectral' ]))
@@ -212,16 +218,25 @@ def plot_spectra_yshade(freqs, power_spectra, shade=None, scale=1, log_freqs=Fal
212218 plt_freqs = np .log10 (freqs ) if log_freqs else freqs
213219 plt_powers = np .log10 (power_spectra ) if log_powers else power_spectra
214220
215- # Plot mean
216- powers_mean = np .mean (plt_powers , axis = 0 ) if plt_powers .ndim == 2 else plt_powers
217- ax .plot (plt_freqs , powers_mean )
221+ # Organize mean spectrum to plot
222+ avg_funcs = {'mean' : np .mean , 'median' : np .median }
223+ avg_func = avg_funcs [average ] if isinstance (average , str ) else average
224+ avg_powers = avg_func (plt_powers , axis = 0 ) if plt_powers .ndim == 2 else plt_powers
225+
226+ # Plot average power spectrum
227+ ax .plot (plt_freqs , avg_powers , linewidth = 2.0 , color = color , label = label )
218228
219- # Shade +/- scale * (standard deviation or shade)
220- shade = scale * np .std (plt_powers , axis = 0 ) if shade is None else scale * shade
221- upper_shade = powers_mean + shade
222- lower_shade = powers_mean - shade
229+ # Organize shading to plot
230+ shade_funcs = {'std' : np .std , 'sem' : sem }
231+ shade_func = shade_funcs [shade ] if isinstance (shade , str ) else shade
232+ shade_vals = scale * shade_func (plt_powers , axis = 0 ) \
233+ if isinstance (shade , str ) else scale * shade
234+ upper_shade = avg_powers + shade_vals
235+ lower_shade = avg_powers - shade_vals
223236
237+ # Plot +/- yshading around spectrum
224238 alpha = plot_kwargs .pop ('alpha' , 0.25 )
225- ax .fill_between (plt_freqs , lower_shade , upper_shade , alpha = alpha , ** plot_kwargs )
239+ ax .fill_between (plt_freqs , lower_shade , upper_shade ,
240+ alpha = alpha , color = color , ** plot_kwargs )
226241
227242 style_spectrum_plot (ax , log_freqs , log_powers )
0 commit comments