Skip to content

Commit d650984

Browse files
committed
Merge remote-tracking branch 'origin/dev' into feature/tuple_units
2 parents 0c33d99 + aaa4a6b commit d650984

File tree

3 files changed

+14
-9
lines changed

3 files changed

+14
-9
lines changed

src/progpy/data_models/lstm_model.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -538,9 +538,7 @@ def from_data(cls, inputs, outputs, event_states=None, t_met=None, **kwargs):
538538
from tensorflow import keras
539539

540540
# Build model
541-
callbacks = [
542-
keras.callbacks.ModelCheckpoint("best_model.keras", save_best_only=True)
543-
]
541+
callbacks = [ ]
544542

545543
if params['early_stop']:
546544
callbacks.append(keras.callbacks.EarlyStopping(**params['early_stop.cfg']))
@@ -593,8 +591,6 @@ def from_data(cls, inputs, outputs, event_states=None, t_met=None, **kwargs):
593591
workers=params['workers'],
594592
use_multiprocessing=(params['workers'] > 1))
595593

596-
model = keras.models.load_model("best_model.keras")
597-
598594
# Split model into separate models
599595
n_state_layers = params['layers'] + 1 + (params['dropout'] > 0) + (params['normalize'])
600596
output_layer_input = keras.layers.Input(model.layers[n_state_layers-1].output.shape[1:])

src/progpy/prognostics_model.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -787,7 +787,7 @@ def simulate_to_threshold(self, future_loading_eqn: abc.Callable = None, first_o
787787
integration_method: str, optional
788788
Integration method, e.g. 'rk4' or 'euler' (default: 'euler')
789789
save_freq : float, optional
790-
Frequency at which output is saved (s), e.g., save_freq = 10 \n
790+
Frequency at which output is saved (s), e.g., save_freq = 10. A save_freq of 0 will save every step. \n
791791
save_pts : list[float], optional
792792
Additional ordered list of custom times where output is saved (s), e.g., save_pts= [50, 75] \n
793793
eval_pts : list[float], optional
@@ -883,8 +883,8 @@ def simulate_to_threshold(self, future_loading_eqn: abc.Callable = None, first_o
883883
raise ValueError("'dt' must be positive, was {}".format(config['dt']))
884884
if not isinstance(config['save_freq'], Number) and not isinstance(config['save_freq'], tuple):
885885
raise TypeError("'save_freq' must be a number, was a {}".format(type(config['save_freq'])))
886-
if (isinstance(config['save_freq'], Number) and config['save_freq'] <= 0) or \
887-
(isinstance(config['save_freq'], tuple) and config['save_freq'][1] <= 0):
886+
if (isinstance(config['save_freq'], Number) and config['save_freq'] < 0) or \
887+
(isinstance(config['save_freq'], tuple) and config['save_freq'][1] < 0):
888888
raise ValueError("'save_freq' must be positive, was {}".format(config['save_freq']))
889889
if not isinstance(config['save_pts'], abc.Iterable):
890890
raise TypeError("'save_pts' must be list or array, was a {}".format(type(config['save_pts'])))
@@ -1013,7 +1013,8 @@ def next_time(t, x):
10131013
def next_time(t, x=None):
10141014
next_save_pt = save_pts[save_pt_index] if save_pt_index < len(save_pts) else float('inf')
10151015
next_eval_pt = eval_pts[eval_pt_index] if eval_pt_index < len(eval_pts) else float('inf')
1016-
return min(dt, next_save-t, next_save_pt-t, next_eval_pt-t)
1016+
opts = (item for item in (dt, next_save-t, next_save_pt-t, next_eval_pt-t) if item > 0)
1017+
return min(*opts)
10171018
elif dt_mode != 'function':
10181019
raise ValueError(f"'dt' mode {dt_mode} not supported. Must be 'constant', 'auto', or a function")
10191020

tests/test_base_models.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -918,6 +918,14 @@ def load(t, x=None):
918918
result = m.simulate_to_threshold(load, dt=2, save_freq=0.75, save_pts=[1.5, 2.5])
919919
self.assertListEqual(result.times, [0, 2, 4])
920920

921+
# With save_freq==0
922+
result = m.simulate_to_threshold(load, dt=2, save_freq=0)
923+
self.assertListEqual(result.times, [0, 2, 4])
924+
925+
# With save_freq==0 and auto step size
926+
result = m.simulate_to_threshold(load, dt=('auto', 2), save_freq=0)
927+
self.assertListEqual(result.times, [0, 2, 4])
928+
921929
result = m.simulate_to_threshold(load, dt=2, save_pts=[2.5])
922930
self.assertListEqual(result.times, [0, 4])
923931

0 commit comments

Comments
 (0)