33from itertools import cycle
44
55import numpy as np
6+ import matplotlib .pyplot as plt
67
78from fooof .sim .gen import gen_freqs , gen_aperiodic
89from fooof .core .modutils import safe_import , check_dependency
910from fooof .plts .settings import PLT_FIGSIZES
10- from fooof .plts .style import check_n_style , style_param_plot
11- from fooof .plts .utils import check_ax , recursive_plot , check_plot_kwargs
11+ from fooof .plts .style import style_param_plot , style_plot
12+ from fooof .plts .utils import check_ax , recursive_plot , savefig , check_plot_kwargs
1213
1314plt = safe_import ('.pyplot' , 'matplotlib' )
1415
1516###################################################################################################
1617###################################################################################################
1718
19+ @savefig
20+ @style_plot
1821@check_dependency (plt , 'matplotlib' )
19- def plot_aperiodic_params (aps , colors = None , labels = None ,
20- ax = None , plot_style = style_param_plot , ** plot_kwargs ):
22+ def plot_aperiodic_params (aps , colors = None , labels = None , ax = None , ** plot_kwargs ):
2123 """Plot aperiodic parameters as dots representing offset and exponent value.
2224
2325 Parameters
@@ -30,38 +32,38 @@ def plot_aperiodic_params(aps, colors=None, labels=None,
3032 Label(s) for plotted data, to be added in a legend.
3133 ax : matplotlib.Axes, optional
3234 Figure axes upon which to plot.
33- plot_style : callable, optional, default: style_param_plot
34- A function to call to apply styling & aesthetics to the plot.
3535 **plot_kwargs
36- Keyword arguments to pass into the plot call .
36+ Keyword arguments to pass into the ``style_plot`` .
3737 """
3838
3939 ax = check_ax (ax , plot_kwargs .pop ('figsize' , PLT_FIGSIZES ['params' ]))
4040
4141 if isinstance (aps , list ):
42- recursive_plot (aps , plot_aperiodic_params , ax , colors = colors , labels = labels ,
43- plot_style = plot_style , ** plot_kwargs )
42+ recursive_plot (aps , plot_aperiodic_params , ax , colors = colors , labels = labels )
4443
4544 else :
4645
4746 # Unpack data: offset as x; exponent as y
4847 xs , ys = aps [:, 0 ], aps [:, - 1 ]
4948 sizes = plot_kwargs .pop ('s' , 150 )
5049
50+ # Create the plot
5151 plot_kwargs = check_plot_kwargs (plot_kwargs , {'alpha' : 0.7 })
5252 ax .scatter (xs , ys , sizes , c = colors , label = labels , ** plot_kwargs )
5353
5454 # Add axis labels
5555 ax .set_xlabel ('Offset' )
5656 ax .set_ylabel ('Exponent' )
5757
58- check_n_style ( plot_style , ax )
58+ style_param_plot ( ax )
5959
6060
61+ @savefig
62+ @style_plot
6163@check_dependency (plt , 'matplotlib' )
6264def plot_aperiodic_fits (aps , freq_range , control_offset = False ,
6365 log_freqs = False , colors = None , labels = None ,
64- ax = None , plot_style = style_param_plot , ** plot_kwargs ):
66+ ax = None , ** plot_kwargs ):
6567 """Plot reconstructions of model aperiodic fits.
6668
6769 Parameters
@@ -80,10 +82,8 @@ def plot_aperiodic_fits(aps, freq_range, control_offset=False,
8082 Label(s) for plotted data, to be added in a legend.
8183 ax : matplotlib.Axes, optional
8284 Figure axes upon which to plot.
83- plot_style : callable, optional, default: style_param_plot
84- A function to call to apply styling & aesthetics to the plot.
8585 **plot_kwargs
86- Keyword arguments to pass into the plot call .
86+ Keyword arguments to pass into the ``style_plot`` .
8787 """
8888
8989 ax = check_ax (ax , plot_kwargs .pop ('figsize' , PLT_FIGSIZES ['params' ]))
@@ -93,11 +93,9 @@ def plot_aperiodic_fits(aps, freq_range, control_offset=False,
9393 if not colors :
9494 colors = cycle (plt .rcParams ['axes.prop_cycle' ].by_key ()['color' ])
9595
96- recursive_plot (aps , plot_function = plot_aperiodic_fits , ax = ax ,
97- freq_range = tuple (freq_range ),
98- control_offset = control_offset ,
99- log_freqs = log_freqs , colors = colors , labels = labels ,
100- plot_style = plot_style , ** plot_kwargs )
96+ recursive_plot (aps , plot_aperiodic_fits , ax = ax , freq_range = tuple (freq_range ),
97+ control_offset = control_offset , log_freqs = log_freqs , colors = colors ,
98+ labels = labels , ** plot_kwargs )
10199 else :
102100
103101 freqs = gen_freqs (freq_range , 0.1 )
@@ -118,17 +116,15 @@ def plot_aperiodic_fits(aps, freq_range, control_offset=False,
118116 # Recreate & plot the aperiodic component from parameters
119117 ap_vals = gen_aperiodic (freqs , ap_params )
120118
121- plot_kwargs = check_plot_kwargs (plot_kwargs , {'alpha' : 0.35 , 'linewidth' : 1.25 })
122- ax .plot (plt_freqs , ap_vals , color = colors , ** plot_kwargs )
119+ ax .plot (plt_freqs , ap_vals , color = colors , alpha = 0.35 , linewidth = 1.25 )
123120
124121 # Collect a running average across components
125122 avg_vals = np .nansum (np .vstack ([avg_vals , ap_vals ]), axis = 0 )
126123
127124 # Plot the average component
128125 avg = avg_vals / aps .shape [0 ]
129126 avg_color = 'black' if not colors else colors
130- ax .plot (plt_freqs , avg , linewidth = plot_kwargs .get ('linewidth' )* 3 ,
131- color = avg_color , label = labels )
127+ ax .plot (plt_freqs , avg , linewidth = 3.75 , color = avg_color , label = labels )
132128
133129 # Add axis labels
134130 ax .set_xlabel ('log(Frequency)' if log_freqs else 'Frequency' )
@@ -137,5 +133,4 @@ def plot_aperiodic_fits(aps, freq_range, control_offset=False,
137133 # Set plot limit
138134 ax .set_xlim (np .log10 (freq_range ) if log_freqs else freq_range )
139135
140- # Apply plot style
141- check_n_style (plot_style , ax )
136+ style_param_plot (ax )
0 commit comments