Skip to content

Commit 7a520c4

Browse files
authored
Merge pull request #182 from fooof-tools/nans
[ENH] Add `check_data` mode to deal with running with/without nan values
2 parents 1c3b030 + 815ae18 commit 7a520c4

File tree

4 files changed

+73
-21
lines changed

4 files changed

+73
-21
lines changed

fooof/objs/fit.py

Lines changed: 46 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,15 @@
3939
The maximum number of calls to the curve fitting function.
4040
_error_metric : str
4141
The error metric to use for post-hoc measures of model fit error.
42+
43+
Run Modes
44+
---------
4245
_debug : bool
4346
Whether the object is set in debug mode.
4447
This should be controlled by using the `set_debug_mode` method.
48+
_check_data : bool
49+
Whether to check added data for NaN or Inf values, and fail out if present.
50+
This should be controlled by using the `set_check_data_mode` method.
4551
4652
Code Notes
4753
----------
@@ -184,12 +190,16 @@ def __init__(self, peak_width_limits=(0.5, 12.0), max_n_peaks=np.inf, min_peak_h
184190
# The maximum number of calls to the curve fitting function
185191
self._maxfev = 5000
186192
# The error metric to calculate, post model fitting. See `_calc_error` for options
187-
# Note: this is used to check error post-hoc, not an objective function for fitting models
193+
# Note: this is for checking error post fitting, not an objective function for fitting
188194
self._error_metric = 'MAE'
189-
# Set whether in debug mode, in which an error is raised if a model fit fails
195+
196+
## RUN MODES
197+
# Set default debug mode - controls if an error is raised if model fitting is unsuccessful
190198
self._debug = False
199+
# Set default check data mode - controls if an error is raised if NaN / Inf data are added
200+
self._check_data = True
191201

192-
# Set internal settings, based on inputs, & initialize data & results attributes
202+
# Set internal settings, based on inputs, and initialize data & results attributes
193203
self._reset_internal_settings()
194204
self._reset_data_results(True, True, True)
195205

@@ -312,7 +322,7 @@ def add_data(self, freqs, power_spectrum, freq_range=None, clear_results=True):
312322
clear_results=self.has_model and clear_results)
313323

314324
self.freqs, self.power_spectrum, self.freq_range, self.freq_res = \
315-
self._prepare_data(freqs, power_spectrum, freq_range, 1, self.verbose)
325+
self._prepare_data(freqs, power_spectrum, freq_range, 1)
316326

317327

318328
def add_settings(self, fooof_settings):
@@ -432,6 +442,14 @@ def fit(self, freqs=None, power_spectrum=None, freq_range=None):
432442
# In rare cases, the model fails to fit, and so uses try / except
433443
try:
434444

445+
# If not set to fail on NaN or Inf data at add time, check data here
446+
# This serves as a catch all for curve_fits which will fail given NaN or Inf
447+
# Because FitError's are by default caught, this allows fitting to continue
448+
if not self._check_data:
449+
if np.any(np.isinf(self.power_spectrum)) or np.any(np.isnan(self.power_spectrum)):
450+
raise FitError("Model fitting was skipped because there are NaN or Inf "
451+
"values in the data, which preclude model fitting.")
452+
435453
# Fit the aperiodic component
436454
self.aperiodic_params_ = self._robust_ap_fit(self.freqs, self.power_spectrum)
437455
self._ap_fit = gen_aperiodic(self.freqs, self.aperiodic_params_)
@@ -675,7 +693,7 @@ def copy(self):
675693

676694

677695
def set_debug_mode(self, debug):
678-
"""Set whether debug mode, wherein an error is raised if fitting is unsuccessful.
696+
"""Set debug mode, which controls if an error is raised if model fitting is unsuccessful.
679697
680698
Parameters
681699
----------
@@ -686,6 +704,18 @@ def set_debug_mode(self, debug):
686704
self._debug = debug
687705

688706

707+
def set_check_data_mode(self, check_data):
708+
"""Set check data mode, which controls if an error is raised if NaN or Inf data are added.
709+
710+
Parameters
711+
----------
712+
check_data : bool
713+
Whether to run in check data mode.
714+
"""
715+
716+
self._check_data = check_data
717+
718+
689719
def _check_width_limits(self):
690720
"""Check and warn about peak width limits / frequency resolution interaction."""
691721

@@ -795,7 +825,8 @@ def _robust_ap_fit(self, freqs, power_spectrum):
795825
raise FitError("Model fitting failed due to not finding "
796826
"parameters in the robust aperiodic fit.")
797827
except TypeError:
798-
raise FitError("Model fitting failed due to sub-sampling in the robust aperiodic fit.")
828+
raise FitError("Model fitting failed due to sub-sampling "
829+
"in the robust aperiodic fit.")
799830

800831
return aperiodic_params
801832

@@ -1110,8 +1141,7 @@ def _calc_error(self, metric=None):
11101141
raise ValueError(msg)
11111142

11121143

1113-
@staticmethod
1114-
def _prepare_data(freqs, power_spectrum, freq_range, spectra_dim=1, verbose=True):
1144+
def _prepare_data(self, freqs, power_spectrum, freq_range, spectra_dim=1):
11151145
"""Prepare input data for adding to current object.
11161146
11171147
Parameters
@@ -1125,8 +1155,6 @@ def _prepare_data(freqs, power_spectrum, freq_range, spectra_dim=1, verbose=True
11251155
Frequency range to restrict power spectrum to. If None, keeps the entire range.
11261156
spectra_dim : int, optional, default: 1
11271157
Dimensionality that the power spectra should have.
1128-
verbose : bool, optional
1129-
Whether to be verbose in printing out warnings.
11301158
11311159
Returns
11321160
-------
@@ -1181,7 +1209,7 @@ def _prepare_data(freqs, power_spectrum, freq_range, spectra_dim=1, verbose=True
11811209
# Aperiodic fit gets an inf if freq of 0 is included, which leads to an error
11821210
if freqs[0] == 0.0:
11831211
freqs, power_spectrum = trim_spectrum(freqs, power_spectrum, [freqs[1], freqs.max()])
1184-
if verbose:
1212+
if self.verbose:
11851213
print("\nFOOOF WARNING: Skipping frequency == 0, "
11861214
"as this causes a problem with fitting.")
11871215

@@ -1192,12 +1220,13 @@ def _prepare_data(freqs, power_spectrum, freq_range, spectra_dim=1, verbose=True
11921220
# Log power values
11931221
power_spectrum = np.log10(power_spectrum)
11941222

1195-
# Check if there are any infs / nans, and raise an error if so
1196-
if np.any(np.isinf(power_spectrum)) or np.any(np.isnan(power_spectrum)):
1197-
raise DataError("The input power spectra data, after logging, contains NaNs or Infs. "
1198-
"This will cause the fitting to fail. "
1199-
"One reason this can happen is if inputs are already logged. "
1200-
"Inputs data should be in linear spacing, not log.")
1223+
if self._check_data:
1224+
# Check if there are any infs / nans, and raise an error if so
1225+
if np.any(np.isinf(power_spectrum)) or np.any(np.isnan(power_spectrum)):
1226+
raise DataError("The input power spectra data, after logging, contains NaNs or Infs. "
1227+
"This will cause the fitting to fail. "
1228+
"One reason this can happen is if inputs are already logged. "
1229+
"Inputs data should be in linear spacing, not log.")
12011230

12021231
return freqs, power_spectrum, freq_range, freq_res
12031232

fooof/objs/group.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,7 @@ def add_data(self, freqs, power_spectra, freq_range=None):
222222
self._reset_group_results()
223223

224224
self.freqs, self.power_spectra, self.freq_range, self.freq_res = \
225-
self._prepare_data(freqs, power_spectra, freq_range, 2, self.verbose)
225+
self._prepare_data(freqs, power_spectra, freq_range, 2)
226226

227227

228228
def report(self, freqs=None, power_spectra=None, freq_range=None, n_jobs=1, progress=None):
@@ -476,8 +476,9 @@ def get_fooof(self, ind, regenerate=True):
476476
The FOOOFResults data loaded into a FOOOF object.
477477
"""
478478

479-
# Initialize a FOOOF object, with same settings as current FOOOFGroup
479+
# Initialize a FOOOF object, with same settings & check data mode as current FOOOFGroup
480480
fm = FOOOF(*self.get_settings(), verbose=self.verbose)
481+
fm.set_check_data_mode(self._check_data)
481482

482483
# Add data for specified single power spectrum, if available
483484
# The power spectrum is inverted back to linear, as it is re-logged when added to FOOOF

fooof/objs/utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,9 @@ def combine_fooofs(fooofs):
178178
if len(fg) == temp_power_spectra.shape[0]:
179179
fg.power_spectra = temp_power_spectra
180180

181+
# Set the check data mode, as True if any of the inputs have it on, False otherwise
182+
fg.set_check_data_mode(any([getattr(f_obj, '_check_data') for f_obj in fooofs]))
183+
181184
# Add data information information
182185
fg.add_meta_data(fooofs[0].get_meta_data())
183186

fooof/tests/objs/test_fit.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from fooof.core.items import OBJ_DESC
1313
from fooof.core.errors import FitError
1414
from fooof.core.utils import group_three
15-
from fooof.sim import gen_power_spectrum
15+
from fooof.sim import gen_freqs, gen_power_spectrum
1616
from fooof.data import FOOOFSettings, FOOOFMetaData, FOOOFResults
1717
from fooof.core.errors import DataError, NoDataError, InconsistentDataError
1818

@@ -396,7 +396,7 @@ def raise_runtime_error(*args, **kwargs):
396396
assert np.all(np.isnan(getattr(tfm, result)))
397397

398398
def test_fooof_debug():
399-
"""Test FOOOF fit failure in debug mode."""
399+
"""Test FOOOF in debug mode, including with fit failures."""
400400

401401
tfm = FOOOF(verbose=False)
402402
tfm._maxfev = 5
@@ -406,3 +406,22 @@ def test_fooof_debug():
406406

407407
with raises(FitError):
408408
tfm.fit(*gen_power_spectrum([3, 50], [50, 2], [10, 0.5, 2, 20, 0.3, 4]))
409+
410+
def test_fooof_check_data():
411+
"""Test FOOOF in with check data mode turned off, including with NaN data."""
412+
413+
tfm = FOOOF(verbose=False)
414+
415+
tfm.set_check_data_mode(False)
416+
assert tfm._check_data is False
417+
418+
# Add data, with check data turned off
419+
# In check data mode, adding data with NaN should run
420+
freqs = gen_freqs([3, 50], 0.5)
421+
powers = np.ones_like(freqs) * np.nan
422+
tfm.add_data(freqs, powers)
423+
assert tfm.has_data
424+
425+
# Model fitting should execute, but return a null model fit, given the NaNs, without failing
426+
tfm.fit()
427+
assert not tfm.has_model

0 commit comments

Comments
 (0)