Skip to content

Commit 083cc08

Browse files
authored
Merge pull request #176 from ryanhammonds/plts
[ENH] Plot style managment
2 parents 7c75d64 + f5e448e commit 083cc08

23 files changed

+571
-243
lines changed

fooof/objs/fit.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,6 @@
7474
gen_issue_str, gen_width_warning_str)
7575

7676
from fooof.plts.fm import plot_fm
77-
from fooof.plts.style import style_spectrum_plot
7877
from fooof.utils.data import trim_spectrum
7978
from fooof.utils.params import compute_gauss_std
8079
from fooof.data import FOOOFResults, FOOOFSettings, FOOOFMetaData
@@ -633,12 +632,13 @@ def get_results(self):
633632
@copy_doc_func_to_method(plot_fm)
634633
def plot(self, plot_peaks=None, plot_aperiodic=True, plt_log=False,
635634
add_legend=True, save_fig=False, file_name=None, file_path=None,
636-
ax=None, plot_style=style_spectrum_plot,
637-
data_kwargs=None, model_kwargs=None, aperiodic_kwargs=None, peak_kwargs=None):
635+
ax=None, data_kwargs=None, model_kwargs=None,
636+
aperiodic_kwargs=None, peak_kwargs=None, **plot_kwargs):
638637

639-
plot_fm(self, plot_peaks, plot_aperiodic, plt_log, add_legend,
640-
save_fig, file_name, file_path, ax, plot_style,
641-
data_kwargs, model_kwargs, aperiodic_kwargs, peak_kwargs)
638+
plot_fm(self, plot_peaks=plot_peaks, plot_aperiodic=plot_aperiodic, plt_log=plt_log,
639+
add_legend=add_legend, save_fig=save_fig, file_name=file_name,
640+
file_path=file_path, ax=ax, data_kwargs=data_kwargs, model_kwargs=model_kwargs,
641+
aperiodic_kwargs=aperiodic_kwargs, peak_kwargs=peak_kwargs, **plot_kwargs)
642642

643643

644644
@copy_doc_func_to_method(save_report_fm)

fooof/objs/group.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -396,9 +396,9 @@ def get_params(self, name, col=None):
396396

397397

398398
@copy_doc_func_to_method(plot_fg)
399-
def plot(self, save_fig=False, file_name=None, file_path=None):
399+
def plot(self, save_fig=False, file_name=None, file_path=None, **plot_kwargs):
400400

401-
plot_fg(self, save_fig, file_name, file_path)
401+
plot_fg(self, save_fig=save_fig, file_name=file_name, file_path=file_path, **plot_kwargs)
402402

403403

404404
@copy_doc_func_to_method(save_report_fg)

fooof/plts/annotate.py

Lines changed: 18 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,10 @@
77
from fooof.core.funcs import gaussian_function
88
from fooof.core.modutils import safe_import, check_dependency
99
from fooof.sim.gen import gen_aperiodic
10-
from fooof.plts.utils import check_ax
10+
from fooof.plts.utils import check_ax, savefig
1111
from fooof.plts.spectra import plot_spectrum
1212
from fooof.plts.settings import PLT_FIGSIZES, PLT_COLORS
13-
from fooof.plts.style import check_n_style, style_spectrum_plot
13+
from fooof.plts.style import style_spectrum_plot
1414
from fooof.analysis.periodic import get_band_peak_fm
1515
from fooof.utils.params import compute_knee_frequency, compute_fwhm
1616

@@ -20,16 +20,15 @@
2020
###################################################################################################
2121
###################################################################################################
2222

23+
@savefig
2324
@check_dependency(plt, 'matplotlib')
24-
def plot_annotated_peak_search(fm, plot_style=style_spectrum_plot):
25+
def plot_annotated_peak_search(fm):
2526
"""Plot a series of plots illustrating the peak search from a flattened spectrum.
2627
2728
Parameters
2829
----------
2930
fm : FOOOF
3031
FOOOF object, with model fit, data and settings available.
31-
plot_style : callable, optional, default: style_spectrum_plot
32-
A function to call to apply styling & aesthetics to the plots.
3332
"""
3433

3534
# Recalculate the initial aperiodic fit and flattened spectrum that
@@ -46,14 +45,12 @@ def plot_annotated_peak_search(fm, plot_style=style_spectrum_plot):
4645
# This forces the creation of a new plotting axes per iteration
4746
ax = check_ax(None, PLT_FIGSIZES['spectral'])
4847

49-
plot_spectrum(fm.freqs, flatspec, ax=ax, plot_style=None,
50-
label='Flattened Spectrum', color=PLT_COLORS['data'], linewidth=2.5)
51-
plot_spectrum(fm.freqs, [fm.peak_threshold * np.std(flatspec)]*len(fm.freqs),
52-
ax=ax, plot_style=None, label='Relative Threshold',
53-
color='orange', linewidth=2.5, linestyle='dashed')
54-
plot_spectrum(fm.freqs, [fm.min_peak_height]*len(fm.freqs),
55-
ax=ax, plot_style=None, label='Absolute Threshold',
56-
color='red', linewidth=2.5, linestyle='dashed')
48+
plot_spectrum(fm.freqs, flatspec, ax=ax, linewidth=2.5,
49+
label='Flattened Spectrum', color=PLT_COLORS['data'])
50+
plot_spectrum(fm.freqs, [fm.peak_threshold * np.std(flatspec)]*len(fm.freqs), ax=ax,
51+
label='Relative Threshold', color='orange', linewidth=2.5, linestyle='dashed')
52+
plot_spectrum(fm.freqs, [fm.min_peak_height]*len(fm.freqs), ax=ax,
53+
label='Absolute Threshold', color='red', linewidth=2.5, linestyle='dashed')
5754

5855
maxi = np.argmax(flatspec)
5956
ax.plot(fm.freqs[maxi], flatspec[maxi], '.',
@@ -65,18 +62,18 @@ def plot_annotated_peak_search(fm, plot_style=style_spectrum_plot):
6562
if ind < fm.n_peaks_:
6663

6764
gauss = gaussian_function(fm.freqs, *fm.gaussian_params_[ind, :])
68-
plot_spectrum(fm.freqs, gauss, ax=ax, plot_style=None,
69-
label='Gaussian Fit', color=PLT_COLORS['periodic'],
70-
linestyle=':', linewidth=3.0)
65+
plot_spectrum(fm.freqs, gauss, ax=ax, label='Gaussian Fit',
66+
color=PLT_COLORS['periodic'], linestyle=':', linewidth=3.0)
7167

7268
flatspec = flatspec - gauss
7369

74-
check_n_style(plot_style, ax, False, True)
70+
style_spectrum_plot(ax, False, True)
7571

7672

73+
@savefig
7774
@check_dependency(plt, 'matplotlib')
78-
def plot_annotated_model(fm, plt_log=False, annotate_peaks=True, annotate_aperiodic=True,
79-
ax=None, plot_style=style_spectrum_plot):
75+
def plot_annotated_model(fm, plt_log=False, annotate_peaks=True,
76+
annotate_aperiodic=True, ax=None):
8077
"""Plot a an annotated power spectrum and model, from a FOOOF object.
8178
8279
Parameters
@@ -91,8 +88,6 @@ def plot_annotated_model(fm, plt_log=False, annotate_peaks=True, annotate_aperio
9188
Whether to annotate the aperiodic components of the model fit.
9289
ax : matplotlib.Axes, optional
9390
Figure axes upon which to plot.
94-
plot_style : callable, optional, default: style_spectrum_plot
95-
A function to call to apply styling & aesthetics to the plots.
9691
9792
Raises
9893
------
@@ -112,7 +107,7 @@ def plot_annotated_model(fm, plt_log=False, annotate_peaks=True, annotate_aperio
112107

113108
# Create the baseline figure
114109
ax = check_ax(ax, PLT_FIGSIZES['spectral'])
115-
fm.plot(plot_peaks='dot-shade-width', plt_log=plt_log, ax=ax, plot_style=None,
110+
fm.plot(plot_peaks='dot-shade-width', plt_log=plt_log, ax=ax,
116111
data_kwargs={'lw' : lw1, 'alpha' : 0.6},
117112
aperiodic_kwargs={'lw' : lw1, 'zorder' : 10},
118113
model_kwargs={'lw' : lw1, 'alpha' : 0.5},
@@ -219,7 +214,7 @@ def plot_annotated_model(fm, plt_log=False, annotate_peaks=True, annotate_aperio
219214
color=PLT_COLORS['aperiodic'], fontsize=fontsize)
220215

221216
# Apply style to plot & tune grid styling
222-
check_n_style(plot_style, ax, plt_log, True)
217+
style_spectrum_plot(ax, plt_log, True)
223218
ax.grid(True, alpha=0.5)
224219

225220
# Add labels to plot in the legend

fooof/plts/aperiodic.py

Lines changed: 20 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -3,21 +3,23 @@
33
from itertools import cycle
44

55
import numpy as np
6+
import matplotlib.pyplot as plt
67

78
from fooof.sim.gen import gen_freqs, gen_aperiodic
89
from fooof.core.modutils import safe_import, check_dependency
910
from fooof.plts.settings import PLT_FIGSIZES
10-
from fooof.plts.style import check_n_style, style_param_plot
11-
from fooof.plts.utils import check_ax, recursive_plot, check_plot_kwargs
11+
from fooof.plts.style import style_param_plot, style_plot
12+
from fooof.plts.utils import check_ax, recursive_plot, savefig, check_plot_kwargs
1213

1314
plt = safe_import('.pyplot', 'matplotlib')
1415

1516
###################################################################################################
1617
###################################################################################################
1718

19+
@savefig
20+
@style_plot
1821
@check_dependency(plt, 'matplotlib')
19-
def plot_aperiodic_params(aps, colors=None, labels=None,
20-
ax=None, plot_style=style_param_plot, **plot_kwargs):
22+
def plot_aperiodic_params(aps, colors=None, labels=None, ax=None, **plot_kwargs):
2123
"""Plot aperiodic parameters as dots representing offset and exponent value.
2224
2325
Parameters
@@ -30,38 +32,38 @@ def plot_aperiodic_params(aps, colors=None, labels=None,
3032
Label(s) for plotted data, to be added in a legend.
3133
ax : matplotlib.Axes, optional
3234
Figure axes upon which to plot.
33-
plot_style : callable, optional, default: style_param_plot
34-
A function to call to apply styling & aesthetics to the plot.
3535
**plot_kwargs
36-
Keyword arguments to pass into the plot call.
36+
Keyword arguments to pass into the ``style_plot``.
3737
"""
3838

3939
ax = check_ax(ax, plot_kwargs.pop('figsize', PLT_FIGSIZES['params']))
4040

4141
if isinstance(aps, list):
42-
recursive_plot(aps, plot_aperiodic_params, ax, colors=colors, labels=labels,
43-
plot_style=plot_style, **plot_kwargs)
42+
recursive_plot(aps, plot_aperiodic_params, ax, colors=colors, labels=labels)
4443

4544
else:
4645

4746
# Unpack data: offset as x; exponent as y
4847
xs, ys = aps[:, 0], aps[:, -1]
4948
sizes = plot_kwargs.pop('s', 150)
5049

50+
# Create the plot
5151
plot_kwargs = check_plot_kwargs(plot_kwargs, {'alpha' : 0.7})
5252
ax.scatter(xs, ys, sizes, c=colors, label=labels, **plot_kwargs)
5353

5454
# Add axis labels
5555
ax.set_xlabel('Offset')
5656
ax.set_ylabel('Exponent')
5757

58-
check_n_style(plot_style, ax)
58+
style_param_plot(ax)
5959

6060

61+
@savefig
62+
@style_plot
6163
@check_dependency(plt, 'matplotlib')
6264
def plot_aperiodic_fits(aps, freq_range, control_offset=False,
6365
log_freqs=False, colors=None, labels=None,
64-
ax=None, plot_style=style_param_plot, **plot_kwargs):
66+
ax=None, **plot_kwargs):
6567
"""Plot reconstructions of model aperiodic fits.
6668
6769
Parameters
@@ -80,10 +82,8 @@ def plot_aperiodic_fits(aps, freq_range, control_offset=False,
8082
Label(s) for plotted data, to be added in a legend.
8183
ax : matplotlib.Axes, optional
8284
Figure axes upon which to plot.
83-
plot_style : callable, optional, default: style_param_plot
84-
A function to call to apply styling & aesthetics to the plot.
8585
**plot_kwargs
86-
Keyword arguments to pass into the plot call.
86+
Keyword arguments to pass into the ``style_plot``.
8787
"""
8888

8989
ax = check_ax(ax, plot_kwargs.pop('figsize', PLT_FIGSIZES['params']))
@@ -93,11 +93,9 @@ def plot_aperiodic_fits(aps, freq_range, control_offset=False,
9393
if not colors:
9494
colors = cycle(plt.rcParams['axes.prop_cycle'].by_key()['color'])
9595

96-
recursive_plot(aps, plot_function=plot_aperiodic_fits, ax=ax,
97-
freq_range=tuple(freq_range),
98-
control_offset=control_offset,
99-
log_freqs=log_freqs, colors=colors, labels=labels,
100-
plot_style=plot_style, **plot_kwargs)
96+
recursive_plot(aps, plot_aperiodic_fits, ax=ax, freq_range=tuple(freq_range),
97+
control_offset=control_offset, log_freqs=log_freqs, colors=colors,
98+
labels=labels, **plot_kwargs)
10199
else:
102100

103101
freqs = gen_freqs(freq_range, 0.1)
@@ -118,17 +116,15 @@ def plot_aperiodic_fits(aps, freq_range, control_offset=False,
118116
# Recreate & plot the aperiodic component from parameters
119117
ap_vals = gen_aperiodic(freqs, ap_params)
120118

121-
plot_kwargs = check_plot_kwargs(plot_kwargs, {'alpha' : 0.35, 'linewidth' : 1.25})
122-
ax.plot(plt_freqs, ap_vals, color=colors, **plot_kwargs)
119+
ax.plot(plt_freqs, ap_vals, color=colors, alpha=0.35, linewidth=1.25)
123120

124121
# Collect a running average across components
125122
avg_vals = np.nansum(np.vstack([avg_vals, ap_vals]), axis=0)
126123

127124
# Plot the average component
128125
avg = avg_vals / aps.shape[0]
129126
avg_color = 'black' if not colors else colors
130-
ax.plot(plt_freqs, avg, linewidth=plot_kwargs.get('linewidth')*3,
131-
color=avg_color, label=labels)
127+
ax.plot(plt_freqs, avg, linewidth=3.75, color=avg_color, label=labels)
132128

133129
# Add axis labels
134130
ax.set_xlabel('log(Frequency)' if log_freqs else 'Frequency')
@@ -137,5 +133,4 @@ def plot_aperiodic_fits(aps, freq_range, control_offset=False,
137133
# Set plot limit
138134
ax.set_xlim(np.log10(freq_range) if log_freqs else freq_range)
139135

140-
# Apply plot style
141-
check_n_style(plot_style, ax)
136+
style_param_plot(ax)

fooof/plts/error.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,18 @@
55
from fooof.core.modutils import safe_import, check_dependency
66
from fooof.plts.spectra import plot_spectrum
77
from fooof.plts.settings import PLT_FIGSIZES
8-
from fooof.plts.style import check_n_style, style_spectrum_plot
9-
from fooof.plts.utils import check_ax
8+
from fooof.plts.style import style_spectrum_plot, style_plot
9+
from fooof.plts.utils import check_ax, savefig
1010

1111
plt = safe_import('.pyplot', 'matplotlib')
1212

1313
###################################################################################################
1414
###################################################################################################
1515

16+
@savefig
17+
@style_plot
1618
@check_dependency(plt, 'matplotlib')
17-
def plot_spectral_error(freqs, error, shade=None, log_freqs=False,
18-
ax=None, plot_style=style_spectrum_plot, **plot_kwargs):
19+
def plot_spectral_error(freqs, error, shade=None, log_freqs=False, ax=None, **plot_kwargs):
1920
"""Plot frequency by frequency error values.
2021
2122
Parameters
@@ -31,17 +32,15 @@ def plot_spectral_error(freqs, error, shade=None, log_freqs=False,
3132
Whether to plot the frequency axis in log spacing.
3233
ax : matplotlib.Axes, optional
3334
Figure axes upon which to plot.
34-
plot_style : callable, optional, default: style_spectrum_plot
35-
A function to call to apply styling & aesthetics to the plot.
3635
**plot_kwargs
37-
Keyword arguments to be passed to `plot_spectra` or to the plot call.
36+
Keyword arguments to pass into the ``style_plot``.
3837
"""
3938

4039
ax = check_ax(ax, plot_kwargs.pop('figsize', PLT_FIGSIZES['spectral']))
4140

4241
plt_freqs = np.log10(freqs) if log_freqs else freqs
4342

44-
plot_spectrum(plt_freqs, error, plot_style=None, ax=ax, linewidth=3, **plot_kwargs)
43+
plot_spectrum(plt_freqs, error, ax=ax, linewidth=3)
4544

4645
if np.any(shade):
4746
ax.fill_between(plt_freqs, error-shade, error+shade, alpha=0.25)
@@ -51,5 +50,5 @@ def plot_spectral_error(freqs, error, shade=None, log_freqs=False,
5150
ax.set_ylim([0, ymax])
5251
ax.set_xlim(plt_freqs.min(), plt_freqs.max())
5352

54-
check_n_style(plot_style, ax, log_freqs, True)
53+
style_spectrum_plot(ax, log_freqs, True)
5554
ax.set_ylabel('Absolute Error')

0 commit comments

Comments
 (0)