@@ -20,19 +20,34 @@ class Bands():
2020 >>> bands = Bands({'theta' : [4, 8], 'alpha' : [8, 12], 'beta' : [15, 30]})
2121 """
2222
23- def __init__ (self , input_bands = {} ):
23+ def __init__ (self , input_bands = None , n_bands = None ):
2424 """Initialize the Bands object.
2525
2626 Parameters
2727 ----------
2828 input_bands : dict, optional
2929 A dictionary of oscillation bands.
30+ n_bands : int, optional
31+ The number of bands to extract from the spectra.
32+ Can only be specified if not providing `input_bands`.
33+
34+ Attributes
35+ ----------
36+ bands : OrderedDict
37+ Band definitions.
3038 """
3139
3240 self .bands = OrderedDict ()
3341
34- for label , band_def in input_bands .items ():
35- self .add_band (label , band_def )
42+ if input_bands :
43+ for label , band_def in input_bands .items ():
44+ self .add_band (label , band_def )
45+
46+ self ._n_bands = None
47+ if n_bands :
48+ if input_bands :
49+ raise ValueError ('Cannot provive both `input_bands` and `n_bands`.' )
50+ self ._n_bands = n_bands
3651
3752
3853 def __getitem__ (self , label ):
@@ -89,7 +104,10 @@ def definitions(self):
89104 def n_bands (self ):
90105 """The number of bands defined in the object."""
91106
92- return len (self .bands )
107+ if self ._n_bands is not None :
108+ return self ._n_bands
109+ else :
110+ return len (self .bands )
93111
94112
95113 def add_band (self , label , band_definition ):
@@ -103,6 +121,7 @@ def add_band(self, label, band_definition):
103121 The lower and upper frequency limit of the band, in Hz.
104122 """
105123
124+ self ._n_bands = None
106125 self ._check_band (label , band_definition )
107126 self .bands [label ] = tuple (band_definition )
108127
@@ -147,3 +166,28 @@ def _check_band(label, band_definition):
147166 # Safety check that limits are in correct order
148167 if not band_definition [0 ] < band_definition [1 ]:
149168 raise ValueError ("Band limit definitions are invalid." )
169+
170+
171+ def check_bands (bands ):
172+ """Check bands definition.
173+
174+ Parameters
175+ ----------
176+ bands : Bands or dict or int, optional
177+ How to organize peaks into bands.
178+
179+ Returns
180+ -------
181+ bands : Bands
182+ Bands definition.
183+ """
184+
185+ if not isinstance (bands , Bands ):
186+ if isinstance (bands , (dict , OrderedDict )):
187+ bands = Bands (bands )
188+ elif isinstance (bands , int ):
189+ bands = Bands (n_bands = bands )
190+ else :
191+ raise ValueError ('Bands definition not understood.' )
192+
193+ return bands
0 commit comments