Skip to content

Commit c4ed7df

Browse files
committed
Add function to average across a FOOOFgroup
1 parent 34aedb1 commit c4ed7df

File tree

2 files changed

+36
-14
lines changed

2 files changed

+36
-14
lines changed

fooof/funcs.py

Lines changed: 31 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,31 +3,48 @@
33
import numpy as np
44

55
from fooof import FOOOF, FOOOFGroup
6-
from fooof.synth.gen import gen_freqs
6+
from fooof.data import FOOOFResults
77
from fooof.utils import compare_info
8+
from fooof.synth.gen import gen_freqs
9+
from fooof.analysis import get_band_peaks_fg
810

911
###################################################################################################
1012
###################################################################################################
1113

12-
def average_fg(fg, bands, avg='mean'):
13-
"""Average across a FOOOFGroup object."""
14+
def average_fg(fg, bands, avg_method='mean'):
15+
"""Average across a FOOOFGroup object.
1416
15-
if avg == 'mean':
16-
avg_func = np.nanmean
17-
elif avg == 'median':
18-
avg_func = np.nanmedian
17+
Parameters
18+
----------
19+
fg : FOOOFGroup
20+
A FOOOFGroup object with data to average across.
21+
bands : Bands
22+
Bands object that defines the frequency bands to collapse peaks across.
23+
avg : {'mean', 'median'}
24+
Averaging function to use.
1925
20-
ap_params = avg_func(fg.get_all_data('aperiodic_params'), 0)
26+
Returns
27+
-------
28+
fm : FOOOF
29+
FOOOF object containing the average results from the FOOOFGroup input.
30+
"""
2131

22-
peak_params, gaussian_params = np.empty([0, 3]), np.empty([0, 3])
32+
if avg_method not in ['mean', 'median']:
33+
raise ValueError('Requested average method not understood.')
34+
if not len(fg):
35+
raise ValueError('Input FOOOFGroup has no fit results - can not proceed.')
2336

24-
for label, band in bands:
37+
if avg_method == 'mean':
38+
avg_func = np.nanmean
39+
elif avg_method == 'median':
40+
avg_func = np.nanmedian
2541

26-
peak_params = np.vstack([peak_params,
27-
avg_func(get_band_peak_group(fg.get_all_data('peak_params'), band, len(fg)), 0)])
42+
ap_params = avg_func(fg.get_all_data('aperiodic_params'), 0)
2843

29-
gaussian_params = np.vstack([gaussian_params,
30-
avg_func(get_band_peak_group(fg.get_all_data('gaussian_params'), band, len(fg)), 0)])
44+
peak_params = np.array([avg_func(get_band_peaks_fg(fg, band, 'peak_params'), 0) \
45+
for label, band in bands])
46+
gaussian_params = np.array([avg_func(get_band_peaks_fg(fg, band, 'gaussian_params'), 0) \
47+
for label, band in bands])
3148

3249
r2 = avg_func(fg.get_all_data('r_squared'))
3350
error = avg_func(fg.get_all_data('error'))

fooof/tests/test_funcs.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,11 @@
1414
###################################################################################################
1515
###################################################################################################
1616

17+
def test_average_fg(tfg, tbands):
18+
19+
nfm = average_fg(tfg, tbands)
20+
assert nfm
21+
1722
def test_combine_fooofs(tfm, tfg):
1823

1924
tfm2 = tfm.copy(); tfm3 = tfm.copy()

0 commit comments

Comments
 (0)