Skip to content

Commit cd63d9c

Browse files
committed
add get data & model component methods to fit object
1 parent ff8829f commit cd63d9c

File tree

2 files changed

+101
-0
lines changed

2 files changed

+101
-0
lines changed

fooof/objs/fit.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@
6363
from numpy.linalg import LinAlgError
6464
from scipy.optimize import curve_fit
6565

66+
from fooof.core.utils import unlog
6667
from fooof.core.items import OBJ_DESC
6768
from fooof.core.info import get_indices
6869
from fooof.core.io import save_fm, load_json
@@ -596,6 +597,95 @@ def get_meta_data(self):
596597
for key in OBJ_DESC['meta_data']})
597598

598599

600+
def get_data_component(self, component='full', space='log'):
601+
"""Get a data component.
602+
603+
Parameters
604+
----------
605+
component : {'full', 'aperiodic', 'peak'}
606+
Which data component to return.
607+
'full' - full power spectrum
608+
'aperiodic' - isolated aperiodic data component
609+
'peak' - isolated peak data component
610+
space : {'log', 'linear'}
611+
Which space to return the data component in.
612+
'log' - returns in log10 space.
613+
'linear' - returns in linear space.
614+
615+
Returns
616+
-------
617+
output : 1d array
618+
Specified data component, in specified spacing.
619+
620+
Notes
621+
-----
622+
The 'space' parameter doesn't just define the spacing of the data component
623+
values, but rather defines the space of the additive data definiton such that
624+
`power_spectrum = aperiodic_component + peak_component`.
625+
With space set as 'log', this combination holds in log space.
626+
With space set as 'linear', this combination holds in linear space.
627+
"""
628+
629+
assert space in ['linear', 'log'], "Input for 'space' invalid."
630+
631+
if component == 'full':
632+
output = self.power_spectrum if space == 'log' else unlog(self.power_spectrum)
633+
elif component == 'aperiodic':
634+
output = self._spectrum_peak_rm if space == 'log' else \
635+
unlog(self.power_spectrum) / unlog(self._peak_fit)
636+
elif component == 'peak':
637+
output = self._spectrum_flat if space == 'log' else \
638+
unlog(self.power_spectrum) - unlog(self._ap_fit)
639+
else:
640+
raise ValueError('Input for component invalid.')
641+
642+
return output
643+
644+
645+
def get_model_component(self, component='full', space='log'):
646+
"""Get a model component.
647+
648+
Parameters
649+
----------
650+
component : {'full', 'aperiodic', 'peak'}
651+
Which model component to return.
652+
'full' - full model
653+
'aperiodic' - isolated aperiodic model component
654+
'peak' - isolated peak model component
655+
space : {'log', 'linear'}
656+
Which space to return the model component in.
657+
'log' - returns in log10 space.
658+
'linear' - returns in linear space.
659+
660+
Returns
661+
-------
662+
output : 1d array
663+
Specified model component, in specified spacing.
664+
665+
Notes
666+
-----
667+
The 'space' parameter doesn't just define the spacing of the model component
668+
values, but rather defines the space of the additive model such that
669+
`model = aperiodic_component + peak_component`.
670+
With space set as 'log', this combination holds in log space.
671+
With space set as 'lienar', this combination holds in linear space.
672+
"""
673+
674+
assert space in ['linear', 'log'], "Input for 'space' invalid."
675+
676+
if component == 'full':
677+
output = self.fooofed_spectrum_ if space == 'log' else unlog(self.fooofed_spectrum_)
678+
elif component == 'aperiodic':
679+
output = self._ap_fit if space == 'log' else unlog(self._ap_fit)
680+
elif component == 'peak':
681+
output = self._peak_fit if space == 'log' else \
682+
unlog(self.fooofed_spectrum_) - unlog(self._ap_fit)
683+
else:
684+
raise ValueError('Input for component invalid.')
685+
686+
return output
687+
688+
599689
def get_params(self, name, col=None):
600690
"""Return model fit parameters for specified feature(s).
601691

fooof/tests/objs/test_fit.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -311,6 +311,17 @@ def test_obj_gets(tfm):
311311
results = tfm.get_results()
312312
assert isinstance(results, FOOOFResults)
313313

314+
def test_get_components(tfm):
315+
316+
# Make sure test object has been fit
317+
tfm.fit()
318+
319+
# Test get data & model components
320+
for comp in ['full', 'aperiodic', 'peak']:
321+
for space in ['log', 'linear']:
322+
assert isinstance(tfm.get_data_component(comp, space), np.ndarray)
323+
assert isinstance(tfm.get_model_component(comp, space), np.ndarray)
324+
314325
def test_get_params(tfm):
315326
"""Test the get_params method."""
316327

0 commit comments

Comments
 (0)