Skip to content

Commit 7c75d64

Browse files
authored
Merge pull request #201 from fooof-tools/optimize
[ENH] Optimization
2 parents f81ee12 + c2c8643 commit 7c75d64

File tree

4 files changed

+31
-20
lines changed

4 files changed

+31
-20
lines changed

fooof/core/utils.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,13 @@ def group_three(vec):
1313
1414
Parameters
1515
----------
16-
vec : 1d array
17-
Array of items to group by 3. Length of array must be divisible by three.
16+
vec : list or 1d array
17+
List or array of items to group by 3. Length of array must be divisible by three.
1818
1919
Returns
2020
-------
21-
list of list
22-
List of lists, each with three items.
21+
array or list of list
22+
Array or list of lists, each with three items. Output type will match input type.
2323
2424
Raises
2525
------
@@ -30,7 +30,11 @@ def group_three(vec):
3030
if len(vec) % 3 != 0:
3131
raise ValueError("Wrong size array to group by three.")
3232

33-
return [list(vec[ii:ii+3]) for ii in range(0, len(vec), 3)]
33+
# Reshape, if an array, as it's faster, otherwise asssume lise
34+
if isinstance(vec, np.ndarray):
35+
return np.reshape(vec, (-1, 3))
36+
else:
37+
return [list(vec[ii:ii+3]) for ii in range(0, len(vec), 3)]
3438

3539

3640
def nearest_ind(array, value):

fooof/objs/fit.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1005,18 +1005,16 @@ def _create_peak_params(self, gaus_params):
10051005
with `freqs`, `fooofed_spectrum_` and `_ap_fit` all required to be available.
10061006
"""
10071007

1008-
peak_params = np.empty([0, 3])
1008+
peak_params = np.empty((len(gaus_params), 3))
10091009

10101010
for ii, peak in enumerate(gaus_params):
10111011

10121012
# Gets the index of the power_spectrum at the frequency closest to the CF of the peak
1013-
ind = min(range(len(self.freqs)), key=lambda ii: abs(self.freqs[ii] - peak[0]))
1013+
ind = np.argmin(np.abs(self.freqs - peak[0]))
10141014

10151015
# Collect peak parameter data
1016-
peak_params = np.vstack((peak_params,
1017-
[peak[0],
1018-
self.fooofed_spectrum_[ind] - self._ap_fit[ind],
1019-
peak[2] * 2]))
1016+
peak_params[ii] = [peak[0], self.fooofed_spectrum_[ind] - self._ap_fit[ind],
1017+
peak[2] * 2]
10201018

10211019
return peak_params
10221020

@@ -1035,8 +1033,8 @@ def _drop_peak_cf(self, guess):
10351033
Guess parameters for gaussian peak fits. Shape: [n_peaks, 3].
10361034
"""
10371035

1038-
cf_params = [item[0] for item in guess]
1039-
bw_params = [item[2] * self._bw_std_edge for item in guess]
1036+
cf_params = guess[:, 0]
1037+
bw_params = guess[:, 2] * self._bw_std_edge
10401038

10411039
# Check if peaks within drop threshold from the edge of the frequency range
10421040
keep_peak = \

fooof/objs/utils.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -219,9 +219,14 @@ def fit_fooof_3d(fg, freqs, power_spectra, freq_range=None, n_jobs=1):
219219
>>> fgs = fit_fooof_3d(fg, freqs, power_spectra, freq_range=[3, 30]) # doctest:+SKIP
220220
"""
221221

222-
fgs = []
223-
for cond_spectra in power_spectra:
224-
fg.fit(freqs, cond_spectra, freq_range, n_jobs)
225-
fgs.append(fg.copy())
222+
# Reshape 3d data to 2d and fit, in order to fit with a single group model object
223+
shape = np.shape(power_spectra)
224+
powers_2d = np.reshape(power_spectra, (shape[0] * shape[1], shape[2]))
225+
226+
fg.fit(freqs, powers_2d, freq_range, n_jobs)
227+
228+
# Reorganize 2d results into a list of model group objects, to reflect original shape
229+
fgs = [fg.get_group(range(dim_a * shape[1], (dim_a + 1) * shape[1])) \
230+
for dim_a in range(shape[0])]
226231

227232
return fgs

fooof/tests/objs/test_utils.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -120,13 +120,17 @@ def test_combine_errors(tfm, tfg):
120120

121121
def test_fit_fooof_3d(tfg):
122122

123-
n_spectra = 2
123+
n_groups = 2
124+
n_spectra = 3
124125
xs, ys = gen_group_power_spectra(n_spectra, *default_group_params())
125-
ys = np.stack([ys, ys], axis=0)
126+
ys = np.stack([ys] * n_groups, axis=0)
127+
spectra_shape = np.shape(ys)
126128

127129
tfg = FOOOFGroup()
128130
fgs = fit_fooof_3d(tfg, xs, ys)
129131

130-
assert len(fgs) == 2
132+
assert len(fgs) == n_groups == spectra_shape[0]
131133
for fg in fgs:
132134
assert fg
135+
assert len(fg) == n_spectra
136+
assert fg.power_spectra.shape == spectra_shape[1:]

0 commit comments

Comments
 (0)