Skip to content

Commit f7b01e3

Browse files
authored
Merge branch 'dev' into feature/particle_filter_fix
2 parents 9ba241a + ff234f0 commit f7b01e3

File tree

8 files changed

+125
-56
lines changed

8 files changed

+125
-56
lines changed

examples/basic_example.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,6 @@
2020
def run_example():
2121
# Step 1: Setup model & future loading
2222
m = ThrownObject(process_noise = 1)
23-
def future_loading(t, x = None):
24-
# No load for a thrown object
25-
return m.InputContainer({})
2623
initial_state = m.initialize()
2724

2825
# Step 2: Demonstrating state estimator
@@ -42,7 +39,7 @@ def future_loading(t, x = None):
4239
# Step 2c: Perform state estimation step, given some measurement, above what's expected
4340
example_measurements = m.OutputContainer({'x': 7.5})
4441
t = 0.1
45-
u = future_loading(t)
42+
u = m.InputContainer({})
4643
filt.estimate(t, u, example_measurements) # Update state, given (example) sensor data
4744

4845
# Step 2d: Print & Plot Resulting Posterior State
@@ -65,7 +62,7 @@ def future_loading(t, x = None):
6562
# Step 3b: Perform a prediction
6663
NUM_SAMPLES = 50
6764
STEP_SIZE = 0.01
68-
mc_results = mc.predict(filt.x, future_loading, n_samples = NUM_SAMPLES, dt=STEP_SIZE, save_freq=STEP_SIZE)
65+
mc_results = mc.predict(filt.x, n_samples = NUM_SAMPLES, dt=STEP_SIZE, save_freq=STEP_SIZE)
6966
print('Predicted time of event (ToE): ', mc_results.time_of_event.mean)
7067
# Here there are 2 events predicted, when the object starts falling, and when it impacts the ground.
7168

examples/horizon.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,7 @@
1818

1919
def run_example():
2020
# Step 1: Setup model & future loading
21-
def future_loading(t, x = None):
22-
return {}
23-
m = ThrownObject(process_noise = 0.25, measurement_noise = 0.2)
21+
m = ThrownObject(process_noise=0.25, measurement_noise=0.2)
2422
initial_state = m.initialize()
2523

2624
# Step 2: Demonstrating state estimator
@@ -53,7 +51,7 @@ def future_loading(t, x = None):
5351
PREDICTION_HORIZON = 7.75
5452
samples = filt.x # Since we're using a particle filter, which is also sample-based, we can directly use the samples, without changes
5553
STEP_SIZE = 0.01
56-
mc_results = mc.predict(samples, future_loading, dt=STEP_SIZE, horizon = PREDICTION_HORIZON)
54+
mc_results = mc.predict(samples, dt=STEP_SIZE, horizon=PREDICTION_HORIZON)
5755
print("\nPredicted Time of Event:")
5856
metrics = mc_results.time_of_event.metrics()
5957
pprint(metrics) # Note this takes some time

examples/predict_specific_event.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,6 @@ def run_example():
1212
m = ThrownObject()
1313
initial_state = m.initialize()
1414
load = m.InputContainer({}) # Optimization - create once
15-
def future_loading(t, x = None):
16-
return load
1715

1816
## State Estimation - perform a single ukf state estimate step
1917
filt = state_estimators.UnscentedKalmanFilter(m, initial_state)
@@ -24,7 +22,7 @@ def future_loading(t, x = None):
2422
pred = predictors.UnscentedTransformPredictor(m)
2523

2624
# Predict with a step size of 0.1
27-
mc_results = pred.predict(filt.x, future_loading, dt=0.1, save_freq= 1, events=['impact'])
25+
mc_results = pred.predict(filt.x, dt=0.1, save_freq= 1, events=['impact'])
2826

2927
# Print Results
3028
for i, time in enumerate(mc_results.times):

examples/sensitivity.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,22 +14,18 @@ def run_example():
1414
# Step 1: Create instance of model
1515
m = ThrownObject()
1616

17-
# Step 2: Setup for simulation
18-
def future_load(t, x=None):
19-
return m.InputContainer({})
20-
21-
# Step 3: Setup range on parameters considered
17+
# Step 2: Setup range on parameters considered
2218
thrower_height_range = np.arange(1.2, 2.1, 0.1)
2319

24-
# Step 4: Sim for each
20+
# Step 3: Sim for each
2521
event = 'impact'
2622
eods = np.empty(len(thrower_height_range))
2723
for (i, thrower_height) in zip(range(len(thrower_height_range)), thrower_height_range):
2824
m.parameters['thrower_height'] = thrower_height
29-
simulated_results = m.simulate_to_threshold(future_load, threshold_keys=[event], dt =1e-3, save_freq =10)
25+
simulated_results = m.simulate_to_threshold(threshold_keys=[event], dt =1e-3, save_freq =10)
3026
eods[i] = simulated_results.times[-1]
3127

32-
# Step 5: Analysis
28+
# Step 4: Analysis
3329
print('For a reasonable range of heights, impact time is between {} and {}'.format(round(eods[0],3), round(eods[-1],3)))
3430
sensitivity = (eods[-1]-eods[0])/(thrower_height_range[-1] - thrower_height_range[0])
3531
print(' - Average sensitivity: {} s per cm height'.format(round(sensitivity/100, 6)))
@@ -40,7 +36,7 @@ def future_load(t, x=None):
4036
eods = np.empty(len(throw_speed_range))
4137
for (i, throw_speed) in zip(range(len(throw_speed_range)), throw_speed_range):
4238
m.parameters['throwing_speed'] = throw_speed
43-
simulated_results = m.simulate_to_threshold(future_load, threshold_keys=[event], options={'dt':1e-3, 'save_freq':10})
39+
simulated_results = m.simulate_to_threshold(threshold_keys=[event], options={'dt':1e-3, 'save_freq':10})
4440
eods[i] = simulated_results.times[-1]
4541

4642
print('\nFor a reasonable range of throwing speeds, impact time is between {} and {}'.format(round(eods[0],3), round(eods[-1],3)))

examples/state_limits.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,7 @@ def run_example():
1515
# Step 1: Create instance of model (without drag)
1616
m = ThrownObject( cd = 0 )
1717

18-
# Step 2: Setup for simulation
19-
def future_load(t, x=None):
20-
return {}
21-
22-
# add state limits
18+
# Step 2: add state limits
2319
m.state_limits = {
2420
# object may not go below ground height
2521
'x': (0, inf),
@@ -30,7 +26,7 @@ def future_load(t, x=None):
3026

3127
# Step 3: Simulate to impact
3228
event = 'impact'
33-
simulated_results = m.simulate_to_threshold(future_load, threshold_keys=[event], dt=0.005, save_freq=1)
29+
simulated_results = m.simulate_to_threshold(threshold_keys=[event], dt=0.005, save_freq=1)
3430

3531
# Print states
3632
print('Example 1')
@@ -42,7 +38,7 @@ def future_load(t, x=None):
4238
x0 = m.initialize(u = {}, z = {})
4339
x0['x'] = -1
4440

45-
simulated_results = m.simulate_to_threshold(future_load, threshold_keys=[event], dt=0.005, save_freq=1, x = x0)
41+
simulated_results = m.simulate_to_threshold(threshold_keys=[event], dt=0.005, save_freq=1, x=x0)
4642

4743
# Print states
4844
print('Example 2')
@@ -57,7 +53,7 @@ def future_load(t, x=None):
5753
m.parameters['g'] = -50000000
5854

5955
print('Example 3')
60-
simulated_results = m.simulate_to_threshold(future_load, threshold_keys=[event], dt=0.005, save_freq=0.3, x = x0, print = True, progress = False)
56+
simulated_results = m.simulate_to_threshold(threshold_keys=[event], dt=0.005, save_freq=0.3, x=x0, print=True, progress=False)
6157

6258
# Note that the limits can also be applied manually using the apply_limits function
6359
print('limiting states')

src/progpy/predictors/monte_carlo.py

Lines changed: 59 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -19,27 +19,56 @@ class MonteCarlo(Predictor):
1919
2020
Configuration Parameters
2121
------------------------------
22-
t0 : float
23-
Initial time at which prediction begins, e.g., 0
24-
dt : float
25-
Simulation step size (s), e.g., 0.1
26-
events : list[str]
27-
Events to predict (subset of model.events) e.g., ['event1', 'event2']
28-
horizon : float
29-
Prediction horizon (s)
30-
n_samples : int
31-
Number of samples to use. If not specified, a default value is used. If state is type UnweightedSamples and n_samples is not provided, the provided unweighted samples will be used directly.
32-
save_freq : float
33-
Frequency at which results are saved (s)
34-
save_pts : list[float]
35-
Any additional savepoints (s) e.g., [10.1, 22.5]
22+
n_samples : int, optional
23+
Default number of samples to use. If not specified, a default value is used. If state is type UnweightedSamples and n_samples is not provided, the provided unweighted samples will be used directly.
24+
save_freq : float, optional
25+
Default frequency at which results are saved (s).
3626
"""
3727

28+
__DEFAULT_N_SAMPLES = 100 # Default number of samples to use, if none specified and not UncertainData
29+
3830
default_parameters = {
39-
'n_samples': 100 # Default number of samples to use, if none specified
31+
'n_samples': None
4032
}
4133

42-
def predict(self, state: UncertainData, future_loading_eqn: Callable, **kwargs) -> PredictionResults:
34+
def predict(self, state: UncertainData, future_loading_eqn: Callable = None, **kwargs) -> PredictionResults:
35+
"""
36+
Perform a single prediction
37+
38+
Parameters
39+
----------
40+
state : UncertainData
41+
Distribution representing current state of the system
42+
future_loading_eqn : function (t, x=None) -> z, optional
43+
Function to generate an estimate of loading at future time t, and state x
44+
45+
Keyword Arguments
46+
------------------
47+
t0 : float, optional
48+
Initial time at which prediction begins, e.g., 0
49+
dt : float, optional
50+
Simulation step size (s), e.g., 0.1
51+
events : list[str], optional
52+
Events to predict (subset of model.events) e.g., ['event1', 'event2']
53+
horizon : float, optional
54+
Prediction horizon (s)
55+
n_samples : int, optional
56+
Number of samples to use. If not specified, a default value is used. If state is type UnweightedSamples and n_samples is not provided, the provided unweighted samples will be used directly.
57+
save_freq : float, optional
58+
Frequency at which results are saved (s)
59+
save_pts : list[float], optional
60+
Any additional savepoints (s) e.g., [10.1, 22.5]
61+
62+
Return
63+
----------
64+
result from prediction, including: NameTuple
65+
* times (List[float]): Times for each savepoint such that inputs.snapshot(i), states.snapshot(i), outputs.snapshot(i), and event_states.snapshot(i) are all at times[i]
66+
* inputs (Prediction): Inputs at each savepoint such that inputs.snapshot(i) is the input distribution (type UncertainData) at times[i]
67+
* states (Prediction): States at each savepoint such that states.snapshot(i) is the state distribution (type UncertainData) at times[i]
68+
* outputs (Prediction): Outputs at each savepoint such that outputs.snapshot(i) is the output distribution (type UncertainData) at times[i]
69+
* event_states (Prediction): Event states at each savepoint such that event_states.snapshot(i) is the event state distribution (type UncertainData) at times[i]
70+
* time_of_event (UncertainData): Distribution of predicted Time of Event (ToE) for each predicted event, represented by some subclass of UncertaintData (e.g., MultivariateNormalDist)
71+
"""
4372
if isinstance(state, dict) or isinstance(state, self.model.StateContainer):
4473
from progpy.uncertain_data import ScalarData
4574
state = ScalarData(state, _type = self.model.StateContainer)
@@ -48,16 +77,28 @@ def predict(self, state: UncertainData, future_loading_eqn: Callable, **kwargs)
4877
else:
4978
raise TypeError("state must be UncertainData, dict, or StateContainer")
5079

80+
if future_loading_eqn is None:
81+
future_loading_eqn = lambda t, x=None: self.model.InputContainer({})
82+
5183
params = deepcopy(self.parameters) # copy parameters
5284
params.update(kwargs) # update for specific run
5385
params['print'] = False
5486
params['progress'] = False
5587

88+
if not isinstance(state, UnweightedSamples) and params['n_samples'] is None:
89+
# if not unweighted samples, some sample number is required, so set to default.
90+
params['n_samples'] = MonteCarlo.__DEFAULT_N_SAMPLES
91+
elif isinstance(state, UnweightedSamples) and params['n_samples'] is None:
92+
params['n_samples'] = len(state) # number of samples is from provided state
93+
5694
if len(params['events']) == 0 and 'horizon' not in params:
5795
raise ValueError("If specifying no event (i.e., simulate to time), must specify horizon")
5896

59-
# Sample from state if n_samples specified or state is not UnweightedSamples
60-
if not isinstance(state, UnweightedSamples) or len(state) != params['n_samples']:
97+
# Sample from state if n_samples specified or state is not UnweightedSamples (Case 2)
98+
# Or if is Unweighted samples, but there are the wrong number of samples (Case 1)
99+
if (
100+
(isinstance(state, UnweightedSamples) and len(state) != params['n_samples']) # Case 1
101+
or not isinstance(state, UnweightedSamples)): # Case 2
61102
state = state.sample(params['n_samples'])
62103

63104
es_eqn = self.model.event_state

src/progpy/predictors/unscented_transform.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -123,14 +123,14 @@ def state_transition(x, dt):
123123
self.filter = kalman.UnscentedKalmanFilter(num_states, num_measurements, self.parameters['dt'], measure, state_transition, self.sigma_points)
124124
self.filter.Q = self.parameters['Q']
125125

126-
def predict(self, state, future_loading_eqn: Callable, **kwargs) -> PredictionResults:
126+
def predict(self, state, future_loading_eqn: Callable = None, **kwargs) -> PredictionResults:
127127
"""
128128
Perform a single prediction
129129
130130
Parameters
131131
----------
132132
state (UncertaintData): Distribution of states
133-
future_loading_eqn : function (t, x={}) -> z
133+
future_loading_eqn: function (t, x=None) -> z, optional
134134
Function to generate an estimate of loading at future time t
135135
options (optional, kwargs): configuration options\n
136136
Any additional configuration values. Note: These parameters can also be specified in the predictor constructor. The following configuration parameters are supported: \n
@@ -169,6 +169,9 @@ def predict(self, state, future_loading_eqn: Callable, **kwargs) -> PredictionRe
169169
else:
170170
raise TypeError("state must be UncertainData, dict, or StateContainer")
171171

172+
if future_loading_eqn is None:
173+
future_loading_eqn = lambda t, x=None: self.model.InputContainer({})
174+
172175
params = deepcopy(self.parameters) # copy parameters
173176
params.update(kwargs) # update for specific run
174177

tests/test_predictors.py

Lines changed: 47 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -66,10 +66,9 @@ def test_UTP_ThrownObject(self):
6666
m = ThrownObject()
6767
pred = UnscentedTransformPredictor(m)
6868
samples = MultivariateNormalDist(['x', 'v'], [1.83, 40], [[0.1, 0.01], [0.01, 0.1]])
69-
def future_loading(t, x={}):
70-
return {}
7169

72-
mc_results = pred.predict(samples, future_loading, dt=0.01, save_freq=1)
70+
# No future loading (i.e., no load)
71+
mc_results = pred.predict(samples, dt=0.01, save_freq=1)
7372
self.assertAlmostEqual(mc_results.time_of_event.mean['impact'], 8.21, 0)
7473
self.assertAlmostEqual(mc_results.time_of_event.mean['falling'], 4.15, 0)
7574
# self.assertAlmostEqual(mc_results.times[-1], 9, 1) # Saving every second, last time should be around the 1s after impact event (because one of the sigma points fails afterwards)
@@ -126,10 +125,9 @@ def future_loading(t, x=None):
126125
def test_MC(self):
127126
m = ThrownObject()
128127
mc = MonteCarlo(m)
129-
def future_loading(t=None, x=None):
130-
return {}
131-
132-
mc.predict(m.initialize(), future_loading, dt=0.2, num_samples=3, save_freq=1)
128+
129+
# Test with empty future loading (i.e., no load)
130+
mc.predict(m.initialize(), dt=0.2, num_samples=3, save_freq=1)
133131

134132
def test_prediction_mvnormaldist(self):
135133
times = list(range(10))
@@ -385,6 +383,48 @@ def test_utp_surrogate(self):
385383

386384
def test_mc_surrogate(self):
387385
self._test_surrogate_pred(MonteCarlo)
386+
387+
def test_mc_num_samples(self):
388+
"""
389+
This test confirms that monte carlos sampling logic works as expected
390+
"""
391+
m = ThrownObject()
392+
def future_load(t, x=None):
393+
return m.InputContainer({})
394+
395+
pred = MonteCarlo(m)
396+
397+
# First test- scalar input
398+
x_scalar = ScalarData({'x': 10, 'v': 0})
399+
# Should default to 100 samples
400+
result = pred.predict(x_scalar, future_load)
401+
self.assertEqual(len(result.time_of_event), 100)
402+
# Repeat with less samples
403+
result = pred.predict(x_scalar, future_load, n_samples=10)
404+
self.assertEqual(len(result.time_of_event), 10)
405+
406+
# Second test- Same, but with multivariate normal input
407+
# Behavior should be the same
408+
x_mvnormal = MultivariateNormalDist(['x', 'v'], [10, 0], [[0.1, 0], [0, 0.1]])
409+
# Should default to 100 samples
410+
result = pred.predict(x_mvnormal, future_load)
411+
self.assertEqual(len(result.time_of_event), 100)
412+
# Repeat with less samples
413+
result = pred.predict(x_mvnormal, future_load, n_samples=10)
414+
self.assertEqual(len(result.time_of_event), 10)
415+
416+
# Third test- UnweightedSamples input
417+
x_uwsamples = UnweightedSamples([{'x': 10, 'v': 0}, {'x': 9.9, 'v': 0.1}, {'x': 10.1, 'v': -0.1}])
418+
# Should default to same as x_uwsamples - HERE IS THE DIFFERENCE FROM OTHER TYPES
419+
result = pred.predict(x_uwsamples, future_load)
420+
self.assertEqual(len(result.time_of_event), 3)
421+
# Should be exact same data, in the same order
422+
for i in range(3):
423+
self.assertEqual(result.states[i][0]['x'], x_uwsamples[i]['x'])
424+
self.assertEqual(result.states[i][0]['v'], x_uwsamples[i]['v'])
425+
# Repeat with more samples
426+
result = pred.predict(x_uwsamples, future_load, n_samples=10)
427+
self.assertEqual(len(result.time_of_event), 10)
388428

389429
# This allows the module to be executed directly
390430
def main():

0 commit comments

Comments
 (0)