Skip to content

Commit 5a6fc79

Browse files
committed
update Bands for supporting n bands def
1 parent c029510 commit 5a6fc79

File tree

2 files changed

+78
-4
lines changed

2 files changed

+78
-4
lines changed

specparam/bands/bands.py

Lines changed: 48 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,19 +20,34 @@ class Bands():
2020
>>> bands = Bands({'theta' : [4, 8], 'alpha' : [8, 12], 'beta' : [15, 30]})
2121
"""
2222

23-
def __init__(self, input_bands={}):
23+
def __init__(self, input_bands=None, n_bands=None):
2424
"""Initialize the Bands object.
2525
2626
Parameters
2727
----------
2828
input_bands : dict, optional
2929
A dictionary of oscillation bands.
30+
n_bands : int, optional
31+
The number of bands to extract from the spectra.
32+
Can only be specified if not providing `input_bands`.
33+
34+
Attributes
35+
----------
36+
bands : OrderedDict
37+
Band definitions.
3038
"""
3139

3240
self.bands = OrderedDict()
3341

34-
for label, band_def in input_bands.items():
35-
self.add_band(label, band_def)
42+
if input_bands:
43+
for label, band_def in input_bands.items():
44+
self.add_band(label, band_def)
45+
46+
self._n_bands = None
47+
if n_bands:
48+
if input_bands:
49+
raise ValueError('Cannot provive both `input_bands` and `n_bands`.')
50+
self._n_bands = n_bands
3651

3752

3853
def __getitem__(self, label):
@@ -89,7 +104,10 @@ def definitions(self):
89104
def n_bands(self):
90105
"""The number of bands defined in the object."""
91106

92-
return len(self.bands)
107+
if self._n_bands is not None:
108+
return self._n_bands
109+
else:
110+
return len(self.bands)
93111

94112

95113
def add_band(self, label, band_definition):
@@ -103,6 +121,7 @@ def add_band(self, label, band_definition):
103121
The lower and upper frequency limit of the band, in Hz.
104122
"""
105123

124+
self._n_bands = None
106125
self._check_band(label, band_definition)
107126
self.bands[label] = tuple(band_definition)
108127

@@ -147,3 +166,28 @@ def _check_band(label, band_definition):
147166
# Safety check that limits are in correct order
148167
if not band_definition[0] < band_definition[1]:
149168
raise ValueError("Band limit definitions are invalid.")
169+
170+
171+
def check_bands(bands):
172+
"""Check bands definition.
173+
174+
Parameters
175+
----------
176+
bands : Bands or dict or int, optional
177+
How to organize peaks into bands.
178+
179+
Returns
180+
-------
181+
bands : Bands
182+
Bands definition.
183+
"""
184+
185+
if not isinstance(bands, Bands):
186+
if isinstance(bands, (dict, OrderedDict)):
187+
bands = Bands(bands)
188+
elif isinstance(bands, int):
189+
bands = Bands(n_bands=bands)
190+
else:
191+
raise ValueError('Bands definition not understood.')
192+
193+
return bands

specparam/tests/bands/test_bands.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,3 +52,33 @@ def test_bands_properties(tbands):
5252

5353
assert set(tbands.labels) == set(['theta', 'alpha', 'beta'])
5454
assert tbands.n_bands == 3
55+
56+
def test_bands_n_bands():
57+
58+
n_bands = 2
59+
bands = Bands(n_bands=n_bands)
60+
assert bands.bands == {}
61+
assert bands._n_bands == n_bands
62+
assert len(bands) == n_bands
63+
64+
# Check that adding a band replaces n_band definition
65+
bands.add_band('alpha', [7, 14])
66+
assert bands._n_bands == None
67+
assert len(bands) == 1
68+
69+
# test fails adding both bands and n_bands
70+
with raises(ValueError):
71+
bands = Bands({'alpha' : (7, 14)}, n_bands=2)
72+
73+
74+
def test_check_bands(tbands):
75+
76+
out1 = check_bands(tbands)
77+
assert isinstance(out1, Bands)
78+
assert out1 == tbands
79+
80+
out2 = check_bands({'alpha' : (7, 14)})
81+
assert isinstance(out2, Bands)
82+
83+
out3 = check_bands(2)
84+
assert isinstance(out3, Bands)

0 commit comments

Comments
 (0)