Skip to content

Commit 5d089f9

Browse files
committed
Update get & comp info, add FOOOFDataInfo to consolidate appraoch to data
1 parent c2f32a0 commit 5d089f9

File tree

9 files changed

+175
-111
lines changed

9 files changed

+175
-111
lines changed

fooof/core/io.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def save_fm(fm, file_name, file_path=None, append=False,
8989
# Set and select which variables to keep. Use a set to drop any potential overlap
9090
# Note that results also saves frequency information to be able to recreate freq vector
9191
attributes = get_obj_desc()
92-
keep = set((attributes['results'] + attributes['freq_info'] if save_results else []) + \
92+
keep = set((attributes['results'] + attributes['data_info'] if save_results else []) + \
9393
(attributes['settings'] if save_settings else []) + \
9494
(attributes['data'] if save_data else []))
9595
obj_dict = dict_select_keys(obj_dict, keep)

fooof/core/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ def get_obj_desc():
114114
'settings' : ['peak_width_limits', 'max_n_peaks', 'min_peak_amplitude',
115115
'peak_threshold', 'aperiodic_mode'],
116116
'data' : ['power_spectrum', 'freq_range', 'freq_res'],
117-
'freq_info' : ['freq_range', 'freq_res'],
117+
'data_info' : ['freq_range', 'freq_res'],
118118
'arrays' : ['freqs', 'power_spectrum', 'aperiodic_params_',
119119
'peak_params_', '_gaussian_params']}
120120

fooof/data.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,20 @@
2626
"""
2727

2828

29+
FOOOFDataInfo = namedtuple('FOOOFDataInfo', ['freq_range', 'freq_res'])
30+
31+
FOOOFDataInfo.__doc__ = """\
32+
Data related information for a FOOOF object.
33+
34+
Attributes
35+
----------
36+
freq_range : list of [float, float]
37+
Frequency range of the power spectrum, as [lowest_freq, highest_freq].
38+
freq_res : float
39+
Frequency resolution of the power spectrum.
40+
"""
41+
42+
2943
FOOOFResults = namedtuple('FOOOFResults', ['aperiodic_params', 'peak_params',
3044
'r_squared', 'error', 'gaussian_params'])
3145
FOOOFResults.__doc__ = """\

fooof/fit.py

Lines changed: 25 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@
4646

4747
from fooof.plts.fm import plot_fm
4848
from fooof.utils import trim_spectrum
49-
from fooof.data import FOOOFResults, FOOOFSettings
49+
from fooof.data import FOOOFResults, FOOOFSettings, FOOOFDataInfo
5050
from fooof.synth.gen import gen_freqs, gen_aperiodic, gen_peaks
5151

5252
###################################################################################################
@@ -249,11 +249,8 @@ def add_settings(self, fooof_settings):
249249
An object containing the settings for a FOOOF model.
250250
"""
251251

252-
self.aperiodic_mode = fooof_settings.aperiodic_mode
253-
self.peak_width_limits = fooof_settings.peak_width_limits
254-
self.max_n_peaks = fooof_settings.max_n_peaks
255-
self.min_peak_amplitude = fooof_settings.min_peak_amplitude
256-
self.peak_threshold = fooof_settings.peak_threshold
252+
for setting in get_obj_desc()['settings']:
253+
setattr(self, setting, getattr(fooof_settings, setting))
257254

258255
self._check_loaded_settings(fooof_settings._asdict())
259256

@@ -421,31 +418,40 @@ def print_report_issue(concise=False):
421418
print(gen_issue_str(concise))
422419

423420

424-
def get_results(self):
425-
"""Return model fit parameters and goodness of fit metrics.
421+
def get_settings(self):
422+
"""Return user defined settings of the FOOOF object.
426423
427424
Returns
428425
-------
429-
FOOOFResults
430-
Object containing the FOOOF model fit results from the current FOOOF object.
426+
FOOOFSettings
427+
Object containing the settings from the current FOOOF object.
431428
"""
432429

433-
return FOOOFResults(self.aperiodic_params_, self.peak_params_,
434-
self.r_squared_, self.error_, self._gaussian_params)
430+
return FOOOFSettings(**{key : getattr(self, key) for key in get_obj_desc()['settings']})
435431

436432

437-
def get_settings(self):
438-
"""Return user defined settings of the FOOOF object.
433+
def get_data_info(self):
434+
"""Return data information from the FOOOF object.
439435
440436
Returns
441437
-------
442-
FOOOFSettings
443-
Object containing the settings from the current FOOOF object.
438+
FOOOFDataInfo
439+
Object containing information about the data from the current FOOOF object.
440+
"""
441+
442+
return FOOOFDataInfo(**{key : getattr(self, key) for key in get_obj_desc()['data_info']})
443+
444+
445+
def get_results(self):
446+
"""Return model fit parameters and goodness of fit metrics.
447+
448+
Returns
449+
-------
450+
FOOOFResults
451+
Object containing the FOOOF model fit results from the current FOOOF object.
444452
"""
445453

446-
return FOOOFSettings(self.peak_width_limits, self.max_n_peaks,
447-
self.min_peak_amplitude, self.peak_threshold,
448-
self.aperiodic_mode)
454+
return FOOOFResults(**{key.strip('_') : getattr(self, key) for key in get_obj_desc()['results']})
449455

450456

451457
@copy_doc_func_to_method(plot_fm)

fooof/funcs.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from fooof import FOOOFGroup
66
from fooof.synth.gen import gen_freqs
7-
from fooof.utils import get_settings, get_obj_desc, compare_settings, compare_data_info
7+
from fooof.utils import get_obj_desc, compare_info
88

99
###################################################################################################
1010
###################################################################################################
@@ -24,12 +24,12 @@ def combine_fooofs(fooofs):
2424
"""
2525

2626
# Compare settings
27-
if not compare_settings(fooofs) or not compare_data_info(fooofs):
27+
if not compare_info(fooofs, 'settings') or not compare_info(fooofs, 'data_info'):
2828
raise ValueError("These objects have incompatible settings or data," \
2929
"and so cannot be combined.")
3030

3131
# Initialize FOOOFGroup object, with settings derived from input objects
32-
fg = FOOOFGroup(**get_settings(fooofs[0]), verbose=fooofs[0].verbose)
32+
fg = FOOOFGroup(*fooofs[0].get_settings(), verbose=fooofs[0].verbose)
3333
fg.power_spectra = np.empty([0, len(fooofs[0].freqs)])
3434

3535
# Add FOOOF results from each FOOOF object to group
@@ -44,7 +44,7 @@ def combine_fooofs(fooofs):
4444
fg.power_spectra = np.vstack([fg.power_spectra, f_obj.power_spectrum])
4545

4646
# Add data information information
47-
for data_info in get_obj_desc()['freq_info']:
47+
for data_info in get_obj_desc()['data_info']:
4848
setattr(fg, data_info, getattr(fooofs[0], data_info))
4949
fg.freqs = gen_freqs(fg.freq_range, fg.freq_res)
5050

fooof/tests/test_fit.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import pkg_resources as pkg
1313

1414
from fooof import FOOOF
15-
from fooof.data import FOOOFSettings, FOOOFResults
15+
from fooof.data import FOOOFSettings, FOOOFDataInfo, FOOOFResults
1616
from fooof.synth import gen_power_spectrum
1717
from fooof.core.utils import group_three, get_obj_desc
1818

@@ -136,11 +136,13 @@ def test_adds():
136136
def test_gets(tfm):
137137
"""Tests methods that return FOOOF data objects.
138138
139-
Checks: get_settings, get_results
139+
Checks: get_settings, get_data_info, get_results
140140
"""
141141

142142
settings = tfm.get_settings()
143143
assert isinstance(settings, FOOOFSettings)
144+
data_info = tfm.get_data_info()
145+
assert isinstance(data_info, FOOOFDataInfo)
144146
results = tfm.get_results()
145147
assert isinstance(results, FOOOFResults)
146148

@@ -152,10 +154,11 @@ def test_copy():
152154

153155
assert tfm != ntfm
154156

155-
def test_fooof_prints_get(tfm):
156-
"""Test methods that print, return results (alias and pass through methods).
157+
def test_fooof_prints(tfm):
158+
"""Test methods that print (alias and pass through methods).
157159
158-
Checks: print_settings, print_results, get_results, get_settings."""
160+
Checks: print_settings, print_results.
161+
"""
159162

160163
tfm.print_settings()
161164
tfm.print_results()

fooof/tests/test_funcs.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import numpy as np
66

7-
from fooof.utils import compare_settings
7+
from fooof.utils import compare_info
88
from fooof.group import FOOOFGroup
99
from fooof.synth import gen_group_power_spectra
1010

@@ -23,39 +23,39 @@ def test_combine_fooofs(tfm, tfg):
2323
fg1 = combine_fooofs([tfm, tfm2])
2424
assert fg1
2525
assert len(fg1) == 2
26-
assert compare_settings([fg1, tfm])
26+
assert compare_info([fg1, tfm], 'settings')
2727
assert fg1.group_results[0] == tfm.get_results()
2828
assert fg1.group_results[-1] == tfm2.get_results()
2929

3030
# Check combining 3 FOOOFs
3131
fg2 = combine_fooofs([tfm, tfm2, tfm3])
3232
assert fg2
3333
assert len(fg2) == 3
34-
assert compare_settings([fg2, tfm])
34+
assert compare_info([fg2, tfm], 'settings')
3535
assert fg2.group_results[0] == tfm.get_results()
3636
assert fg2.group_results[-1] == tfm3.get_results()
3737

3838
# Check combining 2 FOOOFGroups
3939
nfg1 = combine_fooofs([tfg, tfg2])
4040
assert nfg1
4141
assert len(nfg1) == len(tfg) + len(tfg2)
42-
assert compare_settings([nfg1, tfg, tfg2])
42+
assert compare_info([nfg1, tfg, tfg2], 'settings')
4343
assert nfg1.group_results[0] == tfg.group_results[0]
4444
assert nfg1.group_results[-1] == tfg2.group_results[-1]
4545

4646
# Check combining 3 FOOOFGroups
4747
nfg2 = combine_fooofs([tfg, tfg2, tfg3])
4848
assert nfg2
4949
assert len(nfg2) == len(tfg) + len(tfg2) + len(tfg3)
50-
assert compare_settings([nfg2, tfg, tfg2, tfg3])
50+
assert compare_info([nfg2, tfg, tfg2, tfg3], 'settings')
5151
assert nfg2.group_results[0] == tfg.group_results[0]
5252
assert nfg2.group_results[-1] == tfg3.group_results[-1]
5353

5454
# Check combining a mixture of FOOOF & FOOOFGroup
5555
mfg3 = combine_fooofs([tfg, tfm, tfg2, tfm2])
5656
assert mfg3
5757
assert len(mfg3) == len(tfg) + 1 + len(tfg2) + 1
58-
assert compare_settings([tfg, tfm, tfg2, tfm2])
58+
assert compare_info([tfg, tfm, tfg2, tfm2], 'settings')
5959
assert mfg3.group_results[0] == tfg.group_results[0]
6060
assert mfg3.group_results[-1] == tfm2.get_results()
6161

fooof/tests/test_utils.py

Lines changed: 8 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -17,35 +17,23 @@ def test_trim_spectrum():
1717
assert np.array_equal(f_out, np.array([2., 3., 4.]))
1818
assert np.array_equal(p_out, np.array([3., 4., 5.]))
1919

20-
def test_get_settings(tfm, tfg):
20+
def test_get_info(tfm, tfg):
2121

2222
for f_obj in [tfm, tfg]:
23-
assert get_settings(f_obj)
23+
assert get_info(f_obj, 'settings')
24+
assert get_info(f_obj, 'data_info')
2425

25-
def test_get_data_info(tfm, tfg):
26+
def test_compare_info(tfm, tfg):
2627

2728
for f_obj in [tfm, tfg]:
28-
assert get_data_info(f_obj)
2929

30-
def test_compare_settings(tfm, tfg):
31-
32-
for f_obj in [tfm, tfg]:
3330
f_obj2 = f_obj.copy()
3431

35-
assert compare_settings([f_obj, f_obj2])
36-
32+
assert compare_info([f_obj, f_obj2], 'settings')
3733
f_obj2.peak_width_limits = [2, 4]
3834
f_obj2._reset_internal_settings()
35+
assert not compare_info([f_obj, f_obj2], 'settings')
3936

40-
assert not compare_settings([f_obj, f_obj2])
41-
42-
def test_compare_data_info(tfm, tfg):
43-
44-
for f_obj in [tfm, tfg]:
45-
f_obj2 = f_obj.copy()
46-
47-
assert compare_data_info([f_obj, f_obj2])
48-
37+
assert compare_info([f_obj, f_obj2], 'data_info')
4938
f_obj2.freq_range = [5, 25]
50-
51-
assert not compare_data_info([f_obj, f_obj2])
39+
assert not compare_info([f_obj, f_obj2], 'data_info')

0 commit comments

Comments
 (0)