Skip to content

Commit 6bcad75

Browse files
committed
tweak get_model
1 parent e872308 commit 6bcad75

File tree

3 files changed

+23
-10
lines changed

3 files changed

+23
-10
lines changed

specparam/models/event.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -272,13 +272,14 @@ def get_model(self, event_ind=None, window_ind=None, regenerate=True):
272272
model = super().get_model()
273273

274274
# Add data for specified single power spectrum, if available
275-
if self.data.has_data:
275+
if event_ind is not None and window_ind is not None and self.data.has_data:
276276
model.data.power_spectrum = self.data.spectrograms[event_ind][:, window_ind]
277277

278278
# Add results for specified power spectrum, regenerating full fit if requested
279-
model.results.add_results(self.results.event_group_results[event_ind][window_ind])
280-
if regenerate:
281-
model.results._regenerate_model(self.data.freqs)
279+
if event_ind is not None and window_ind is not None:
280+
model.results.add_results(self.results.event_group_results[event_ind][window_ind])
281+
if regenerate:
282+
model.results._regenerate_model(self.data.freqs)
282283

283284
return model
284285

specparam/tests/models/test_event.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,13 +118,26 @@ def test_event_load():
118118

119119
def test_event_get_model(tfe):
120120

121+
# Check getting null model
122+
tfm_null = tfe.get_model()
123+
assert tfm_null
124+
# Check that settings are copied over properly, but data and results are empty
125+
for setting in tfe.algorithm.settings.names:
126+
assert getattr(tfe.algorithm, setting) == getattr(tfm_null.algorithm, setting)
127+
assert not tfm_null.data.has_data
128+
assert not tfm_null.results.has_model
129+
121130
# Check without regenerating
122131
tfm0 = tfe.get_model(0, 0, False)
123132
assert tfm0
133+
assert tfm0.data.has_data
134+
assert tfm0.results.has_model
124135

125136
# Check with regenerating
126137
tfm1 = tfe.get_model(1, 1, True)
127138
assert tfm1
139+
assert tfm1.data.has_data
140+
assert tfm1.results.has_model
128141
assert np.all(tfm1.results.modeled_spectrum_)
129142

130143
def test_event_get_params(tfe):

specparam/tests/models/test_group.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -332,14 +332,13 @@ def test_get_model(tfg):
332332
"""Check return of an individual model fit from a group object."""
333333

334334
# Test with no ind (no data / results)
335-
tfm = tfg.get_model()
336-
assert tfm
335+
tfm_null = tfg.get_model()
336+
assert tfm_null
337337
# Check that settings are copied over properly, but data and results are empty
338338
for setting in tfg.algorithm.settings.names:
339-
assert getattr(tfg.algorithm, setting) == getattr(tfm.algorithm, setting)
340-
for result in tfg.results._fields:
341-
assert np.all(np.isnan(getattr(tfm.results, result)))
342-
assert not tfm.data.power_spectrum
339+
assert getattr(tfg.algorithm, setting) == getattr(tfm_null.algorithm, setting)
340+
assert not tfm_null.data.has_data
341+
assert not tfm_null.results.has_model
343342

344343
# Check without regenerating
345344
tfm0 = tfg.get_model(0, False)

0 commit comments

Comments
 (0)