Skip to content

Commit 8497ec1

Browse files
committed
update model objs for data / results & attribute defs
1 parent d790b41 commit 8497ec1

File tree

4 files changed

+36
-100
lines changed

4 files changed

+36
-100
lines changed

specparam/models/event.py

Lines changed: 6 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
import numpy as np
44

55
from specparam.models import SpectralModel, SpectralTimeModel
6-
from specparam.objs.results import BaseResults3D
7-
from specparam.objs.data import BaseData3D
6+
from specparam.objs.results import Results3D
7+
from specparam.objs.data import Data3D
88
from specparam.plts.event import plot_event_model
99
from specparam.data.conversions import event_group_to_dataframe, dict_to_df
1010
from specparam.data.utils import flatten_results_dict
@@ -22,6 +22,7 @@
2222
###################################################################################################
2323

2424
@replace_docstring_sections([docs_get_section(SpectralModel.__doc__, 'Parameters'),
25+
docs_get_section(SpectralModel.__doc__, 'Attributes'),
2526
docs_get_section(SpectralModel.__doc__, 'Notes')])
2627
class SpectralTimeEventModel(SpectralTimeModel):
2728
"""Model a set of event as a combination of aperiodic and periodic components.
@@ -37,20 +38,7 @@ class SpectralTimeEventModel(SpectralTimeModel):
3738
3839
Attributes
3940
----------
40-
freqs : 1d array
41-
Frequency values for the power spectra.
42-
spectrograms : 3d array
43-
Power values for the spectrograms, organized as [n_events, n_freqs, n_time_windows].
44-
Power values are stored internally in log10 scale.
45-
freq_range : list of [float, float]
46-
Frequency range of the power spectra, as [lowest_freq, highest_freq].
47-
freq_res : float
48-
Frequency resolution of the power spectra.
49-
event_group_results : list of list of FitResults
50-
Full model results collected across all events and models.
51-
event_time_results : dict
52-
Results of the model fit across each time window, collected across events.
53-
Each value in the dictionary stores a model fit parameter, as [n_events, n_time_windows].
41+
% copied in from SpectralModel object
5442
5543
Notes
5644
-----
@@ -69,9 +57,9 @@ def __init__(self, *args, **kwargs):
6957
verbose=kwargs.pop('verbose', True),
7058
**kwargs)
7159

72-
self.data = BaseData3D()
60+
self.data = Data3D()
7361

74-
self.results = BaseResults3D(modes=self.modes,
62+
self.results = Results3D(modes=self.modes,
7563
metrics=kwargs.pop('metrics', None),
7664
bands=kwargs.pop('bands', None))
7765

specparam/models/group.py

Lines changed: 7 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88
import numpy as np
99

1010
from specparam.models import SpectralModel
11-
from specparam.objs.data import BaseData2D
12-
from specparam.objs.results import BaseResults2D
11+
from specparam.objs.data import Data2D
12+
from specparam.objs.results import Results2D
1313
from specparam.objs.utils import run_parallel_group, pbar
1414
from specparam.plts.group import plot_group_model
1515
from specparam.io.models import save_group
@@ -25,6 +25,7 @@
2525
###################################################################################################
2626

2727
@replace_docstring_sections([docs_get_section(SpectralModel.__doc__, 'Parameters'),
28+
docs_get_section(SpectralModel.__doc__, 'Attributes'),
2829
docs_get_section(SpectralModel.__doc__, 'Notes')])
2930
class SpectralGroupModel(SpectralModel):
3031

@@ -41,27 +42,7 @@ class SpectralGroupModel(SpectralModel):
4142
4243
Attributes
4344
----------
44-
freqs : 1d array
45-
Frequency values for the power spectra.
46-
power_spectra : 2d array
47-
Power values for the group of power spectra, as [n_power_spectra, n_freqs].
48-
Power values are stored internally in log10 scale.
49-
freq_range : list of [float, float]
50-
Frequency range of the power spectra, as [lowest_freq, highest_freq].
51-
freq_res : float
52-
Frequency resolution of the power spectra.
53-
group_results : list of FitResults
54-
Results of the model fit for each power spectrum.
55-
has_data : bool
56-
Whether data is loaded to the object.
57-
has_model : bool
58-
Whether model results are available in the object.
59-
n_peaks_ : int
60-
The number of peaks fit in the model.
61-
n_null_ : int
62-
The number of models that failed to fit and/or that are marked as null.
63-
null_inds_ : list of int
64-
The indices of any models that are null.
45+
% copied in from SpectralModel object
6546
6647
Notes
6748
-----
@@ -85,9 +66,9 @@ def __init__(self, *args, **kwargs):
8566
verbose=kwargs.pop('verbose', True),
8667
**kwargs)
8768

88-
self.data = BaseData2D()
69+
self.data = Data2D()
8970

90-
self.results = BaseResults2D(modes=self.modes,
71+
self.results = Results2D(modes=self.modes,
9172
metrics=kwargs.pop('metrics', None),
9273
bands=kwargs.pop('bands', None))
9374

@@ -267,7 +248,7 @@ def load(self, file_name, file_path=None):
267248
self._reset_data_results(clear_spectrum=True, clear_results=True)
268249

269250

270-
@copy_doc_func_to_method(BaseResults2D.get_params)
251+
@copy_doc_func_to_method(Results2D.get_params)
271252
def get_params(self, name, field=None):
272253

273254
return self.results.get_params(name, field)

specparam/models/model.py

Lines changed: 17 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88
import numpy as np
99

1010
from specparam.models.base import BaseModel
11-
from specparam.objs.data import BaseData
12-
from specparam.objs.results import BaseResults
11+
from specparam.objs.data import Data
12+
from specparam.objs.results import Results
1313
from specparam.algorithms.spectral_fit import SpectralFitAlgorithm, SPECTRAL_FIT_SETTINGS
1414
from specparam.reports.save import save_model_report
1515
from specparam.reports.strings import gen_model_results_str
@@ -47,49 +47,25 @@ class SpectralModel(BaseModel):
4747
4848
Attributes
4949
----------
50-
freqs : 1d array
51-
Frequency values for the power spectrum.
52-
power_spectrum : 1d array
53-
Power values, stored internally in log10 scale.
54-
freq_range : list of [float, float]
55-
Frequency range of the power spectrum, as [lowest_freq, highest_freq].
56-
freq_res : float
57-
Frequency resolution of the power spectrum.
58-
modeled_spectrum_ : 1d array
59-
The full model fit of the power spectrum, in log10 scale.
60-
aperiodic_params_ : 1d array
61-
Fitted parameter values that define the aperiodic fit. As [Offset, (Knee), Exponent].
62-
The knee parameter is only included if aperiodic component is fit with a knee.
63-
peak_params_ : 2d array
64-
Fitted parameter values for the peaks. Each row is a peak, as [CF, PW, BW].
65-
gaussian_params_ : 2d array
66-
Fitted parameter values that define the gaussian fit(s).
67-
Each row is a gaussian, as [mean, height, standard deviation].
68-
r_squared_ : float
69-
R-squared of the fit between the input power spectrum and the full model fit.
70-
error_ : float
71-
Error of the full model fit.
72-
n_peaks_ : int
73-
The number of peaks fit in the model.
74-
has_data : bool
75-
Whether data is loaded to the object.
76-
has_model : bool
77-
Whether model results are available in the object.
78-
_debug : bool
79-
Whether the object is set in debug mode.
80-
If in debug mode, an error is raised if model fitting is unsuccessful.
81-
This should be controlled by using the `set_debug` method.
50+
algorithm : Algorithm
51+
Algorithm object with model fitting settings and procedures.
52+
modes : Modes
53+
Modes object with fit mode definitions.
54+
data : Data
55+
Data object with spectral data and metadata.
56+
results : Results
57+
Results object with model fit results and metrics.
8258
8359
Notes
8460
-----
85-
- Commonly used abbreviations used in this module include:
86-
CF: center frequency, PW: power, BW: Bandwidth, AP: aperiodic
8761
- Input power spectra must be provided in linear scale.
8862
Internally they are stored in log10 scale, as this is what the model operates upon.
8963
- Input power spectra should be smooth, as overly noisy power spectra may lead to bad fits.
9064
For example, raw FFT inputs are not appropriate. Where possible and appropriate, use
9165
longer time segments for power spectrum calculation to get smoother power spectra,
9266
as this will give better model fits.
67+
- Commonly used abbreviations used in this module include:
68+
CF: center frequency, PW: power, BW: Bandwidth, AP: aperiodic
9369
- The gaussian params are those that define the gaussian of the fit, where as the peak
9470
params are a modified version, in which the CF of the peak is the mean of the gaussian,
9571
the PW of the peak is the height of the gaussian over and above the aperiodic component,
@@ -106,9 +82,9 @@ def __init__(self, peak_width_limits=(0.5, 12.0), max_n_peaks=np.inf, min_peak_h
10682
periodic_mode=periodic_mode,
10783
verbose=verbose)
10884

109-
self.data = BaseData()
85+
self.data = Data()
11086

111-
self.results = BaseResults(modes=self.modes, metrics=metrics, bands=bands)
87+
self.results = Results(modes=self.modes, metrics=metrics, bands=bands)
11288

11389
self.algorithm = SpectralFitAlgorithm(
11490
peak_width_limits=peak_width_limits, max_n_peaks=max_n_peaks,
@@ -117,8 +93,8 @@ def __init__(self, peak_width_limits=(0.5, 12.0), max_n_peaks=np.inf, min_peak_h
11793
debug=debug, **model_kwargs)
11894

11995

120-
@replace_docstring_sections([docs_get_section(BaseData.add_data.__doc__, 'Parameters'),
121-
docs_get_section(BaseData.add_data.__doc__, 'Notes')])
96+
@replace_docstring_sections([docs_get_section(Data.add_data.__doc__, 'Parameters'),
97+
docs_get_section(Data.add_data.__doc__, 'Notes')])
12298
def add_data(self, freqs, power_spectrum, freq_range=None, clear_results=True):
12399
"""Add data (frequencies, and power spectrum values) to the current object.
124100
@@ -311,7 +287,7 @@ def load(self, file_name, file_path=None, regenerate=True):
311287
self.results._regenerate_model(self.data.freqs)
312288

313289

314-
@copy_doc_func_to_method(BaseResults.get_params)
290+
@copy_doc_func_to_method(Results.get_params)
315291
def get_params(self, name, field=None):
316292

317293
return self.results.get_params(name, field)

specparam/models/time.py

Lines changed: 6 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
"""Time model object and associated code for fitting the model to spectrograms."""
22

33
from specparam.models import SpectralModel, SpectralGroupModel
4-
from specparam.objs.results import BaseResults2DT
5-
from specparam.objs.data import BaseData2DT
4+
from specparam.objs.results import Results2DT
5+
from specparam.objs.data import Data2DT
66
from specparam.data.conversions import group_to_dataframe, dict_to_df
77
from specparam.data.utils import get_results_by_ind
88
from specparam.io.models import save_time
@@ -17,6 +17,7 @@
1717
###################################################################################################
1818

1919
@replace_docstring_sections([docs_get_section(SpectralModel.__doc__, 'Parameters'),
20+
docs_get_section(SpectralModel.__doc__, 'Attributes'),
2021
docs_get_section(SpectralModel.__doc__, 'Notes')])
2122
class SpectralTimeModel(SpectralGroupModel):
2223
"""Model a spectrogram as a combination of aperiodic and periodic components.
@@ -32,17 +33,7 @@ class SpectralTimeModel(SpectralGroupModel):
3233
3334
Attributes
3435
----------
35-
freqs : 1d array
36-
Frequency values for the spectrogram.
37-
spectrogram : 2d array
38-
Power values for the spectrogram, as [n_freqs, n_time_windows].
39-
Power values are stored internally in log10 scale.
40-
freq_range : list of [float, float]
41-
Frequency range of the spectrogram, as [lowest_freq, highest_freq].
42-
freq_res : float
43-
Frequency resolution of the spectrogram.
44-
time_results : dict
45-
Results of the model fit across each time window.
36+
% copied in from SpectralModel object
4637
4738
Notes
4839
-----
@@ -66,9 +57,9 @@ def __init__(self, *args, **kwargs):
6657
verbose=kwargs.pop('verbose', True),
6758
**kwargs)
6859

69-
self.data = BaseData2DT()
60+
self.data = Data2DT()
7061

71-
self.results = BaseResults2DT(modes=self.modes,
62+
self.results = Results2DT(modes=self.modes,
7263
metrics=kwargs.pop('metrics', None),
7364
bands=kwargs.pop('bands', None))
7465

0 commit comments

Comments
 (0)