Skip to content

Commit a5e8e3e

Browse files
committed
Update tests for updates, and related fixes
1 parent 3af502b commit a5e8e3e

File tree

5 files changed

+67
-41
lines changed

5 files changed

+67
-41
lines changed

fooof/core/info.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,18 @@ def get_obj_desc():
99
Mapping of FOOOF object attributes, and what kind of data they are.
1010
"""
1111

12-
attributes = {'results' : ['aperiodic_params_', 'peak_params_', 'error_',
13-
'r_squared_', '_gaussian_params'],
14-
'settings' : ['peak_width_limits', 'max_n_peaks', 'min_peak_amplitude',
15-
'peak_threshold', 'aperiodic_mode'],
12+
attributes = {'results' : ['aperiodic_params_', 'peak_params_',
13+
'r_squared_', 'error_',
14+
'_gaussian_params'],
15+
'settings' : ['peak_width_limits', 'max_n_peaks',
16+
'min_peak_amplitude', 'peak_threshold',
17+
'aperiodic_mode'],
1618
'data' : ['power_spectrum', 'freq_range', 'freq_res'],
1719
'data_info' : ['freq_range', 'freq_res'],
1820
'arrays' : ['freqs', 'power_spectrum', 'aperiodic_params_',
19-
'peak_params_', '_gaussian_params']}
21+
'peak_params_', '_gaussian_params'],
22+
'model_components' : ['_spectrum_flat', '_spectrum_peak_rm',
23+
'_ap_fit', '_peak_fit']}
2024

2125
return attributes
2226

fooof/group.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -318,7 +318,7 @@ def get_fooof(self, ind, regenerate=False):
318318
fm.add_data(self.freqs, np.power(10, self.power_spectra[ind]))
319319
# If no power spectrum data available, copy over data information & regenerate freqs
320320
else:
321-
fm.add_data_info(self.get_data_info)
321+
fm.add_data_info(self.get_data_info())
322322

323323
# Add results for specified power spectrum, regenerating full fit if requested
324324
fm.add_results(self.group_results[ind])

fooof/tests/test_fit.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,12 @@ def test_adds():
128128
for setting in get_obj_desc()['settings']:
129129
assert getattr(tfm, setting) == getattr(fooof_settings, setting)
130130

131+
# Test adding data info
132+
fooof_data_info = FOOOFDataInfo([3, 40], 0.5)
133+
tfm.add_data_info(fooof_data_info)
134+
for data_info in get_obj_desc()['data_info']:
135+
assert getattr(tfm, data_info) == getattr(fooof_data_info, data_info)
136+
131137
# Test adding results
132138
fooof_results = FOOOFResults([1, 1], [10, 0.5, 0.5], 0.95, 0.02, [10, 0.5, 0.25])
133139
tfm.add_results(fooof_results)
@@ -180,12 +186,12 @@ def test_fooof_resets():
180186
tfm._reset_data_results()
181187
tfm._reset_internal_settings()
182188

183-
assert tfm.freqs is None and tfm.freq_range is None and tfm.freq_res is None \
184-
and tfm.power_spectrum is None and tfm.fooofed_spectrum_ is None and tfm._spectrum_flat is None \
185-
and tfm._spectrum_peak_rm is None and tfm._ap_fit is None and tfm._peak_fit is None
189+
desc = get_obj_desc()
186190

187-
# assert np.all(np.isnan(tfm.aperiodic_params_)) and np.all(np.isnan(tfm.peak_params_)) \
188-
# and np.all(np.isnan(tfm.r_squared_)) and np.all(np.isnan(tfm.error_)) and np.all(np.isnan(tfm._gaussian_params))
191+
for data in ['data', 'results', 'model_components']:
192+
for field in desc[data]:
193+
assert getattr(tfm, field) == None
194+
assert tfm.freqs == None and tfm.fooofed_spectrum_ == None
189195

190196
def test_fooof_report(skip_if_no_mpl):
191197
"""Check that running the top level model method runs."""

fooof/tests/test_funcs.py

Lines changed: 37 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -20,44 +20,51 @@ def test_combine_fooofs(tfm, tfg):
2020
tfg2 = tfg.copy(); tfg3 = tfg.copy()
2121

2222
# Check combining 2 FOOOFs
23-
fg1 = combine_fooofs([tfm, tfm2])
24-
assert fg1
25-
assert len(fg1) == 2
26-
assert compare_info([fg1, tfm], 'settings')
27-
assert fg1.group_results[0] == tfm.get_results()
28-
assert fg1.group_results[-1] == tfm2.get_results()
23+
nfg1 = combine_fooofs([tfm, tfm2])
24+
assert nfg1
25+
assert len(nfg1) == 2
26+
assert compare_info([nfg1, tfm], 'settings')
27+
assert nfg1.group_results[0] == tfm.get_results()
28+
assert nfg1.group_results[-1] == tfm2.get_results()
2929

3030
# Check combining 3 FOOOFs
31-
fg2 = combine_fooofs([tfm, tfm2, tfm3])
32-
assert fg2
33-
assert len(fg2) == 3
34-
assert compare_info([fg2, tfm], 'settings')
35-
assert fg2.group_results[0] == tfm.get_results()
36-
assert fg2.group_results[-1] == tfm3.get_results()
31+
nfg2 = combine_fooofs([tfm, tfm2, tfm3])
32+
assert nfg2
33+
assert len(nfg2) == 3
34+
assert compare_info([nfg2, tfm], 'settings')
35+
assert nfg2.group_results[0] == tfm.get_results()
36+
assert nfg2.group_results[-1] == tfm3.get_results()
3737

3838
# Check combining 2 FOOOFGroups
39-
nfg1 = combine_fooofs([tfg, tfg2])
40-
assert nfg1
41-
assert len(nfg1) == len(tfg) + len(tfg2)
42-
assert compare_info([nfg1, tfg, tfg2], 'settings')
43-
assert nfg1.group_results[0] == tfg.group_results[0]
44-
assert nfg1.group_results[-1] == tfg2.group_results[-1]
39+
nfg3 = combine_fooofs([tfg, tfg2])
40+
assert nfg3
41+
assert len(nfg3) == len(tfg) + len(tfg2)
42+
assert compare_info([nfg3, tfg, tfg2], 'settings')
43+
assert nfg3.group_results[0] == tfg.group_results[0]
44+
assert nfg3.group_results[-1] == tfg2.group_results[-1]
4545

4646
# Check combining 3 FOOOFGroups
47-
nfg2 = combine_fooofs([tfg, tfg2, tfg3])
48-
assert nfg2
49-
assert len(nfg2) == len(tfg) + len(tfg2) + len(tfg3)
50-
assert compare_info([nfg2, tfg, tfg2, tfg3], 'settings')
51-
assert nfg2.group_results[0] == tfg.group_results[0]
52-
assert nfg2.group_results[-1] == tfg3.group_results[-1]
47+
nfg4 = combine_fooofs([tfg, tfg2, tfg3])
48+
assert nfg4
49+
assert len(nfg4) == len(tfg) + len(tfg2) + len(tfg3)
50+
assert compare_info([nfg4, tfg, tfg2, tfg3], 'settings')
51+
assert nfg4.group_results[0] == tfg.group_results[0]
52+
assert nfg4.group_results[-1] == tfg3.group_results[-1]
5353

5454
# Check combining a mixture of FOOOF & FOOOFGroup
55-
mfg3 = combine_fooofs([tfg, tfm, tfg2, tfm2])
56-
assert mfg3
57-
assert len(mfg3) == len(tfg) + 1 + len(tfg2) + 1
58-
assert compare_info([tfg, tfm, tfg2, tfm2], 'settings')
59-
assert mfg3.group_results[0] == tfg.group_results[0]
60-
assert mfg3.group_results[-1] == tfm2.get_results()
55+
nfg5 = combine_fooofs([tfg, tfm, tfg2, tfm2])
56+
assert nfg5
57+
assert len(nfg5) == len(tfg) + 1 + len(tfg2) + 1
58+
assert compare_info([nfg5, tfg, tfm, tfg2, tfm2], 'settings')
59+
assert nfg5.group_results[0] == tfg.group_results[0]
60+
assert nfg5.group_results[-1] == tfm2.get_results()
61+
62+
# Check combining objects with no data
63+
tfm2._reset_data_results(False, True, True)
64+
tfg2._reset_data_results(False, True, True, True)
65+
nfg6 = combine_fooofs([tfm2, tfg2])
66+
assert len(nfg6) == 1 + len(tfg2)
67+
assert nfg6.power_spectra == None
6168

6269
def test_combine_errors(tfm, tfg):
6370

fooof/tests/test_group.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,3 +144,12 @@ def test_fg_get_fooof(tfg):
144144
# Check that regenerated model is created
145145
for result in desc['results']:
146146
assert np.all(getattr(tfm1, result))
147+
148+
# Test when object has no data (clear a copy of tfg)
149+
new_tfg = tfg.copy()
150+
new_tfg._reset_data_results(False, True, True, True)
151+
tfm2 = new_tfg.get_fooof(0, True)
152+
assert tfm2
153+
# Check that data info is copied over properly
154+
for data_info in desc['data_info']:
155+
assert getattr(tfm2, data_info)

0 commit comments

Comments
 (0)