Skip to content

Commit 37d6cd8

Browse files
committed
make grid an updateable argument in PSD plots
1 parent d63aae0 commit 37d6cd8

File tree

2 files changed

+11
-7
lines changed

2 files changed

+11
-7
lines changed

fooof/plts/spectra.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -50,10 +50,10 @@ def plot_spectra(freqs, power_spectra, log_freqs=False, log_powers=False, freq_r
5050
Additional plot related keyword arguments.
5151
"""
5252

53+
# Create the plot & collect plot kwargs of interest
5354
ax = check_ax(ax, plot_kwargs.pop('figsize', PLT_FIGSIZES['spectral']))
54-
55-
# Create the plot
5655
plot_kwargs = check_plot_kwargs(plot_kwargs, {'linewidth' : 2.0})
56+
grid = plot_kwargs.pop('grid', True)
5757

5858
# Check for frequency range input, and log if x-axis is in log space
5959
if freq_range is not None:
@@ -82,7 +82,7 @@ def plot_spectra(freqs, power_spectra, log_freqs=False, log_powers=False, freq_r
8282

8383
ax.set_xlim(freq_range)
8484

85-
style_spectrum_plot(ax, log_freqs, log_powers)
85+
style_spectrum_plot(ax, log_freqs, log_powers, grid)
8686

8787

8888
# Alias `plot_spectrum` to `plot_spectra` for backwards compatibility
@@ -127,7 +127,8 @@ def plot_spectra_shading(freqs, power_spectra, shades, shade_colors='r',
127127
add_shades(ax, shades, shade_colors, add_center, plot_kwargs.get('log_freqs', False))
128128

129129
style_spectrum_plot(ax, plot_kwargs.get('log_freqs', False),
130-
plot_kwargs.get('log_powers', False))
130+
plot_kwargs.get('log_powers', False),
131+
plot_kwargs.get('grid', True))
131132

132133

133134
# Alias `plot_spectrum_shading` to `plot_spectra_shading` for backwards compatibility
@@ -172,6 +173,7 @@ def plot_spectra_yshade(freqs, power_spectra, shade='std', average='mean', scale
172173
raise ValueError('Power spectra must be 2d if shade is not given.')
173174

174175
ax = check_ax(ax, plot_kwargs.pop('figsize', PLT_FIGSIZES['spectral']))
176+
grid = plot_kwargs.pop('grid', True)
175177

176178
# Set plot data & labels, logging if requested
177179
plt_freqs = np.log10(freqs) if log_freqs else freqs
@@ -208,4 +210,4 @@ def plot_spectra_yshade(freqs, power_spectra, shade='std', average='mean', scale
208210
ax.fill_between(plt_freqs, lower_shade, upper_shade,
209211
alpha=alpha, color=color, **plot_kwargs)
210212

211-
style_spectrum_plot(ax, log_freqs, log_powers)
213+
style_spectrum_plot(ax, log_freqs, log_powers, grid)

fooof/plts/style.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
###################################################################################################
1313
###################################################################################################
1414

15-
def style_spectrum_plot(ax, log_freqs, log_powers):
15+
def style_spectrum_plot(ax, log_freqs, log_powers, grid=True):
1616
"""Apply style and aesthetics to a power spectrum plot.
1717
1818
Parameters
@@ -23,6 +23,8 @@ def style_spectrum_plot(ax, log_freqs, log_powers):
2323
Whether the frequency axis is plotted in log space.
2424
log_powers : bool
2525
Whether the power axis is plotted in log space.
26+
grid : bool, optional, default: True
27+
Whether to add grid lines to the plot.
2628
"""
2729

2830
# Get labels, based on log status
@@ -33,7 +35,7 @@ def style_spectrum_plot(ax, log_freqs, log_powers):
3335
ax.set_xlabel(xlabel, fontsize=20)
3436
ax.set_ylabel(ylabel, fontsize=20)
3537
ax.tick_params(axis='both', which='major', labelsize=16)
36-
ax.grid(True)
38+
ax.grid(grid)
3739

3840
# If labels were provided, add a legend
3941
if ax.get_legend_handles_labels()[0]:

0 commit comments

Comments
 (0)