|
71 | 71 | from fooof.core.modutils import copy_doc_func_to_method |
72 | 72 | from fooof.core.utils import group_three, check_array_dim |
73 | 73 | from fooof.core.funcs import gaussian_function, get_ap_func, infer_ap_func |
| 74 | +from fooof.core.jacobians import jacobian_gauss |
74 | 75 | from fooof.core.errors import (FitError, NoModelError, DataError, |
75 | 76 | NoDataError, InconsistentDataError) |
76 | 77 | from fooof.core.strings import (gen_settings_str, gen_results_fm_str, |
@@ -192,12 +193,17 @@ def __init__(self, peak_width_limits=(0.5, 12.0), max_n_peaks=np.inf, min_peak_h |
192 | 193 | self._gauss_overlap_thresh = 0.75 |
193 | 194 | # Parameter bounds for center frequency when fitting gaussians, in terms of +/- std dev |
194 | 195 | self._cf_bound = 1.5 |
195 | | - # The maximum number of calls to the curve fitting function |
196 | | - self._maxfev = 5000 |
197 | 196 | # The error metric to calculate, post model fitting. See `_calc_error` for options |
198 | 197 | # Note: this is for checking error post fitting, not an objective function for fitting |
199 | 198 | self._error_metric = 'MAE' |
200 | 199 |
|
| 200 | + ## PRIVATE CURVE_FIT SETTINGS |
| 201 | + # The maximum number of calls to the curve fitting function |
| 202 | + self._maxfev = 5000 |
| 203 | + # The tolerance setting for curve fitting (see scipy.curve_fit - ftol / xtol / gtol) |
| 204 | + # Here reduce tolerance to speed fitting. Set value to 1e-8 to match curve_fit default |
| 205 | + self._tol = 0.00001 |
| 206 | + |
201 | 207 | ## RUN MODES |
202 | 208 | # Set default debug mode - controls if an error is raised if model fitting is unsuccessful |
203 | 209 | self._debug = False |
@@ -946,7 +952,9 @@ def _simple_ap_fit(self, freqs, power_spectrum): |
946 | 952 | warnings.simplefilter("ignore") |
947 | 953 | aperiodic_params, _ = curve_fit(get_ap_func(self.aperiodic_mode), |
948 | 954 | freqs, power_spectrum, p0=guess, |
949 | | - maxfev=self._maxfev, bounds=ap_bounds) |
| 955 | + maxfev=self._maxfev, bounds=ap_bounds, |
| 956 | + ftol=self._tol, xtol=self._tol, gtol=self._tol, |
| 957 | + check_finite=False) |
950 | 958 | except RuntimeError as excp: |
951 | 959 | error_msg = ("Model fitting failed due to not finding parameters in " |
952 | 960 | "the simple aperiodic component fit.") |
@@ -1003,7 +1011,9 @@ def _robust_ap_fit(self, freqs, power_spectrum): |
1003 | 1011 | warnings.simplefilter("ignore") |
1004 | 1012 | aperiodic_params, _ = curve_fit(get_ap_func(self.aperiodic_mode), |
1005 | 1013 | freqs_ignore, spectrum_ignore, p0=popt, |
1006 | | - maxfev=self._maxfev, bounds=ap_bounds) |
| 1014 | + maxfev=self._maxfev, bounds=ap_bounds, |
| 1015 | + ftol=self._tol, xtol=self._tol, gtol=self._tol, |
| 1016 | + check_finite=False) |
1007 | 1017 | except RuntimeError as excp: |
1008 | 1018 | error_msg = ("Model fitting failed due to not finding " |
1009 | 1019 | "parameters in the robust aperiodic fit.") |
@@ -1149,7 +1159,9 @@ def _fit_peak_guess(self, guess): |
1149 | 1159 | # Fit the peaks |
1150 | 1160 | try: |
1151 | 1161 | gaussian_params, _ = curve_fit(gaussian_function, self.freqs, self._spectrum_flat, |
1152 | | - p0=guess, maxfev=self._maxfev, bounds=gaus_param_bounds) |
| 1162 | + p0=guess, maxfev=self._maxfev, bounds=gaus_param_bounds, |
| 1163 | + ftol=self._tol, xtol=self._tol, gtol=self._tol, |
| 1164 | + check_finite=False, jac=jacobian_gauss) |
1153 | 1165 | except RuntimeError as excp: |
1154 | 1166 | error_msg = ("Model fitting failed due to not finding " |
1155 | 1167 | "parameters in the peak component fit.") |
|
0 commit comments