Skip to content

Commit 93d21ef

Browse files
committed
update data & models for Band with n_bands
1 parent 5a6fc79 commit 93d21ef

File tree

4 files changed

+60
-56
lines changed

4 files changed

+60
-56
lines changed

specparam/data/conversions.py

Lines changed: 33 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import numpy as np
44

5-
from specparam import Bands
5+
from specparam.bands.bands import Bands, check_bands
66
from specparam.modutils.dependencies import safe_import, check_dependency
77
from specparam.data.periodic import get_band_peak_arr
88
from specparam.data.utils import flatten_results_dict
@@ -21,15 +21,18 @@ def model_to_dict(fit_results, modes, bands):
2121
Results of a model fit.
2222
modes : Modes
2323
Model modes definition.
24-
bands : Bands or int
25-
How to organize peaks, based on band definitions (Bands) or number of peaks (int).
24+
bands : Bands or dict or int
25+
How to organize peaks, based on band definitions.
26+
Can be Bands object or object that can be converted into a Bands object.
2627
2728
Returns
2829
-------
2930
dict
3031
Model results organized into a dictionary.
3132
"""
3233

34+
bands = check_bands(bands)
35+
3336
fr_dict = {}
3437

3538
# aperiodic parameters
@@ -38,18 +41,18 @@ def model_to_dict(fit_results, modes, bands):
3841

3942
# periodic parameters
4043
peaks = fit_results.peak_params
44+
if not bands.bands and bands.n_bands:
4145

42-
if isinstance(bands, int):
43-
44-
if len(peaks) < bands:
45-
nans = [np.array([np.nan] * 3) for ind in range(bands-len(peaks))]
46+
# If bands if defined in terms of number of peaks
47+
if len(peaks) < bands.n_bands:
48+
nans = [np.array([np.nan] * 3) for ind in range(bands.n_bands-len(peaks))]
4649
peaks = np.vstack((peaks, nans))
4750

48-
for ind, peak in enumerate(peaks[:bands, :]):
51+
for ind, peak in enumerate(peaks[:bands.n_bands, :]):
4952
for pe_label, pe_param in zip(modes.periodic.params.indices, peak):
5053
fr_dict[pe_label + '_' + str(ind)] = pe_param
5154

52-
elif isinstance(bands, Bands):
55+
elif bands.bands:
5356
for band, f_range in bands:
5457
for label, param in zip(modes.periodic.params.indices,
5558
get_band_peak_arr(peaks, f_range)):
@@ -72,16 +75,17 @@ def model_to_dataframe(fit_results, modes, bands):
7275
Results of a model fit.
7376
modes : Modes
7477
Model modes definition.
75-
bands : Bands or int
76-
How to organize peaks, based on band definitions (Bands) or number of peaks (int).
78+
bands : Bands or dict or int
79+
How to organize peaks, based on band definitions.
80+
Can be Bands object or object that can be converted into a Bands object.
7781
7882
Returns
7983
-------
8084
pd.Series
8185
Model results organized into a dataframe.
8286
"""
8387

84-
return pd.Series(model_to_dict(fit_results, modes, bands))
88+
return pd.Series(model_to_dict(fit_results, modes, check_bands(bands)))
8589

8690

8791
def group_to_dict(group_results, modes, bands):
@@ -93,15 +97,18 @@ def group_to_dict(group_results, modes, bands):
9397
List of FitResults objects, reflecting model results across a group of power spectra.
9498
modes : Modes
9599
Model modes definition.
96-
bands : Bands or int
97-
How to organize peaks, based on band definitions (Bands) or number of peaks (int).
100+
bands : Bands or dict or int
101+
How to organize peaks, based on band definitions.
102+
Can be Bands object or object that can be converted into a Bands object.
98103
99104
Returns
100105
-------
101106
dict
102107
Model results organized into a dictionary.
103108
"""
104109

110+
bands = check_bands(bands)
111+
105112
nres = len(group_results)
106113
fr_dict = {ke : np.zeros(nres) for ke in model_to_dict(group_results[0], modes, bands)}
107114
for ind, f_res in enumerate(group_results):
@@ -121,16 +128,17 @@ def group_to_dataframe(group_results, modes, bands):
121128
List of FitResults objects.
122129
modes : Modes
123130
Model modes definition.
124-
bands : Bands or int
125-
How to organize peaks, based on band definitions (Bands) or number of peaks (int).
131+
bands : Bands or dict or int
132+
How to organize peaks, based on band definitions.
133+
Can be Bands object or object that can be converted into a Bands object.
126134
127135
Returns
128136
-------
129137
pd.DataFrame
130138
Model results organized into a dataframe.
131139
"""
132140

133-
return pd.DataFrame(group_to_dict(group_results, modes, bands))
141+
return pd.DataFrame(group_to_dict(group_results, modes, check_bands(bands)))
134142

135143

136144
def event_group_to_dict(event_group_results, modes, bands):
@@ -142,8 +150,9 @@ def event_group_to_dict(event_group_results, modes, bands):
142150
Model fit results from across a set of events.
143151
modes : Modes
144152
Model modes definition.
145-
bands : Bands or int
146-
How to organize peaks, based on band definitions (Bands) or number of peaks (int).
153+
bands : Bands or dict or int
154+
How to organize peaks, based on band definitions.
155+
Can be Bands object or object that can be converted into a Bands object.
147156
148157
Returns
149158
-------
@@ -152,6 +161,7 @@ def event_group_to_dict(event_group_results, modes, bands):
152161
"""
153162

154163
event_time_results = {}
164+
bands = check_bands(bands)
155165

156166
for key in group_to_dict(event_group_results[0], modes, bands):
157167
event_time_results[key] = []
@@ -177,8 +187,9 @@ def event_group_to_dataframe(event_group_results, modes, bands):
177187
List of FitResults objects.
178188
modes : Modes
179189
Model modes definition.
180-
bands : Bands or int
181-
How to organize peaks, based on band definitions (Bands) or number of peaks (int).
190+
bands : Bands or dict or int
191+
How to organize peaks, based on band definitions.
192+
Can be Bands object or object that can be converted into a Bands object.
182193
183194
Returns
184195
-------
@@ -187,7 +198,7 @@ def event_group_to_dataframe(event_group_results, modes, bands):
187198
"""
188199

189200
return pd.DataFrame(flatten_results_dict(\
190-
event_group_to_dict(event_group_results, modes, bands)))
201+
event_group_to_dict(event_group_results, modes, check_bands(bands))))
191202

192203

193204
@check_dependency(pd, 'pandas')

specparam/models/event.py

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -109,10 +109,9 @@ def fit(self, freqs=None, spectrograms=None, freq_range=None, bands=None,
109109
If 3d array, should have shape [n_events, n_freqs, n_time_windows].
110110
freq_range : list of [float, float], optional
111111
Frequency range to fit the model to. If not provided, fits the entire given range.
112-
bands : Bands or int, optional
112+
bands : Bands or dict or int, optional
113113
How to organize peaks into bands.
114-
If Bands, extracts peaks based on band definitions.
115-
If int, extracts the first n peaks.
114+
If Bands or dict, uses band definitions. If int, extracts the first 'n' peaks.
116115
n_jobs : int, optional, default: 1
117116
Number of jobs to run in parallel.
118117
1 is no parallelization. -1 uses all available cores.
@@ -172,10 +171,9 @@ def report(self, freqs=None, spectrograms=None, freq_range=None,
172171
If a 3d array, should have shape [n_events, n_freqs, n_time_windows].
173172
freq_range : list of [float, float], optional
174173
Frequency range to fit the model to. If not provided, fits the entire given range.
175-
bands : Bands or int, optional
174+
bands : Bands or dict or int, optional
176175
How to organize peaks into bands.
177-
If Bands, extracts peaks based on band definitions.
178-
If int, extracts the first n peaks.
176+
If Bands or dict, uses band definitions. If int, extracts the first 'n' peaks.
179177
n_jobs : int, optional, default: 1
180178
Number of jobs to run in parallel.
181179
1 is no parallelization. -1 uses all available cores.
@@ -385,11 +383,10 @@ def to_df(self, bands=None):
385383
386384
Parameters
387385
----------
388-
bands : Bands or int, optional
386+
bands : Bands or dict or int, optional
389387
How to organize peaks into bands.
390-
If Bands, extracts peaks based on band definitions.
391-
If int, extracts the first n peaks.
392-
If provided, re-extracts peak features; if not provided, converts from `time_results`.
388+
If Bands or dict, uses band definitions. If int, extracts the first 'n' peaks.
389+
If provided, re-extracts peak features; if not, converts from `event_group_results`.
393390
394391
Returns
395392
-------
@@ -410,10 +407,9 @@ def convert_results(self, bands=None):
410407
411408
Parameters
412409
----------
413-
bands : Bands or int, optional
410+
bands : Bands or dict or int, optional
414411
How to organize peaks into bands.
415-
If Bands, extracts peaks based on band definitions.
416-
If int, extracts the first 'n' peaks.
412+
If Bands or dict, uses band definitions. If int, extracts the first 'n' peaks.
417413
If not provided, uses band definition available in object.
418414
"""
419415

specparam/models/time.py

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -78,10 +78,9 @@ def fit(self, freqs=None, spectrogram=None, freq_range=None, bands=None,
7878
Spectrogram of power spectrum values, in linear space.
7979
freq_range : list of [float, float], optional
8080
Frequency range to fit the model to. If not provided, fits the entire given range.
81-
bands : Bands or int, optional
81+
bands : Bands or dict or int, optional
8282
How to organize peaks into bands.
83-
If Bands, extracts peaks based on band definitions.
84-
If int, extracts the first n peaks.
83+
If Bands or dict, uses band definitions. If int, extracts the first 'n' peaks.
8584
n_jobs : int, optional, default: 1
8685
Number of jobs to run in parallel.
8786
1 is no parallelization. -1 uses all available cores.
@@ -121,10 +120,9 @@ def report(self, freqs=None, spectrogram=None, freq_range=None,
121120
Spectrogram of power spectrum values, in linear space.
122121
freq_range : list of [float, float], optional
123122
Frequency range to fit the model to. If not provided, fits the entire given range.
124-
bands : Bands or int, optional
123+
bands : Bands or dict or int, optional
125124
How to organize peaks into bands.
126-
If Bands, extracts peaks based on band definitions.
127-
If int, extracts the first n peaks.
125+
If Bands or dict, uses band definitions. If int, extracts the first 'n' peaks.
128126
n_jobs : int, optional, default: 1
129127
Number of jobs to run in parallel.
130128
1 is no parallelization. -1 uses all available cores.
@@ -251,11 +249,10 @@ def to_df(self, bands=None):
251249
252250
Parameters
253251
----------
254-
bands : Bands or int, optional
252+
bands : Bands or dict or int, optional
255253
How to organize peaks into bands.
256-
If Bands, extracts peaks based on band definitions.
257-
If int, extracts the first n peaks.
258-
If provided, re-extracts peak features; if not provided, converts from `time_results`.
254+
If Bands or dict, uses band definitions. If int, extracts the first 'n' peaks.
255+
If provided, re-extracts peak features; if not, converts from `time_results`.
259256
260257
Returns
261258
-------
@@ -276,10 +273,9 @@ def convert_results(self, bands):
276273
277274
Parameters
278275
----------
279-
bands : Bands or int, optional
276+
bands : Bands or dict or int, optional
280277
How to organize peaks into bands.
281-
If Bands, extracts peaks based on band definitions.
282-
If int, extracts the first 'n' peaks.
278+
If Bands or dict, uses band definitions. If int, extracts the first 'n' peaks.
283279
If not provided, uses band definition available in object.
284280
"""
285281

specparam/tests/data/test_conversions.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import numpy as np
66

7+
from specparam.bands import Bands
78
from specparam.modutils.dependencies import safe_import
89
pd = safe_import('pandas')
910

@@ -14,18 +15,18 @@
1415

1516
def test_model_to_dict(tresults, tmodes, tbands):
1617

17-
out = model_to_dict(tresults, tmodes, 1)
18+
out = model_to_dict(tresults, tmodes, Bands(n_bands=1))
1819
assert isinstance(out, dict)
1920
assert 'cf_0' in out
2021
assert out['cf_0'] == tresults.peak_params[0, 0]
2122
assert 'cf_1' not in out
2223

23-
out = model_to_dict(tresults, tmodes, 2)
24+
out = model_to_dict(tresults, tmodes, Bands(n_bands=2))
2425
assert 'cf_0' in out
2526
assert 'cf_1' in out
2627
assert out['cf_1'] == tresults.peak_params[1, 0]
2728

28-
out = model_to_dict(tresults, tmodes, 3)
29+
out = model_to_dict(tresults, tmodes, Bands(n_bands=3))
2930
assert 'cf_2' in out
3031
assert np.isnan(out['cf_2'])
3132

@@ -35,7 +36,7 @@ def test_model_to_dict(tresults, tmodes, tbands):
3536
def test_model_to_dataframe(tresults, tmodes, tbands, skip_if_no_pandas):
3637

3738
for nbands in [1, 2, 3]:
38-
out = model_to_dataframe(tresults, tmodes, nbands)
39+
out = model_to_dataframe(tresults, tmodes, Bands(n_bands=nbands))
3940
assert isinstance(out, pd.Series)
4041

4142
out = model_to_dataframe(tresults, tmodes, tbands)
@@ -46,7 +47,7 @@ def test_group_to_dict(tresults, tmodes, tbands):
4647
fit_results = [deepcopy(tresults), deepcopy(tresults), deepcopy(tresults)]
4748

4849
for nbands in [1, 2, 3]:
49-
out = group_to_dict(fit_results, tmodes, nbands)
50+
out = group_to_dict(fit_results, tmodes, Bands(n_bands=nbands))
5051
assert isinstance(out, dict)
5152

5253
out = group_to_dict(fit_results, tmodes, tbands)
@@ -57,7 +58,7 @@ def test_group_to_dataframe(tresults, tmodes, tbands, skip_if_no_pandas):
5758
fit_results = [deepcopy(tresults), deepcopy(tresults), deepcopy(tresults)]
5859

5960
for nbands in [1, 2, 3]:
60-
out = group_to_dataframe(fit_results, tmodes, nbands)
61+
out = group_to_dataframe(fit_results, tmodes, Bands(n_bands=nbands))
6162
assert isinstance(out, pd.DataFrame)
6263

6364
out = group_to_dataframe(fit_results, tmodes, tbands)
@@ -69,7 +70,7 @@ def test_event_group_to_dict(tresults, tmodes, tbands):
6970
[deepcopy(tresults), deepcopy(tresults), deepcopy(tresults)]]
7071

7172
for nbands in [1, 2, 3]:
72-
out = event_group_to_dict(fit_results, tmodes, nbands)
73+
out = event_group_to_dict(fit_results, tmodes, Bands(n_bands=nbands))
7374
assert isinstance(out, dict)
7475

7576
out = event_group_to_dict(fit_results, tmodes, tbands)
@@ -81,7 +82,7 @@ def test_event_group_to_dataframe(tresults, tmodes, tbands, skip_if_no_pandas):
8182
[deepcopy(tresults), deepcopy(tresults), deepcopy(tresults)]]
8283

8384
for nbands in [1, 2, 3]:
84-
out = event_group_to_dataframe(fit_results, tmodes, nbands)
85+
out = event_group_to_dataframe(fit_results, tmodes, Bands(n_bands=nbands))
8586
assert isinstance(out, pd.DataFrame)
8687

8788
out = event_group_to_dataframe(fit_results, tmodes, tbands)

0 commit comments

Comments
 (0)