6767from fooof .core .reports import save_report_fm
6868from fooof .core .modutils import copy_doc_func_to_method
6969from fooof .core .utils import group_three , check_array_dim
70- from fooof .core .funcs import gaussian_function , get_ap_func , infer_ap_func
70+ from fooof .core .funcs import get_pe_func , get_ap_func , infer_ap_func
7171from fooof .core .errors import (FitError , NoModelError , DataError ,
7272 NoDataError , InconsistentDataError )
7373from fooof .core .strings import (gen_settings_str , gen_results_fm_str ,
@@ -154,8 +154,9 @@ class FOOOF():
154154 """
155155 # pylint: disable=attribute-defined-outside-init
156156
157- def __init__ (self , peak_width_limits = (0.5 , 12.0 ), max_n_peaks = np .inf , min_peak_height = 0.0 ,
158- peak_threshold = 2.0 , aperiodic_mode = 'fixed' , verbose = True ):
157+ def __init__ (self , peak_width_limits = (0.5 , 12.0 ), max_n_peaks = np .inf ,
158+ min_peak_height = 0.0 , peak_threshold = 2.0 , aperiodic_mode = 'fixed' ,
159+ periodic_mode = 'gaussian' , verbose = True ):
159160 """Initialize object with desired settings."""
160161
161162 # Set input settings
@@ -164,6 +165,7 @@ def __init__(self, peak_width_limits=(0.5, 12.0), max_n_peaks=np.inf, min_peak_h
164165 self .min_peak_height = min_peak_height
165166 self .peak_threshold = peak_threshold
166167 self .aperiodic_mode = aperiodic_mode
168+ self .periodic_mode = periodic_mode
167169 self .verbose = verbose
168170
169171 ## PRIVATE SETTINGS
@@ -439,6 +441,9 @@ def fit(self, freqs=None, power_spectrum=None, freq_range=None):
439441 if self .verbose :
440442 self ._check_width_limits ()
441443
444+ # Determine the aperiodic and periodic fit funcs
445+ self ._set_fit_funcs ()
446+
442447 # In rare cases, the model fails to fit, and so uses try / except
443448 try :
444449
@@ -715,6 +720,11 @@ def set_check_data_mode(self, check_data):
715720
716721 self ._check_data = check_data
717722
723+ def _set_fit_funcs (self ):
724+ """Set the requested aperiodic and periodic fit functions."""
725+
726+ self ._pe_func = get_pe_func (self .periodic_mode )
727+ self ._ap_func = get_ap_func (self .aperiodic_mode )
718728
719729 def _check_width_limits (self ):
720730 """Check and warn about peak width limits / frequency resolution interaction."""
@@ -762,8 +772,7 @@ def _simple_ap_fit(self, freqs, power_spectrum):
762772 try :
763773 with warnings .catch_warnings ():
764774 warnings .simplefilter ("ignore" )
765- aperiodic_params , _ = curve_fit (get_ap_func (self .aperiodic_mode ),
766- freqs , power_spectrum , p0 = guess ,
775+ aperiodic_params , _ = curve_fit (self ._ap_func , freqs , power_spectrum , p0 = guess ,
767776 maxfev = self ._maxfev , bounds = ap_bounds )
768777 except RuntimeError :
769778 raise FitError ("Model fitting failed due to not finding parameters in "
@@ -818,9 +827,8 @@ def _robust_ap_fit(self, freqs, power_spectrum):
818827 try :
819828 with warnings .catch_warnings ():
820829 warnings .simplefilter ("ignore" )
821- aperiodic_params , _ = curve_fit (get_ap_func (self .aperiodic_mode ),
822- freqs_ignore , spectrum_ignore , p0 = popt ,
823- maxfev = self ._maxfev , bounds = ap_bounds )
830+ aperiodic_params , _ = curve_fit (self ._ap_func , freqs_ignore , spectrum_ignore ,
831+ p0 = popt , maxfev = self ._maxfev , bounds = ap_bounds )
824832 except RuntimeError :
825833 raise FitError ("Model fitting failed due to not finding "
826834 "parameters in the robust aperiodic fit." )
@@ -904,7 +912,7 @@ def _fit_peaks(self, flat_iter):
904912
905913 # Collect guess parameters and subtract this guess gaussian from the data
906914 guess = np .vstack ((guess , (guess_freq , guess_height , guess_std )))
907- peak_gauss = gaussian_function (self .freqs , guess_freq , guess_height , guess_std )
915+ peak_gauss = self . _pe_func (self .freqs , guess_freq , guess_height , guess_std )
908916 flat_iter = flat_iter - peak_gauss
909917
910918 # Check peaks based on edges, and on overlap, dropping any that violate requirements
@@ -963,7 +971,7 @@ def _fit_peak_guess(self, guess):
963971
964972 # Fit the peaks
965973 try :
966- gaussian_params , _ = curve_fit (gaussian_function , self .freqs , self ._spectrum_flat ,
974+ gaussian_params , _ = curve_fit (self . _pe_func , self .freqs , self ._spectrum_flat ,
967975 p0 = guess , maxfev = self ._maxfev , bounds = gaus_param_bounds )
968976 except RuntimeError :
969977 raise FitError ("Model fitting failed due to not finding "
0 commit comments