|
71 | 71 | from specparam.core.modutils import copy_doc_func_to_method |
72 | 72 | from specparam.core.utils import group_three, check_array_dim |
73 | 73 | from specparam.core.funcs import gaussian_function, get_ap_func, infer_ap_func |
| 74 | +from specparam.core.jacobians import jacobian_gauss |
74 | 75 | from specparam.core.errors import (FitError, NoModelError, DataError, |
75 | 76 | NoDataError, InconsistentDataError) |
76 | 77 | from specparam.core.strings import (gen_settings_str, gen_model_results_str, |
@@ -191,12 +192,17 @@ def __init__(self, peak_width_limits=(0.5, 12.0), max_n_peaks=np.inf, min_peak_h |
191 | 192 | self._gauss_overlap_thresh = 0.75 |
192 | 193 | # Parameter bounds for center frequency when fitting gaussians, in terms of +/- std dev |
193 | 194 | self._cf_bound = 1.5 |
194 | | - # The maximum number of calls to the curve fitting function |
195 | | - self._maxfev = 5000 |
196 | 195 | # The error metric to calculate, post model fitting. See `_calc_error` for options |
197 | 196 | # Note: this is for checking error post fitting, not an objective function for fitting |
198 | 197 | self._error_metric = 'MAE' |
199 | 198 |
|
| 199 | + ## PRIVATE CURVE_FIT SETTINGS |
| 200 | + # The maximum number of calls to the curve fitting function |
| 201 | + self._maxfev = 5000 |
| 202 | + # The tolerance setting for curve fitting (see scipy.curve_fit - ftol / xtol / gtol) |
| 203 | + # Here reduce tolerance to speed fitting. Set value to 1e-8 to match curve_fit default |
| 204 | + self._tol = 0.00001 |
| 205 | + |
200 | 206 | ## RUN MODES |
201 | 207 | # Set default debug mode - controls if an error is raised if model fitting is unsuccessful |
202 | 208 | self._debug = False |
@@ -944,7 +950,9 @@ def _simple_ap_fit(self, freqs, power_spectrum): |
944 | 950 | warnings.simplefilter("ignore") |
945 | 951 | aperiodic_params, _ = curve_fit(get_ap_func(self.aperiodic_mode), |
946 | 952 | freqs, power_spectrum, p0=guess, |
947 | | - maxfev=self._maxfev, bounds=ap_bounds) |
| 953 | + maxfev=self._maxfev, bounds=ap_bounds, |
| 954 | + ftol=self._tol, xtol=self._tol, gtol=self._tol, |
| 955 | + check_finite=False) |
948 | 956 | except RuntimeError as excp: |
949 | 957 | error_msg = ("Model fitting failed due to not finding parameters in " |
950 | 958 | "the simple aperiodic component fit.") |
@@ -1001,7 +1009,9 @@ def _robust_ap_fit(self, freqs, power_spectrum): |
1001 | 1009 | warnings.simplefilter("ignore") |
1002 | 1010 | aperiodic_params, _ = curve_fit(get_ap_func(self.aperiodic_mode), |
1003 | 1011 | freqs_ignore, spectrum_ignore, p0=popt, |
1004 | | - maxfev=self._maxfev, bounds=ap_bounds) |
| 1012 | + maxfev=self._maxfev, bounds=ap_bounds, |
| 1013 | + ftol=self._tol, xtol=self._tol, gtol=self._tol, |
| 1014 | + check_finite=False) |
1005 | 1015 | except RuntimeError as excp: |
1006 | 1016 | error_msg = ("Model fitting failed due to not finding " |
1007 | 1017 | "parameters in the robust aperiodic fit.") |
@@ -1147,7 +1157,9 @@ def _fit_peak_guess(self, guess): |
1147 | 1157 | # Fit the peaks |
1148 | 1158 | try: |
1149 | 1159 | gaussian_params, _ = curve_fit(gaussian_function, self.freqs, self._spectrum_flat, |
1150 | | - p0=guess, maxfev=self._maxfev, bounds=gaus_param_bounds) |
| 1160 | + p0=guess, maxfev=self._maxfev, bounds=gaus_param_bounds, |
| 1161 | + ftol=self._tol, xtol=self._tol, gtol=self._tol, |
| 1162 | + check_finite=False, jac=jacobian_gauss) |
1151 | 1163 | except RuntimeError as excp: |
1152 | 1164 | error_msg = ("Model fitting failed due to not finding " |
1153 | 1165 | "parameters in the peak component fit.") |
|
0 commit comments