Skip to content

Commit f3d1ccb

Browse files
committed
add check_data, to control whether to fail on nans
1 parent 2a66519 commit f3d1ccb

File tree

2 files changed

+45
-17
lines changed

2 files changed

+45
-17
lines changed

fooof/objs/fit.py

Lines changed: 44 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,15 @@
3939
The maximum number of calls to the curve fitting function.
4040
_error_metric : str
4141
The error metric to use for post-hoc measures of model fit error.
42+
43+
Run Modes
44+
---------
4245
_debug : bool
4346
Whether the object is set in debug mode.
4447
This should be controlled by using the `set_debug_mode` method.
48+
_check_data : bool
49+
Whether to check added data for NaN or Inf values, and fail out if present.
50+
This should be controlled by using the `set_check_data_mode` method.
4551
4652
Code Notes
4753
----------
@@ -184,12 +190,16 @@ def __init__(self, peak_width_limits=(0.5, 12.0), max_n_peaks=np.inf, min_peak_h
184190
# The maximum number of calls to the curve fitting function
185191
self._maxfev = 5000
186192
# The error metric to calculate, post model fitting. See `_calc_error` for options
187-
# Note: this is used to check error post-hoc, not an objective function for fitting models
193+
# Note: this is for checking error post fitting, not an objective function for fitting
188194
self._error_metric = 'MAE'
189-
# Set whether in debug mode, in which an error is raised if a model fit fails
195+
196+
## RUN MODES
197+
# Set default debug mode - controls if an error is raised if model fitting is unsuccessful
190198
self._debug = False
199+
# Set default check data mode - controls if an error is raised if NaN / Inf data are added
200+
self._check_data = True
191201

192-
# Set internal settings, based on inputs, & initialize data & results attributes
202+
# Set internal settings, based on inputs, and initialize data & results attributes
193203
self._reset_internal_settings()
194204
self._reset_data_results(True, True, True)
195205

@@ -311,7 +321,7 @@ def add_data(self, freqs, power_spectrum, freq_range=None):
311321
self._reset_data_results(True, True, True)
312322

313323
self.freqs, self.power_spectrum, self.freq_range, self.freq_res = \
314-
self._prepare_data(freqs, power_spectrum, freq_range, 1, self.verbose)
324+
self._prepare_data(freqs, power_spectrum, freq_range, 1)
315325

316326

317327
def add_settings(self, fooof_settings):
@@ -431,6 +441,14 @@ def fit(self, freqs=None, power_spectrum=None, freq_range=None):
431441
# In rare cases, the model fails to fit, and so uses try / except
432442
try:
433443

444+
# If not set to fail on NaN or Inf data at add time, check data here
445+
# This serves as a catch all for curve_fits which will fail given NaN or Inf
446+
# Because FitError's are by default caught, this allows fitting to continue
447+
if not self._check_data:
448+
if not np.any(np.isinf(power_spectrum)) or np.any(np.isnan(power_spectrum)):
449+
raise FitError("There are NaN or Inf values in the data, "
450+
"which preclude model fitting.")
451+
434452
# Fit the aperiodic component
435453
self.aperiodic_params_ = self._robust_ap_fit(self.freqs, self.power_spectrum)
436454
self._ap_fit = gen_aperiodic(self.freqs, self.aperiodic_params_)
@@ -674,7 +692,7 @@ def copy(self):
674692

675693

676694
def set_debug_mode(self, debug):
677-
"""Set whether debug mode, wherein an error is raised if fitting is unsuccessful.
695+
"""Set debug mode, which controls if an error is raised if model fitting is unsuccessful.
678696
679697
Parameters
680698
----------
@@ -685,6 +703,18 @@ def set_debug_mode(self, debug):
685703
self._debug = debug
686704

687705

706+
def set_check_data(self, check_data):
707+
"""Set check data mode, which controls if an error is raised if NaN or Inf data are added.
708+
709+
Parameters
710+
----------
711+
check_data : bool
712+
Whether to run in check data mode.
713+
"""
714+
715+
self._check_data = check_data
716+
717+
688718
def _check_width_limits(self):
689719
"""Check and warn about peak width limits / frequency resolution interaction."""
690720

@@ -1101,8 +1131,7 @@ def _calc_error(self, metric=None):
11011131
raise ValueError(msg)
11021132

11031133

1104-
@staticmethod
1105-
def _prepare_data(freqs, power_spectrum, freq_range, spectra_dim=1, verbose=True):
1134+
def _prepare_data(self, freqs, power_spectrum, freq_range, spectra_dim=1):
11061135
"""Prepare input data for adding to current object.
11071136
11081137
Parameters
@@ -1116,8 +1145,6 @@ def _prepare_data(freqs, power_spectrum, freq_range, spectra_dim=1, verbose=True
11161145
Frequency range to restrict power spectrum to. If None, keeps the entire range.
11171146
spectra_dim : int, optional, default: 1
11181147
Dimensionality that the power spectra should have.
1119-
verbose : bool, optional
1120-
Whether to be verbose in printing out warnings.
11211148
11221149
Returns
11231150
-------
@@ -1172,7 +1199,7 @@ def _prepare_data(freqs, power_spectrum, freq_range, spectra_dim=1, verbose=True
11721199
# Aperiodic fit gets an inf if freq of 0 is included, which leads to an error
11731200
if freqs[0] == 0.0:
11741201
freqs, power_spectrum = trim_spectrum(freqs, power_spectrum, [freqs[1], freqs.max()])
1175-
if verbose:
1202+
if self.verbose:
11761203
print("\nFOOOF WARNING: Skipping frequency == 0, "
11771204
"as this causes a problem with fitting.")
11781205

@@ -1183,12 +1210,13 @@ def _prepare_data(freqs, power_spectrum, freq_range, spectra_dim=1, verbose=True
11831210
# Log power values
11841211
power_spectrum = np.log10(power_spectrum)
11851212

1186-
# Check if there are any infs / nans, and raise an error if so
1187-
if np.any(np.isinf(power_spectrum)) or np.any(np.isnan(power_spectrum)):
1188-
raise DataError("The input power spectra data, after logging, contains NaNs or Infs. "
1189-
"This will cause the fitting to fail. "
1190-
"One reason this can happen is if inputs are already logged. "
1191-
"Inputs data should be in linear spacing, not log.")
1213+
if self._check_data:
1214+
# Check if there are any infs / nans, and raise an error if so
1215+
if np.any(np.isinf(power_spectrum)) or np.any(np.isnan(power_spectrum)):
1216+
raise DataError("The input power spectra data, after logging, contains NaNs or Infs. "
1217+
"This will cause the fitting to fail. "
1218+
"One reason this can happen is if inputs are already logged. "
1219+
"Inputs data should be in linear spacing, not log.")
11921220

11931221
return freqs, power_spectrum, freq_range, freq_res
11941222

fooof/objs/group.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,7 @@ def add_data(self, freqs, power_spectra, freq_range=None):
222222
self._reset_group_results()
223223

224224
self.freqs, self.power_spectra, self.freq_range, self.freq_res = \
225-
self._prepare_data(freqs, power_spectra, freq_range, 2, self.verbose)
225+
self._prepare_data(freqs, power_spectra, freq_range, 2)
226226

227227

228228
def report(self, freqs=None, power_spectra=None, freq_range=None, n_jobs=1, progress=None):

0 commit comments

Comments
 (0)