|
3 | 3 | import numpy as np |
4 | 4 |
|
5 | 5 | from fooof import FOOOF, FOOOFGroup |
6 | | -from fooof.synth.gen import gen_freqs |
| 6 | +from fooof.data import FOOOFResults |
7 | 7 | from fooof.utils import compare_info |
| 8 | +from fooof.synth.gen import gen_freqs |
| 9 | +from fooof.analysis import get_band_peaks_fg |
8 | 10 |
|
9 | 11 | ################################################################################################### |
10 | 12 | ################################################################################################### |
11 | 13 |
|
12 | | -def average_fg(fg, bands, avg='mean'): |
13 | | - """Average across a FOOOFGroup object.""" |
| 14 | +def average_fg(fg, bands, avg_method='mean'): |
| 15 | + """Average across a FOOOFGroup object. |
14 | 16 |
|
15 | | - if avg == 'mean': |
16 | | - avg_func = np.nanmean |
17 | | - elif avg == 'median': |
18 | | - avg_func = np.nanmedian |
| 17 | + Parameters |
| 18 | + ---------- |
| 19 | + fg : FOOOFGroup |
| 20 | + A FOOOFGroup object with data to average across. |
| 21 | + bands : Bands |
| 22 | + Bands object that defines the frequency bands to collapse peaks across. |
| 23 | + avg : {'mean', 'median'} |
| 24 | + Averaging function to use. |
19 | 25 |
|
20 | | - ap_params = avg_func(fg.get_all_data('aperiodic_params'), 0) |
| 26 | + Returns |
| 27 | + ------- |
| 28 | + fm : FOOOF |
| 29 | + FOOOF object containing the average results from the FOOOFGroup input. |
| 30 | + """ |
21 | 31 |
|
22 | | - peak_params, gaussian_params = np.empty([0, 3]), np.empty([0, 3]) |
| 32 | + if avg_method not in ['mean', 'median']: |
| 33 | + raise ValueError('Requested average method not understood.') |
| 34 | + if not len(fg): |
| 35 | + raise ValueError('Input FOOOFGroup has no fit results - can not proceed.') |
23 | 36 |
|
24 | | - for label, band in bands: |
| 37 | + if avg_method == 'mean': |
| 38 | + avg_func = np.nanmean |
| 39 | + elif avg_method == 'median': |
| 40 | + avg_func = np.nanmedian |
25 | 41 |
|
26 | | - peak_params = np.vstack([peak_params, |
27 | | - avg_func(get_band_peak_group(fg.get_all_data('peak_params'), band, len(fg)), 0)]) |
| 42 | + ap_params = avg_func(fg.get_all_data('aperiodic_params'), 0) |
28 | 43 |
|
29 | | - gaussian_params = np.vstack([gaussian_params, |
30 | | - avg_func(get_band_peak_group(fg.get_all_data('gaussian_params'), band, len(fg)), 0)]) |
| 44 | + peak_params = np.array([avg_func(get_band_peaks_fg(fg, band, 'peak_params'), 0) \ |
| 45 | + for label, band in bands]) |
| 46 | + gaussian_params = np.array([avg_func(get_band_peaks_fg(fg, band, 'gaussian_params'), 0) \ |
| 47 | + for label, band in bands]) |
31 | 48 |
|
32 | 49 | r2 = avg_func(fg.get_all_data('r_squared')) |
33 | 50 | error = avg_func(fg.get_all_data('error')) |
|
0 commit comments