Skip to content

Commit 670e1cf

Browse files
authored
Merge pull request #76 from nasa/feature/MC_default_samples
Improve default MC sample size
2 parents b9e3ff5 + bf99af4 commit 670e1cf

File tree

2 files changed

+97
-17
lines changed

2 files changed

+97
-17
lines changed

src/progpy/predictors/monte_carlo.py

Lines changed: 55 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -19,27 +19,56 @@ class MonteCarlo(Predictor):
1919
2020
Configuration Parameters
2121
------------------------------
22-
t0 : float
23-
Initial time at which prediction begins, e.g., 0
24-
dt : float
25-
Simulation step size (s), e.g., 0.1
26-
events : list[str]
27-
Events to predict (subset of model.events) e.g., ['event1', 'event2']
28-
horizon : float
29-
Prediction horizon (s)
30-
n_samples : int
31-
Number of samples to use. If not specified, a default value is used. If state is type UnweightedSamples and n_samples is not provided, the provided unweighted samples will be used directly.
32-
save_freq : float
33-
Frequency at which results are saved (s)
34-
save_pts : list[float]
35-
Any additional savepoints (s) e.g., [10.1, 22.5]
22+
n_samples : int, optional
23+
Default number of samples to use. If not specified, a default value is used. If state is type UnweightedSamples and n_samples is not provided, the provided unweighted samples will be used directly.
24+
save_freq : float, optional
25+
Default frequency at which results are saved (s).
3626
"""
3727

28+
__DEFAULT_N_SAMPLES = 100 # Default number of samples to use, if none specified and not UncertainData
29+
3830
default_parameters = {
39-
'n_samples': 100 # Default number of samples to use, if none specified
31+
'n_samples': None
4032
}
4133

4234
def predict(self, state: UncertainData, future_loading_eqn: Callable, **kwargs) -> PredictionResults:
35+
"""
36+
Perform a single prediction
37+
38+
Parameters
39+
----------
40+
state : UncertainData
41+
Distribution representing current state of the system
42+
future_loading_eqn : function (t, x) -> z
43+
Function to generate an estimate of loading at future time t, and state x
44+
45+
Keyword Arguments
46+
------------------
47+
t0 : float, optional
48+
Initial time at which prediction begins, e.g., 0
49+
dt : float, optional
50+
Simulation step size (s), e.g., 0.1
51+
events : list[str], optional
52+
Events to predict (subset of model.events) e.g., ['event1', 'event2']
53+
horizon : float, optional
54+
Prediction horizon (s)
55+
n_samples : int, optional
56+
Number of samples to use. If not specified, a default value is used. If state is type UnweightedSamples and n_samples is not provided, the provided unweighted samples will be used directly.
57+
save_freq : float, optional
58+
Frequency at which results are saved (s)
59+
save_pts : list[float], optional
60+
Any additional savepoints (s) e.g., [10.1, 22.5]
61+
62+
Return
63+
----------
64+
result from prediction, including: NameTuple
65+
* times (List[float]): Times for each savepoint such that inputs.snapshot(i), states.snapshot(i), outputs.snapshot(i), and event_states.snapshot(i) are all at times[i]
66+
* inputs (Prediction): Inputs at each savepoint such that inputs.snapshot(i) is the input distribution (type UncertainData) at times[i]
67+
* states (Prediction): States at each savepoint such that states.snapshot(i) is the state distribution (type UncertainData) at times[i]
68+
* outputs (Prediction): Outputs at each savepoint such that outputs.snapshot(i) is the output distribution (type UncertainData) at times[i]
69+
* event_states (Prediction): Event states at each savepoint such that event_states.snapshot(i) is the event state distribution (type UncertainData) at times[i]
70+
* time_of_event (UncertainData): Distribution of predicted Time of Event (ToE) for each predicted event, represented by some subclass of UncertaintData (e.g., MultivariateNormalDist)
71+
"""
4372
if isinstance(state, dict) or isinstance(state, self.model.StateContainer):
4473
from progpy.uncertain_data import ScalarData
4574
state = ScalarData(state, _type = self.model.StateContainer)
@@ -53,11 +82,20 @@ def predict(self, state: UncertainData, future_loading_eqn: Callable, **kwargs)
5382
params['print'] = False
5483
params['progress'] = False
5584

85+
if not isinstance(state, UnweightedSamples) and params['n_samples'] is None:
86+
# if not unweighted samples, some sample number is required, so set to default.
87+
params['n_samples'] = MonteCarlo.__DEFAULT_N_SAMPLES
88+
elif isinstance(state, UnweightedSamples) and params['n_samples'] is None:
89+
params['n_samples'] = len(state) # number of samples is from provided state
90+
5691
if len(params['events']) == 0 and 'horizon' not in params:
5792
raise ValueError("If specifying no event (i.e., simulate to time), must specify horizon")
5893

59-
# Sample from state if n_samples specified or state is not UnweightedSamples
60-
if not isinstance(state, UnweightedSamples) or len(state) != params['n_samples']:
94+
# Sample from state if n_samples specified or state is not UnweightedSamples (Case 2)
95+
# Or if is Unweighted samples, but there are the wrong number of samples (Case 1)
96+
if (
97+
(isinstance(state, UnweightedSamples) and len(state) != params['n_samples']) # Case 1
98+
or not isinstance(state, UnweightedSamples)): # Case 2
6199
state = state.sample(params['n_samples'])
62100

63101
es_eqn = self.model.event_state

tests/test_predictors.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -385,6 +385,48 @@ def test_utp_surrogate(self):
385385

386386
def test_mc_surrogate(self):
387387
self._test_surrogate_pred(MonteCarlo)
388+
389+
def test_mc_num_samples(self):
390+
"""
391+
This test confirms that monte carlos sampling logic works as expected
392+
"""
393+
m = ThrownObject()
394+
def future_load(t, x=None):
395+
return m.InputContainer({})
396+
397+
pred = MonteCarlo(m)
398+
399+
# First test- scalar input
400+
x_scalar = ScalarData({'x': 10, 'v': 0})
401+
# Should default to 100 samples
402+
result = pred.predict(x_scalar, future_load)
403+
self.assertEqual(len(result.time_of_event), 100)
404+
# Repeat with less samples
405+
result = pred.predict(x_scalar, future_load, n_samples=10)
406+
self.assertEqual(len(result.time_of_event), 10)
407+
408+
# Second test- Same, but with multivariate normal input
409+
# Behavior should be the same
410+
x_mvnormal = MultivariateNormalDist(['x', 'v'], [10, 0], [[0.1, 0], [0, 0.1]])
411+
# Should default to 100 samples
412+
result = pred.predict(x_mvnormal, future_load)
413+
self.assertEqual(len(result.time_of_event), 100)
414+
# Repeat with less samples
415+
result = pred.predict(x_mvnormal, future_load, n_samples=10)
416+
self.assertEqual(len(result.time_of_event), 10)
417+
418+
# Third test- UnweightedSamples input
419+
x_uwsamples = UnweightedSamples([{'x': 10, 'v': 0}, {'x': 9.9, 'v': 0.1}, {'x': 10.1, 'v': -0.1}])
420+
# Should default to same as x_uwsamples - HERE IS THE DIFFERENCE FROM OTHER TYPES
421+
result = pred.predict(x_uwsamples, future_load)
422+
self.assertEqual(len(result.time_of_event), 3)
423+
# Should be exact same data, in the same order
424+
for i in range(3):
425+
self.assertEqual(result.states[i][0]['x'], x_uwsamples[i]['x'])
426+
self.assertEqual(result.states[i][0]['v'], x_uwsamples[i]['v'])
427+
# Repeat with more samples
428+
result = pred.predict(x_uwsamples, future_load, n_samples=10)
429+
self.assertEqual(len(result.time_of_event), 10)
388430

389431
# This allows the module to be executed directly
390432
def main():

0 commit comments

Comments
 (0)