Skip to content

Commit 1485c8a

Browse files
authored
Merge pull request #66 from nasa/feature/tuple_units
2 parents aaa4a6b + d650984 commit 1485c8a

File tree

2 files changed

+5
-3
lines changed

2 files changed

+5
-3
lines changed

src/progpy/data_models/lstm_model.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# Copyright © 2021 United States Government as represented by the Administrator of the
22
# National Aeronautics and Space Administration. All Rights Reserved.
33

4+
from collections import abc
45
from itertools import chain
56
import matplotlib.pyplot as plt
67
from numbers import Number
@@ -476,8 +477,8 @@ def from_data(cls, inputs, outputs, event_states=None, t_met=None, **kwargs):
476477
raise ValueError(f"layers must be greater than 0, got {params['layers']}")
477478
if np.isscalar(params['units']):
478479
params['units'] = [params['units'] for _ in range(params['layers'])]
479-
if not isinstance(params['units'], (list, np.ndarray)):
480-
raise TypeError(f"units must be a list of integers, not {type(params['units'])}")
480+
if not isinstance(params['units'], (abc.Sequence, np.ndarray)):
481+
raise TypeError(f"units must be a Sequence (e.g., list or tuple) of integers, not {type(params['units'])}")
481482
if len(params['units']) != params['layers']:
482483
raise ValueError(f"units must be a list of integers of length {params['layers']}, got {params['units']}")
483484
for i in range(params['layers']):
@@ -487,7 +488,7 @@ def from_data(cls, inputs, outputs, event_states=None, t_met=None, **kwargs):
487488
raise TypeError(f"dropout must be an float greater than or equal to 0, not {type(params['dropout'])}")
488489
if params['dropout'] < 0:
489490
raise ValueError(f"dropout must be greater than or equal to 0, got {params['dropout']}")
490-
if not isinstance(params['activation'], (list, np.ndarray)):
491+
if not isinstance(params['activation'], (list, tuple, np.ndarray)):
491492
params['activation'] = [params['activation'] for _ in range(params['layers'])]
492493
if not np.isscalar(params['validation_split']):
493494
raise TypeError(f"validation_split must be an float between 0 and 1, not {type(params['validation_split'])}")

tests/test_data_model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,7 @@ def future_loading(t, x=None):
184184
[future_loading for _ in range(5)],
185185
dt=[TIMESTEP, TIMESTEP/2, TIMESTEP/4, TIMESTEP*2, TIMESTEP*4],
186186
window=2,
187+
units=(16, ), # Units as tuple
187188
epochs=20)
188189

189190
# Should get keys from original model

0 commit comments

Comments
 (0)