Skip to content

Commit 39286c2

Browse files
authored
Merge pull request #139 from fooof-tools/bands
Add support for managing bands & averaging over FOOOFGroup objects
2 parents 1a237f4 + c4ed7df commit 39286c2

File tree

8 files changed

+326
-27
lines changed

8 files changed

+326
-27
lines changed

fooof/analysis.py

Lines changed: 65 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,64 @@
1-
"""Basic analysis functions for FOOOF results."""
1+
"""Analysis functions for FOOOF results."""
22

33
import numpy as np
44

55
###################################################################################################
66
###################################################################################################
77

8-
def get_band_peak_group(peak_params, band_def, n_fits):
9-
"""Extracts peaks within a given band of interest, for a group of FOOOF model fits.
8+
def get_band_peak_fm(fm, band, ret_one=True, attribute='peak_params'):
9+
"""Extract peaks from a band of interest from a FOOOFGroup object.
10+
11+
Parameters
12+
----------
13+
fm : FOOOF
14+
FOOOF object to extract peak data from.
15+
band : tuple of (float, float)
16+
Defines the band of interest, as (lower_frequency_bound, upper_frequency_bound).
17+
ret_one : bool, optional, default: True
18+
Whether to return single peak (if True) or all peaks within the range found (if False).
19+
If True, returns the highest power peak within the search range.
20+
attribute : {'peak_params', 'gaussian_params'}
21+
Which attribute of peak data to extract data from.
22+
23+
Returns
24+
-------
25+
1d or 2d array
26+
Peak data. Each row is a peak, as [CF, Amp, BW]
27+
"""
28+
29+
return get_band_peak(getattr(fm, attribute + '_'), band, ret_one)
30+
31+
32+
def get_band_peaks_fg(fg, band, attribute='peak_params'):
33+
"""Extract peaks from a band of interest from a FOOOF object.
34+
35+
Parameters
36+
----------
37+
fg : FOOOFGroup
38+
FOOOFGroup object to extract peak data from.
39+
band : tuple of (float, float)
40+
Defines the band of interest, as (lower_frequency_bound, upper_frequency_bound).
41+
attribute : {'peak_params', 'gaussian_params'}
42+
Which attribute of peak data to extract data from.
43+
44+
Returns
45+
-------
46+
2d array
47+
Peak data. Each row is a peak, as [CF, Amp, BW].
48+
"""
49+
50+
return get_band_peaks_group(fg.get_all_data(attribute), band, len(fg))
51+
52+
53+
def get_band_peaks_group(peak_params, band, n_fits):
54+
"""Extracts peaks within a given band of interest.
1055
1156
Parameters
1257
----------
1358
peak_params : 2d array
1459
Peak parameters, for a group fit, from FOOOF, with shape of [n_peaks, 4].
15-
band_def : [float, float]
16-
Defines the band of interest, as [lower_frequency_bound, upper_frequency_bound].
60+
band : tuple of (float, float)
61+
Defines the band of interest, as (lower_frequency_bound, upper_frequency_bound).
1762
n_fits : int
1863
The number of model fits in the FOOOFGroup data.
1964
@@ -34,32 +79,32 @@ def get_band_peak_group(peak_params, band_def, n_fits):
3479
3580
>>> peaks = np.empty((0, 3))
3681
>>> for f_res in fg:
37-
>>> peaks = np.vstack((peaks, get_band_peak(f_res.peak_params, band_def, ret_one=False)))
82+
>>> peaks = np.vstack((peaks, get_band_peak(f_res.peak_params, band, ret_one=False)))
3883
3984
"""
4085

4186
band_peaks = np.zeros(shape=[n_fits, 3])
42-
for ind in range(n_fits):
4387

44-
# Extacts an array per FOOOF fit, and extracts band peaks from it
88+
# Extacts an array per FOOOF fit, and extracts band peaks from it
89+
for ind in range(n_fits):
4590
band_peaks[ind, :] = get_band_peak(peak_params[tuple([peak_params[:, -1] == ind])][:, 0:3],
46-
band_def=band_def, ret_one=True)
91+
band=band, ret_one=True)
4792

4893
return band_peaks
4994

5095

51-
def get_band_peak(peak_params, band_def, ret_one=True):
52-
"""Extracts peaks within a given band of interest, for a FOOOF model fit.
96+
def get_band_peak(peak_params, band, ret_one=True):
97+
"""Extracts peaks within a given band of interest.
5398
5499
Parameters
55100
----------
56101
peak_params : 2d array
57102
Peak parameters, from FOOOF, with shape of [n_peaks, 3].
58-
band_def : [float, float]
59-
Defines the band of interest, as [lower_frequency_bound, upper_frequency_bound].
103+
band : tuple of (float, float)
104+
Defines the band of interest, as (lower_frequency_bound, upper_frequency_bound).
60105
ret_one : bool, optional, default: True
61106
Whether to return single peak (if True) or all peaks within the range found (if False).
62-
If True, returns the highest amplitude peak within the search range.
107+
If True, returns the highest power peak within the search range.
63108
64109
Returns
65110
-------
@@ -72,7 +117,7 @@ def get_band_peak(peak_params, band_def, ret_one=True):
72117
return np.array([np.nan, np.nan, np.nan])
73118

74119
# Find indices of peaks in the specified range, and check the number found
75-
peak_inds = (peak_params[:, 0] >= band_def[0]) & (peak_params[:, 0] <= band_def[1])
120+
peak_inds = (peak_params[:, 0] >= band[0]) & (peak_params[:, 0] <= band[1])
76121
n_peaks = sum(peak_inds)
77122

78123
# If there are no peaks within the specified range
@@ -82,12 +127,12 @@ def get_band_peak(peak_params, band_def, ret_one=True):
82127

83128
band_peaks = peak_params[peak_inds, :]
84129

85-
# If results > 1 and ret_one, then we return the highest amplitude peak
130+
# If results > 1 and ret_one, then we return the highest power peak
86131
# Call a sub-function to select highest power peak in band
87132
if n_peaks > 1 and ret_one:
88133
band_peaks = get_highest_amp_peak(band_peaks)
89134

90-
# If results == 1, return peak - [cen, power, bw]
135+
# If results == 1, return single peak
91136
return np.squeeze(band_peaks)
92137

93138

@@ -96,13 +141,13 @@ def get_highest_amp_peak(band_peaks):
96141
97142
Parameters
98143
----------
99-
peak_params : 2d array
144+
band_peaks : 2d array
100145
Peak parameters, from FOOOF, with shape of [n_peaks, 3].
101146
102147
Returns
103148
-------
104-
band_peaks : array
105-
Peak data. Each row is a peak, as [CF, Amp, BW].
149+
1d array
150+
Seleced peak data. Row is a peak, as [CF, Amp, BW].
106151
"""
107152

108153
# Catch & return NaN if empty

fooof/bands.py

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
"""A class for managing band definitions."""
2+
3+
from collections import OrderedDict
4+
5+
###################################################################################################
6+
###################################################################################################
7+
8+
class Bands():
9+
"""Class to hold bands definitions.
10+
11+
Attributes
12+
----------
13+
bands : dict
14+
Dictionary of band definitions.
15+
Each entry should be as {'label' : (f_low, f_high)}.
16+
"""
17+
18+
def __init__(self, input_bands={}):
19+
"""Initialize the Bands object.
20+
21+
Parameters
22+
----------
23+
input_bands : dict, optional
24+
A dictionary of oscillation bands to use.
25+
"""
26+
27+
self.bands = OrderedDict()
28+
29+
for label, band_def in input_bands.items():
30+
self.add_band(label, band_def)
31+
32+
def __getitem__(self, label):
33+
34+
try:
35+
return self.bands[label]
36+
except KeyError:
37+
message = "The label '{}' was not found in the defined bands.".format(label)
38+
raise BandNotDefinedError(message) from None
39+
40+
def __getattr__(self, label):
41+
42+
return self.__getitem__(label)
43+
44+
def __repr__(self):
45+
46+
return '\n'.join(['{:8} : {:2} - {:2} Hz'.format(key, *val) \
47+
for key, val in self.bands.items()])
48+
49+
def __len__(self):
50+
51+
return self.n_bands
52+
53+
def __iter__(self):
54+
55+
for label, band_definition in self.bands.items():
56+
yield (label, band_definition)
57+
58+
@property
59+
def labels(self):
60+
"""Get the labels for all bands defined in the object."""
61+
62+
return list(self.bands.keys())
63+
64+
@property
65+
def n_bands(self):
66+
"""Get the number of bands defined in the object."""
67+
68+
return len(self.bands)
69+
70+
71+
def add_band(self, label, band_definition):
72+
"""Add a new oscillation band definition.
73+
74+
Parameters
75+
----------
76+
label : str
77+
Band label to add.
78+
band_definition : tuple of (float, float)
79+
The lower and upper frequency limit of the band, in Hz.
80+
"""
81+
82+
self._check_band(label, band_definition)
83+
self.bands[label] = band_definition
84+
85+
86+
def remove_band(self, label):
87+
"""Remove a previously defined oscillation band.
88+
89+
Parameters
90+
----------
91+
label : str
92+
Band label to remove from band definitions.
93+
"""
94+
95+
self.bands.pop(label)
96+
97+
98+
@staticmethod
99+
def _check_band(label, band_definition):
100+
"""Check that a proposed band definition is valid.
101+
102+
Parameters
103+
----------
104+
label : str
105+
The name of the new band.
106+
band_definition : tuple of (float, float)
107+
The lower and upper frequency limit of the band, in Hz.
108+
109+
Raises
110+
------
111+
InconsistentDataError
112+
If band definition is not properly formatted.
113+
"""
114+
115+
# Check that band name is a string
116+
if not isinstance(label, str):
117+
raise InconsistentDataError('Band name definition is not a string.')
118+
119+
# Check that band limits has the right size
120+
if not len(band_definition) == 2:
121+
raise InconsistentDataError('Band limit definition is not the right size.')
122+
123+
# Safety check that limits are in correct order
124+
if not band_definition[0] < band_definition[1]:
125+
raise InconsistentDataError('Band limit definitions are invalid.')
126+
127+
128+
class BandNotDefinedError(Exception):
129+
pass
130+
131+
class InconsistentDataError(Exception):
132+
pass

fooof/funcs.py

Lines changed: 53 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,64 @@
22

33
import numpy as np
44

5-
from fooof import FOOOFGroup
6-
from fooof.synth.gen import gen_freqs
5+
from fooof import FOOOF, FOOOFGroup
6+
from fooof.data import FOOOFResults
77
from fooof.utils import compare_info
8+
from fooof.synth.gen import gen_freqs
9+
from fooof.analysis import get_band_peaks_fg
810

911
###################################################################################################
1012
###################################################################################################
1113

14+
def average_fg(fg, bands, avg_method='mean'):
15+
"""Average across a FOOOFGroup object.
16+
17+
Parameters
18+
----------
19+
fg : FOOOFGroup
20+
A FOOOFGroup object with data to average across.
21+
bands : Bands
22+
Bands object that defines the frequency bands to collapse peaks across.
23+
avg : {'mean', 'median'}
24+
Averaging function to use.
25+
26+
Returns
27+
-------
28+
fm : FOOOF
29+
FOOOF object containing the average results from the FOOOFGroup input.
30+
"""
31+
32+
if avg_method not in ['mean', 'median']:
33+
raise ValueError('Requested average method not understood.')
34+
if not len(fg):
35+
raise ValueError('Input FOOOFGroup has no fit results - can not proceed.')
36+
37+
if avg_method == 'mean':
38+
avg_func = np.nanmean
39+
elif avg_method == 'median':
40+
avg_func = np.nanmedian
41+
42+
ap_params = avg_func(fg.get_all_data('aperiodic_params'), 0)
43+
44+
peak_params = np.array([avg_func(get_band_peaks_fg(fg, band, 'peak_params'), 0) \
45+
for label, band in bands])
46+
gaussian_params = np.array([avg_func(get_band_peaks_fg(fg, band, 'gaussian_params'), 0) \
47+
for label, band in bands])
48+
49+
r2 = avg_func(fg.get_all_data('r_squared'))
50+
error = avg_func(fg.get_all_data('error'))
51+
52+
results = FOOOFResults(ap_params, peak_params, r2, error, gaussian_params)
53+
54+
# Create the new FOOOF object, with settings, data info & results
55+
fm = FOOOF()
56+
fm.add_settings(fg.get_settings())
57+
fm.add_data_info(fg.get_data_info())
58+
fm.add_results(results)
59+
60+
return fm
61+
62+
1263
def combine_fooofs(fooofs):
1364
"""Combine a group of FOOOF and/or FOOOFGroup objects into a single FOOOFGroup object.
1465

fooof/tests/conftest.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import numpy as np
99

1010
from fooof.core.modutils import safe_import
11-
from fooof.tests.utils import get_tfm, get_tfg
11+
from fooof.tests.utils import get_tfm, get_tfg, get_tbands
1212

1313
plt = safe_import('.pyplot', 'matplotlib')
1414

@@ -45,6 +45,10 @@ def tfm():
4545
def tfg():
4646
yield get_tfg()
4747

48+
@pytest.fixture(scope='session')
49+
def tbands():
50+
yield get_tbands()
51+
4852
@pytest.fixture(scope='session')
4953
def skip_if_no_mpl():
5054
if not safe_import('matplotlib'):

0 commit comments

Comments
 (0)