Skip to content

Commit 386d922

Browse files
committed
allow for null Modes in Results
1 parent e8bf0b8 commit 386d922

File tree

3 files changed

+20
-9
lines changed

3 files changed

+20
-9
lines changed

specparam/modes/modes.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,10 @@ def __init__(self, aperiodic, periodic):
3232
def check_params(self):
3333
"""Check the description of the parameters for each mode."""
3434

35-
self.aperiodic.check_params()
36-
self.periodic.check_params()
35+
if self.aperiodic:
36+
self.aperiodic.check_params()
37+
if self.periodic:
38+
self.periodic.check_params()
3739

3840

3941
def get_modes(self):
@@ -45,19 +47,25 @@ def get_modes(self):
4547
Modes definition.
4648
"""
4749

48-
return ModelModes(aperiodic_mode=self.aperiodic.name, periodic_mode=self.periodic.name)
50+
return ModelModes(aperiodic_mode=self.aperiodic.name if self.aperiodic else None,
51+
periodic_mode=self.periodic.name if self.periodic else None)
4952

5053

5154
def check_mode_definition(mode, options):
5255
"""Check a mode specification.
5356
5457
Parameters
5558
----------
56-
mode : str or Mode
59+
mode : str or None or Mode
5760
Fit mode. If str, should be a label corresponding to an entry in `options`.
5861
options : dict
5962
Available modes.
6063
64+
Returns
65+
-------
66+
mode : Mode or None
67+
Mode object, if defined, or None if not defined.
68+
6169
Raises
6270
------
6371
ValueError
@@ -68,7 +76,9 @@ def check_mode_definition(mode, options):
6876
assert mode in list(options.keys()), 'Specific Mode not found.'
6977
mode = options[mode]
7078

71-
if not isinstance(mode, Mode):
79+
if mode is None:
80+
mode = None
81+
elif not isinstance(mode, Mode):
7282
raise ValueError('Mode input not understood.')
7383

7484
return mode

specparam/objs/results.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import numpy as np
77

88
from specparam.bands.bands import check_bands
9+
from specparam.modes.modes import Modes
910
from specparam.objs.metrics import Metrics
1011
from specparam.objs.params import ModelParameters
1112
from specparam.objs.components import ModelComponents
@@ -57,7 +58,7 @@ class Results():
5758
def __init__(self, modes=None, metrics=None, bands=None):
5859
"""Initialize Results object."""
5960

60-
self.modes = modes
61+
self.modes = modes if modes else Modes(None, None)
6162

6263
self.add_bands(bands)
6364
self.add_metrics(metrics)
@@ -148,7 +149,7 @@ def add_results(self, results):
148149

149150
# TODO: use check_array_dim for peak arrays? Or is / should this be done in `add_params`
150151

151-
for component in ['aperiodic', 'periodic']:
152+
for component in self.modes.components:
152153
for version in ['fit', 'converted']:
153154
attr_comp = 'peak' if component == 'periodic' else component
154155
getattr(self.params, component).add_params(\

specparam/tests/objs/test_results.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,13 @@ def test_results():
1212
tres = Results()
1313
assert isinstance(tres, Results)
1414

15-
def test_results_results(tresults):
15+
def test_results_results(tresults, tmodes):
1616

1717
tres = Results()
1818

1919
tres.add_results(tresults)
2020
assert tres.has_model
21-
for component in ['aperiodic', 'periodic']:
21+
for component in tmodes.components:
2222
attr_comp = 'peak' if component == 'periodic' else component
2323
assert np.array_equal(getattr(tres.params, component).get_params('fit'),
2424
getattr(tresults, attr_comp + '_fit'))

0 commit comments

Comments
 (0)