Skip to content

Commit 72ea4c8

Browse files
committed
udpate interp to take multiple regions
1 parent f9d7c62 commit 72ea4c8

File tree

2 files changed

+41
-19
lines changed

2 files changed

+41
-19
lines changed

fooof/tests/utils/test_data.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ def test_trim_spectrum():
2121

2222
def test_interpolate_spectrum():
2323

24+
# Test with single buffer exclusion zone
2425
freqs, powers = gen_power_spectrum(\
2526
[1, 75], [1, 1], [[10, 0.5, 1.0], [60, 2, 0.1]])
2627

@@ -29,3 +30,12 @@ def test_interpolate_spectrum():
2930
assert np.array_equal(freqs, freqs_out)
3031
assert np.all(powers)
3132
assert powers.shape == powers_out.shape
33+
34+
# Test with multiple buffer exclusion zones
35+
freqs, powers = gen_power_spectrum(\
36+
[1, 150], [1, 100, 1], [[10, 0.5, 1.0], [60, 1, 0.1], [120, 0.5, 0.1]])
37+
38+
freqs_out, powers_out = interpolate_spectrum(freqs, powers, [[58, 62], [118, 122]])
39+
assert np.array_equal(freqs, freqs_out)
40+
assert np.all(powers)
41+
assert powers.shape == powers_out.shape

fooof/utils/data.py

Lines changed: 31 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
"""Utilities for working with data and models."""
22

3+
from itertools import repeat
4+
35
import numpy as np
46

57
###################################################################################################
@@ -60,9 +62,10 @@ def interpolate_spectrum(freqs, powers, interp_range, buffer=3):
6062
Frequency values for the power spectrum.
6163
powers : 1d array
6264
Power values for the power spectrum.
63-
interp_range : list of float
65+
interp_range : list of float or list of list of float
6466
Frequency range to interpolate, as [lowest_freq, highest_freq].
65-
buffer : int
67+
If a list of lists, applies each as it's own interpolation range.
68+
buffer : int or list of int
6669
The number of samples to use on either side of the interpolation
6770
range, that are then averaged and used to calculate the interpolation.
6871
@@ -101,26 +104,35 @@ def interpolate_spectrum(freqs, powers, interp_range, buffer=3):
101104
>>> freqs, powers = interpolate_spectrum(freqs, powers, [58, 62])
102105
"""
103106

104-
# Take a copy of the array, to not change original array
105-
powers = np.copy(powers)
107+
# If given a list of interpolation zones, recurse to apply each one
108+
if isinstance(interp_range[0], list):
109+
buffer = repeat(buffer) if isinstance(buffer, int) else buffer
110+
for interp_zone, cur_buffer in zip(interp_range, buffer):
111+
freqs, powers = interpolate_spectrum(freqs, powers, interp_zone, cur_buffer)
112+
113+
# Assuming list of two floats, interpolate a single frequency range
114+
else:
115+
116+
# Take a copy of the array, to not change original array
117+
powers = np.copy(powers)
106118

107-
# Get the set of frequency values that need to be interpolated
108-
interp_mask = np.logical_and(freqs >= interp_range[0], freqs <= interp_range[1])
109-
interp_freqs = freqs[interp_mask]
119+
# Get the set of frequency values that need to be interpolated
120+
interp_mask = np.logical_and(freqs >= interp_range[0], freqs <= interp_range[1])
121+
interp_freqs = freqs[interp_mask]
110122

111-
# Get the indices of the interpolation range
112-
ii1, ii2 = np.flatnonzero(interp_mask)[[0, -1]]
123+
# Get the indices of the interpolation range
124+
ii1, ii2 = np.flatnonzero(interp_mask)[[0, -1]]
113125

114-
# Extract & log the requested range of data to use around interpolated range
115-
xs1 = np.log10(freqs[ii1-buffer:ii1])
116-
xs2 = np.log10(freqs[ii2:ii2+buffer])
117-
ys1 = np.log10(powers[ii1-buffer:ii1])
118-
ys2 = np.log10(powers[ii2:ii2+buffer])
126+
# Extract & log the requested range of data to use around interpolated range
127+
xs1 = np.log10(freqs[ii1-buffer:ii1])
128+
xs2 = np.log10(freqs[ii2:ii2+buffer])
129+
ys1 = np.log10(powers[ii1-buffer:ii1])
130+
ys2 = np.log10(powers[ii2:ii2+buffer])
119131

120-
# Linearly interpolate, in log-log space, between averages of the extracted points
121-
vals = np.interp(np.log10(interp_freqs),
122-
[np.median(xs1), np.median(xs2)],
123-
[np.median(ys1), np.median(ys2)])
124-
powers[interp_mask] = np.power(10, vals)
132+
# Linearly interpolate, in log-log space, between averages of the extracted points
133+
vals = np.interp(np.log10(interp_freqs),
134+
[np.median(xs1), np.median(xs2)],
135+
[np.median(ys1), np.median(ys2)])
136+
powers[interp_mask] = np.power(10, vals)
125137

126138
return freqs, powers

0 commit comments

Comments
 (0)