Skip to content

Commit 1c3b030

Browse files
authored
Merge pull request #181 from fooof-tools/adds
[ENH] Update & fix (re-)add data and settings
2 parents e5cb7ce + 0c85181 commit 1c3b030

File tree

18 files changed

+90
-54
lines changed

18 files changed

+90
-54
lines changed

fooof/analysis/error.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def compute_pointwise_error_fg(fg, plot_errors=True, return_errors=False, **plt_
102102

103103

104104
def compute_pointwise_error(model, data):
105-
"""Calculate pointwise error between original data and a model fit of that data.
105+
"""Calculate point-wise error between original data and a model fit of that data.
106106
107107
Parameters
108108
----------

fooof/bands/bands.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
"""A data oject for managing band definitions."""
1+
"""A data object for managing band definitions."""
22

33
from collections import OrderedDict
44

@@ -60,7 +60,7 @@ def __len__(self):
6060
return self.n_bands
6161

6262
def __iter__(self):
63-
"""Define iteratation as stepping across each band."""
63+
"""Define iteration as stepping across each band."""
6464

6565
for label, band_definition in self.bands.items():
6666
yield (label, band_definition)

fooof/core/info.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def get_ap_indices(aperiodic_mode):
7070
Returns
7171
-------
7272
indices : dict
73-
Mapping of the column labels and indices for the aperiodc parameters.
73+
Mapping of the column labels and indices for the aperiodic parameters.
7474
"""
7575

7676
if aperiodic_mode == 'fixed':

fooof/core/io.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ def load_json(file_name, file_path):
198198

199199

200200
def load_jsonlines(file_name, file_path):
201-
"""Load a jsonlines file, yielding data line by line.
201+
"""Load a json-lines file, yielding data line by line.
202202
203203
Parameters
204204
----------

fooof/core/modutils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def docs_append_to_section(docstring, section, add):
7979
8080
Parameters
8181
----------
82-
ds : str
82+
docstring : str
8383
Docstring to update.
8484
section : str
8585
Name of the section within the docstring to add to.

fooof/core/reports.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
REPORT_FONT = {'family': 'monospace',
1717
'weight': 'normal',
1818
'size': 16}
19+
SAVE_FORMAT = 'pdf'
1920

2021
###################################################################################################
2122
###################################################################################################
@@ -61,7 +62,7 @@ def save_report_fm(fm, file_name, file_path=None, plt_log=False):
6162
ax2.set_yticks([])
6263

6364
# Save out the report
64-
plt.savefig(fpath(file_path, fname(file_name, 'pdf')))
65+
plt.savefig(fpath(file_path, fname(file_name, SAVE_FORMAT)))
6566
plt.close()
6667

6768

@@ -104,5 +105,5 @@ def save_report_fg(fg, file_name, file_path=None):
104105
plot_fg_peak_cens(fg, ax3)
105106

106107
# Save out the report
107-
plt.savefig(fpath(file_path, fname(file_name, 'pdf')))
108+
plt.savefig(fpath(file_path, fname(file_name, SAVE_FORMAT)))
108109
plt.close()

fooof/core/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ def check_inds(inds):
210210
This function works only on indices defined for 1 dimension.
211211
"""
212212

213-
# Typcasting: if a single int, convert to an array
213+
# Typecasting: if a single int, convert to an array
214214
if isinstance(inds, int):
215215
inds = np.array([inds])
216216
# Typecasting: if a list or range, convert to an array

fooof/objs/fit.py

Lines changed: 25 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -238,14 +238,10 @@ def _reset_internal_settings(self):
238238
# Bandwidth limits are given in 2-sided peak bandwidth
239239
# Convert to gaussian std parameter limits
240240
self._gauss_std_limits = tuple([bwl / 2 for bwl in self.peak_width_limits])
241-
# Bounds for aperiodic fitting. Drops bounds on knee parameter if not set to fit knee
242-
self._ap_bounds = self._ap_bounds if self.aperiodic_mode == 'knee' \
243-
else tuple(bound[0::2] for bound in self._ap_bounds)
244241

245242
# Otherwise, assume settings are unknown (have been cleared) and set to None
246243
else:
247244
self._gauss_std_limits = None
248-
self._ap_bounds = None
249245

250246

251247
def _reset_data_results(self, clear_freqs=False, clear_spectrum=False, clear_results=False):
@@ -286,7 +282,7 @@ def _reset_data_results(self, clear_freqs=False, clear_spectrum=False, clear_res
286282
self._peak_fit = None
287283

288284

289-
def add_data(self, freqs, power_spectrum, freq_range=None):
285+
def add_data(self, freqs, power_spectrum, freq_range=None, clear_results=True):
290286
"""Add data (frequencies, and power spectrum values) to the current object.
291287
292288
Parameters
@@ -298,17 +294,22 @@ def add_data(self, freqs, power_spectrum, freq_range=None):
298294
freq_range : list of [float, float], optional
299295
Frequency range to restrict power spectrum to.
300296
If not provided, keeps the entire range.
297+
clear_results : bool, optional, default: True
298+
Whether to clear prior results, if any are present in the object.
299+
This should only be set to False if data for the current results are being re-added.
301300
302301
Notes
303302
-----
304303
If called on an object with existing data and/or results
305304
they will be cleared by this method call.
306305
"""
307306

308-
# If any data is already present, then clear data & results
307+
# If any data is already present, then clear previous data
308+
# Also clear results, if present, unless indicated not to
309309
# This is to ensure object consistency of all data & results
310-
if np.any(self.freqs):
311-
self._reset_data_results(True, True, True)
310+
self._reset_data_results(clear_freqs=self.has_data,
311+
clear_spectrum=self.has_data,
312+
clear_results=self.has_model and clear_results)
312313

313314
self.freqs, self.power_spectrum, self.freq_range, self.freq_res = \
314315
self._prepare_data(freqs, power_spectrum, freq_range, 1, self.verbose)
@@ -717,6 +718,10 @@ def _simple_ap_fit(self, freqs, power_spectrum):
717718
np.log10(self.freqs[-1]) - np.log10(self.freqs[0]))
718719
if not self._ap_guess[2] else self._ap_guess[2]]
719720

721+
# Get bounds for aperiodic fitting, dropping knee bound if not set to fit knee
722+
ap_bounds = self._ap_bounds if self.aperiodic_mode == 'knee' \
723+
else tuple(bound[0::2] for bound in self._ap_bounds)
724+
720725
# Collect together guess parameters
721726
guess = np.array([off_guess + kne_guess + exp_guess])
722727

@@ -729,7 +734,7 @@ def _simple_ap_fit(self, freqs, power_spectrum):
729734
warnings.simplefilter("ignore")
730735
aperiodic_params, _ = curve_fit(get_ap_func(self.aperiodic_mode),
731736
freqs, power_spectrum, p0=guess,
732-
maxfev=self._maxfev, bounds=self._ap_bounds)
737+
maxfev=self._maxfev, bounds=ap_bounds)
733738
except RuntimeError:
734739
raise FitError("Model fitting failed due to not finding parameters in "
735740
"the simple aperiodic component fit.")
@@ -774,14 +779,18 @@ def _robust_ap_fit(self, freqs, power_spectrum):
774779
freqs_ignore = freqs[perc_mask]
775780
spectrum_ignore = power_spectrum[perc_mask]
776781

782+
# Get bounds for aperiodic fitting, dropping knee bound if not set to fit knee
783+
ap_bounds = self._ap_bounds if self.aperiodic_mode == 'knee' \
784+
else tuple(bound[0::2] for bound in self._ap_bounds)
785+
777786
# Second aperiodic fit - using results of first fit as guess parameters
778787
# See note in _simple_ap_fit about warnings
779788
try:
780789
with warnings.catch_warnings():
781790
warnings.simplefilter("ignore")
782791
aperiodic_params, _ = curve_fit(get_ap_func(self.aperiodic_mode),
783792
freqs_ignore, spectrum_ignore, p0=popt,
784-
maxfev=self._maxfev, bounds=self._ap_bounds)
793+
maxfev=self._maxfev, bounds=ap_bounds)
785794
except RuntimeError:
786795
raise FitError("Model fitting failed due to not finding "
787796
"parameters in the robust aperiodic fit.")
@@ -851,7 +860,7 @@ def _fit_peaks(self, flat_iter):
851860
guess_std = compute_gauss_std(fwhm)
852861

853862
except ValueError:
854-
# This procedure can fail (extremely rarely), if both le & ri ind's end up as None
863+
# This procedure can fail (very rarely), if both left & right inds end up as None
855864
# In this case, default the guess to the average of the peak width limits
856865
guess_std = np.mean(self.peak_width_limits)
857866

@@ -1027,21 +1036,21 @@ def _drop_peak_overlap(self, guess):
10271036
Notes
10281037
-----
10291038
For any gaussians with an overlap that crosses the threshold,
1030-
the lowest height guess guassian is dropped.
1039+
the lowest height guess Gaussian is dropped.
10311040
"""
10321041

1033-
# Sort the peak guesses by increasing frequency, so adjacenent peaks can
1034-
# be compared from right to left.
1042+
# Sort the peak guesses by increasing frequency
1043+
# This is so adjacent peaks can be compared from right to left
10351044
guess = sorted(guess, key=lambda x: float(x[0]))
10361045

10371046
# Calculate standard deviation bounds for checking amount of overlap
1038-
# The bounds are the gaussian frequncy +/- gaussian standard deviation
1047+
# The bounds are the gaussian frequency +/- gaussian standard deviation
10391048
bounds = [[peak[0] - peak[2] * self._gauss_overlap_thresh,
10401049
peak[0] + peak[2] * self._gauss_overlap_thresh] for peak in guess]
10411050

10421051
# Loop through peak bounds, comparing current bound to that of next peak
10431052
# If the left peak's upper bound extends pass the right peaks lower bound,
1044-
# Then drop the guassian with the lower height.
1053+
# then drop the Gaussian with the lower height
10451054
drop_inds = []
10461055
for ind, b_0 in enumerate(bounds[:-1]):
10471056
b_1 = bounds[ind + 1]

fooof/objs/group.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -557,7 +557,7 @@ def _check_width_limits(self):
557557
"""Check and warn about bandwidth limits / frequency resolution interaction."""
558558

559559
# Only check & warn on first power spectrum
560-
# This is to avoid spamming stdout for every spectrum in the group
560+
# This is to avoid spamming standard output for every spectrum in the group
561561
if self.power_spectra[0, 0] == self.power_spectrum[0]:
562562
super()._check_width_limits()
563563

fooof/objs/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ def combine_fooofs(fooofs):
138138
--------
139139
Combine FOOOF objects together (where `fm1`, `fm2` & `fm3` are assumed to be defined and fit):
140140
141-
>>> fg = combine_fooofs([fm1, fm2, f3]) # doctest:+SKIP
141+
>>> fg = combine_fooofs([fm1, fm2, fm3]) # doctest:+SKIP
142142
143143
Combine FOOOFGroup objects together (where `fg1` & `fg2` are assumed to be defined and fit):
144144

0 commit comments

Comments
 (0)