55This file contains functions for plotting power spectra, that take in data directly.
66"""
77
8+ from inspect import isfunction
89from itertools import repeat , cycle
910
1011import numpy as np
@@ -209,7 +210,7 @@ def plot_spectra_yshade(freqs, power_spectra, shade='std', average='mean', scale
209210 Keyword arguments to be passed to `plot_spectra` or to the plot call.
210211 """
211212
212- if isinstance (shade , str ) and power_spectra .ndim != 2 :
213+ if ( isinstance (shade , str ) or isfunction ( shade ) ) and power_spectra .ndim != 2 :
213214 raise ValueError ('Power spectra must be 2d if shade is not given.' )
214215
215216 ax = check_ax (ax , plot_kwargs .pop ('figsize' , PLT_FIGSIZES ['spectral' ]))
@@ -220,17 +221,27 @@ def plot_spectra_yshade(freqs, power_spectra, shade='std', average='mean', scale
220221
221222 # Organize mean spectrum to plot
222223 avg_funcs = {'mean' : np .mean , 'median' : np .median }
223- avg_func = avg_funcs [average ] if isinstance (average , str ) else average
224- avg_powers = avg_func (plt_powers , axis = 0 ) if plt_powers .ndim == 2 else plt_powers
224+
225+ if isinstance (average , str ) and plt_powers .ndim == 2 :
226+ avg_powers = avg_funcs [average ](plt_powers , axis = 0 )
227+ elif isfunction (average ) and plt_powers .ndim == 2 :
228+ avg_powers = average (plt_powers )
229+ else :
230+ avg_powers = plt_powers
225231
226232 # Plot average power spectrum
227233 ax .plot (plt_freqs , avg_powers , linewidth = 2.0 , color = color , label = label )
228234
229235 # Organize shading to plot
230236 shade_funcs = {'std' : np .std , 'sem' : sem }
231- shade_func = shade_funcs [shade ] if isinstance (shade , str ) else shade
232- shade_vals = scale * shade_func (plt_powers , axis = 0 ) \
233- if isinstance (shade , str ) else scale * shade
237+
238+ if isinstance (shade , str ):
239+ shade_vals = scale * shade_funcs [shade ](plt_powers , axis = 0 )
240+ elif isfunction (shade ):
241+ shade_vals = scale * shade (plt_powers )
242+ else :
243+ shade_vals = scale * shade
244+
234245 upper_shade = avg_powers + shade_vals
235246 lower_shade = avg_powers - shade_vals
236247
0 commit comments