Skip to content

Commit b694831

Browse files
committed
Add tests for bands objects, and associated fixes
1 parent 89ce84a commit b694831

File tree

4 files changed

+71
-14
lines changed

4 files changed

+71
-14
lines changed

fooof/bands.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,10 @@
22

33
from collections import OrderedDict
44

5-
import numpy as np
6-
75
###################################################################################################
86
###################################################################################################
97

10-
class Bands(object):
8+
class Bands():
119
"""Class to hold definitions of oscillation bands.
1210
1311
Attributes
@@ -16,19 +14,19 @@ class Bands(object):
1614
Dictionary of band definitions.
1715
"""
1816

19-
def __init__(self, input_bands=None):
17+
def __init__(self, input_bands={}):
2018
"""Initialize the Bands object.
2119
2220
Parameters
2321
----------
24-
input_bands : dict, optional (default = None)
22+
input_bands : dict, optional
2523
A dictionary of oscillation bands to use.
2624
"""
2725

2826
self.bands = OrderedDict()
2927

3028
for label, band_def in input_bands.items():
31-
self.add_band(label)
29+
self.add_band(label, band_def)
3230

3331

3432
def __getitem__(self, name):
@@ -46,7 +44,7 @@ def __getattr__(self, name):
4644

4745
def __repr__(self):
4846

49-
return '\n'.join(['{:8} : {:2} - {:2} Hz'.format(*val) \
47+
return '\n'.join(['{:8} : {:2} - {:2} Hz'.format(key, *val) \
5048
for key, val in self.bands.items()])
5149

5250

@@ -57,12 +55,14 @@ def __len__(self):
5755

5856
@property
5957
def labels(self):
58+
"""Get the labels for all bands defined in the object."""
6059

6160
return list(self.bands.keys())
6261

6362

6463
@property
6564
def n_bands(self):
65+
"""Get the number of bands defined in the object."""
6666

6767
return len(self.bands)
6868

@@ -78,8 +78,7 @@ def add_band(self, label, band_definition):
7878
The lower and upper frequency limit of the band, in Hz.
7979
"""
8080

81-
# Check that band definition is properly formatted & add if so
82-
_check_band(label, band_definition)
81+
self._check_band(label, band_definition)
8382
self.bands[label] = band_definition
8483

8584

@@ -92,24 +91,24 @@ def remove_band(self, label):
9291
Band label to remove from band definitions.
9392
"""
9493

95-
self.bands.pop(rm_band)
94+
self.bands.pop(label)
9695

9796

9897
@staticmethod
9998
def _check_band(label, band_definition):
100-
"""Check that a proposed band definition is properly formatted.
99+
"""Check that a proposed band definition is valid.
101100
102101
Parameters
103102
----------
104103
label : str
105-
The name of the new oscillation band.
104+
The name of the new band.
106105
band_definition : tuple of (float, float)
107106
The lower and upper frequency limit of the band, in Hz.
108107
109108
Raises
110109
------
111110
InconsistentDataError
112-
If oscillation band definition is not properly formatted.
111+
If band definition is not properly formatted.
113112
"""
114113

115114
# Check that band name is a string

fooof/tests/conftest.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import numpy as np
99

1010
from fooof.core.modutils import safe_import
11-
from fooof.tests.utils import get_tfm, get_tfg
11+
from fooof.tests.utils import get_tfm, get_tfg, get_tbands
1212

1313
plt = safe_import('.pyplot', 'matplotlib')
1414

@@ -45,6 +45,10 @@ def tfm():
4545
def tfg():
4646
yield get_tfg()
4747

48+
@pytest.fixture(scope='session')
49+
def tbands():
50+
yield get_tbands()
51+
4852
@pytest.fixture(scope='session')
4953
def skip_if_no_mpl():
5054
if not safe_import('matplotlib'):

fooof/tests/test_bands.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
"""Test functions for FOOOF bands."""
2+
3+
from py.test import raises
4+
5+
from fooof.bands import *
6+
7+
###################################################################################################
8+
###################################################################################################
9+
10+
def test_bands():
11+
12+
bands = Bands()
13+
assert isinstance(bands, Bands)
14+
15+
def test_bands_add_band():
16+
17+
bands = Bands()
18+
bands.add_band('test', (5, 10))
19+
assert bands.bands == {'test' : (5, 10)}
20+
21+
def test_bands_remove_band():
22+
23+
bands = Bands()
24+
bands.add_band('test', (5, 10))
25+
bands.remove_band('test')
26+
assert bands.bands == {}
27+
28+
def test_bands_errors():
29+
30+
bands = Bands()
31+
with raises(InconsistentDataError):
32+
bands.add_band(1, (1, 1))
33+
with raises(InconsistentDataError):
34+
bands.add_band('test', (1, 1, 1))
35+
with raises(InconsistentDataError):
36+
bands.add_band('test', (2, 1))
37+
38+
def test_bands_dunders(tbands):
39+
40+
assert tbands['theta']
41+
assert tbands.alpha
42+
assert repr(tbands)
43+
assert len(tbands) == 3
44+
45+
def test_bands_properties(tbands):
46+
47+
assert tbands.labels == ['theta', 'alpha', 'beta']
48+
assert tbands.n_bands == 3

fooof/tests/utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from functools import wraps
44

55
from fooof import FOOOF, FOOOFGroup
6+
from fooof.bands import Bands
67
from fooof.synth import gen_power_spectrum, gen_group_power_spectra, param_sampler
78
from fooof.core.modutils import safe_import
89

@@ -36,6 +37,11 @@ def get_tfg():
3637

3738
return tfg
3839

40+
def get_tbands():
41+
"""Get a bands object, for testing."""
42+
43+
return Bands({'theta' : (4, 8), 'alpha' : (8, 12), 'beta' : (13, 30)})
44+
3945
def default_group_params():
4046
"""Create default parameters for generating a test group of power spectra."""
4147

0 commit comments

Comments
 (0)