Skip to content

Commit ff234f0

Browse files
authored
Merge pull request #78 from nasa/feature/empty_future_loading
Make future loading optional in prediction
2 parents 670e1cf + a34172e commit ff234f0

File tree

8 files changed

+29
-40
lines changed

8 files changed

+29
-40
lines changed

examples/basic_example.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,6 @@
2020
def run_example():
2121
# Step 1: Setup model & future loading
2222
m = ThrownObject(process_noise = 1)
23-
def future_loading(t, x = None):
24-
# No load for a thrown object
25-
return m.InputContainer({})
2623
initial_state = m.initialize()
2724

2825
# Step 2: Demonstrating state estimator
@@ -42,7 +39,7 @@ def future_loading(t, x = None):
4239
# Step 2c: Perform state estimation step, given some measurement, above what's expected
4340
example_measurements = m.OutputContainer({'x': 7.5})
4441
t = 0.1
45-
u = future_loading(t)
42+
u = m.InputContainer({})
4643
filt.estimate(t, u, example_measurements) # Update state, given (example) sensor data
4744

4845
# Step 2d: Print & Plot Resulting Posterior State
@@ -65,7 +62,7 @@ def future_loading(t, x = None):
6562
# Step 3b: Perform a prediction
6663
NUM_SAMPLES = 50
6764
STEP_SIZE = 0.01
68-
mc_results = mc.predict(filt.x, future_loading, n_samples = NUM_SAMPLES, dt=STEP_SIZE, save_freq=STEP_SIZE)
65+
mc_results = mc.predict(filt.x, n_samples = NUM_SAMPLES, dt=STEP_SIZE, save_freq=STEP_SIZE)
6966
print('Predicted time of event (ToE): ', mc_results.time_of_event.mean)
7067
# Here there are 2 events predicted, when the object starts falling, and when it impacts the ground.
7168

examples/horizon.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,7 @@
1818

1919
def run_example():
2020
# Step 1: Setup model & future loading
21-
def future_loading(t, x = None):
22-
return {}
23-
m = ThrownObject(process_noise = 0.25, measurement_noise = 0.2)
21+
m = ThrownObject(process_noise=0.25, measurement_noise=0.2)
2422
initial_state = m.initialize()
2523

2624
# Step 2: Demonstrating state estimator
@@ -53,7 +51,7 @@ def future_loading(t, x = None):
5351
PREDICTION_HORIZON = 7.75
5452
samples = filt.x # Since we're using a particle filter, which is also sample-based, we can directly use the samples, without changes
5553
STEP_SIZE = 0.01
56-
mc_results = mc.predict(samples, future_loading, dt=STEP_SIZE, horizon = PREDICTION_HORIZON)
54+
mc_results = mc.predict(samples, dt=STEP_SIZE, horizon=PREDICTION_HORIZON)
5755
print("\nPredicted Time of Event:")
5856
metrics = mc_results.time_of_event.metrics()
5957
pprint(metrics) # Note this takes some time

examples/predict_specific_event.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,6 @@ def run_example():
1212
m = ThrownObject()
1313
initial_state = m.initialize()
1414
load = m.InputContainer({}) # Optimization - create once
15-
def future_loading(t, x = None):
16-
return load
1715

1816
## State Estimation - perform a single ukf state estimate step
1917
filt = state_estimators.UnscentedKalmanFilter(m, initial_state)
@@ -24,7 +22,7 @@ def future_loading(t, x = None):
2422
pred = predictors.UnscentedTransformPredictor(m)
2523

2624
# Predict with a step size of 0.1
27-
mc_results = pred.predict(filt.x, future_loading, dt=0.1, save_freq= 1, events=['impact'])
25+
mc_results = pred.predict(filt.x, dt=0.1, save_freq= 1, events=['impact'])
2826

2927
# Print Results
3028
for i, time in enumerate(mc_results.times):

examples/sensitivity.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,22 +14,18 @@ def run_example():
1414
# Step 1: Create instance of model
1515
m = ThrownObject()
1616

17-
# Step 2: Setup for simulation
18-
def future_load(t, x=None):
19-
return m.InputContainer({})
20-
21-
# Step 3: Setup range on parameters considered
17+
# Step 2: Setup range on parameters considered
2218
thrower_height_range = np.arange(1.2, 2.1, 0.1)
2319

24-
# Step 4: Sim for each
20+
# Step 3: Sim for each
2521
event = 'impact'
2622
eods = np.empty(len(thrower_height_range))
2723
for (i, thrower_height) in zip(range(len(thrower_height_range)), thrower_height_range):
2824
m.parameters['thrower_height'] = thrower_height
29-
simulated_results = m.simulate_to_threshold(future_load, threshold_keys=[event], dt =1e-3, save_freq =10)
25+
simulated_results = m.simulate_to_threshold(threshold_keys=[event], dt =1e-3, save_freq =10)
3026
eods[i] = simulated_results.times[-1]
3127

32-
# Step 5: Analysis
28+
# Step 4: Analysis
3329
print('For a reasonable range of heights, impact time is between {} and {}'.format(round(eods[0],3), round(eods[-1],3)))
3430
sensitivity = (eods[-1]-eods[0])/(thrower_height_range[-1] - thrower_height_range[0])
3531
print(' - Average sensitivity: {} s per cm height'.format(round(sensitivity/100, 6)))
@@ -40,7 +36,7 @@ def future_load(t, x=None):
4036
eods = np.empty(len(throw_speed_range))
4137
for (i, throw_speed) in zip(range(len(throw_speed_range)), throw_speed_range):
4238
m.parameters['throwing_speed'] = throw_speed
43-
simulated_results = m.simulate_to_threshold(future_load, threshold_keys=[event], options={'dt':1e-3, 'save_freq':10})
39+
simulated_results = m.simulate_to_threshold(threshold_keys=[event], options={'dt':1e-3, 'save_freq':10})
4440
eods[i] = simulated_results.times[-1]
4541

4642
print('\nFor a reasonable range of throwing speeds, impact time is between {} and {}'.format(round(eods[0],3), round(eods[-1],3)))

examples/state_limits.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,7 @@ def run_example():
1515
# Step 1: Create instance of model (without drag)
1616
m = ThrownObject( cd = 0 )
1717

18-
# Step 2: Setup for simulation
19-
def future_load(t, x=None):
20-
return {}
21-
22-
# add state limits
18+
# Step 2: add state limits
2319
m.state_limits = {
2420
# object may not go below ground height
2521
'x': (0, inf),
@@ -30,7 +26,7 @@ def future_load(t, x=None):
3026

3127
# Step 3: Simulate to impact
3228
event = 'impact'
33-
simulated_results = m.simulate_to_threshold(future_load, threshold_keys=[event], dt=0.005, save_freq=1)
29+
simulated_results = m.simulate_to_threshold(threshold_keys=[event], dt=0.005, save_freq=1)
3430

3531
# Print states
3632
print('Example 1')
@@ -42,7 +38,7 @@ def future_load(t, x=None):
4238
x0 = m.initialize(u = {}, z = {})
4339
x0['x'] = -1
4440

45-
simulated_results = m.simulate_to_threshold(future_load, threshold_keys=[event], dt=0.005, save_freq=1, x = x0)
41+
simulated_results = m.simulate_to_threshold(threshold_keys=[event], dt=0.005, save_freq=1, x=x0)
4642

4743
# Print states
4844
print('Example 2')
@@ -57,7 +53,7 @@ def future_load(t, x=None):
5753
m.parameters['g'] = -50000000
5854

5955
print('Example 3')
60-
simulated_results = m.simulate_to_threshold(future_load, threshold_keys=[event], dt=0.005, save_freq=0.3, x = x0, print = True, progress = False)
56+
simulated_results = m.simulate_to_threshold(threshold_keys=[event], dt=0.005, save_freq=0.3, x=x0, print=True, progress=False)
6157

6258
# Note that the limits can also be applied manually using the apply_limits function
6359
print('limiting states')

src/progpy/predictors/monte_carlo.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,15 +31,15 @@ class MonteCarlo(Predictor):
3131
'n_samples': None
3232
}
3333

34-
def predict(self, state: UncertainData, future_loading_eqn: Callable, **kwargs) -> PredictionResults:
34+
def predict(self, state: UncertainData, future_loading_eqn: Callable = None, **kwargs) -> PredictionResults:
3535
"""
3636
Perform a single prediction
3737
3838
Parameters
3939
----------
4040
state : UncertainData
4141
Distribution representing current state of the system
42-
future_loading_eqn : function (t, x) -> z
42+
future_loading_eqn : function (t, x=None) -> z, optional
4343
Function to generate an estimate of loading at future time t, and state x
4444
4545
Keyword Arguments
@@ -77,6 +77,9 @@ def predict(self, state: UncertainData, future_loading_eqn: Callable, **kwargs)
7777
else:
7878
raise TypeError("state must be UncertainData, dict, or StateContainer")
7979

80+
if future_loading_eqn is None:
81+
future_loading_eqn = lambda t, x=None: self.model.InputContainer({})
82+
8083
params = deepcopy(self.parameters) # copy parameters
8184
params.update(kwargs) # update for specific run
8285
params['print'] = False

src/progpy/predictors/unscented_transform.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -123,14 +123,14 @@ def state_transition(x, dt):
123123
self.filter = kalman.UnscentedKalmanFilter(num_states, num_measurements, self.parameters['dt'], measure, state_transition, self.sigma_points)
124124
self.filter.Q = self.parameters['Q']
125125

126-
def predict(self, state, future_loading_eqn: Callable, **kwargs) -> PredictionResults:
126+
def predict(self, state, future_loading_eqn: Callable = None, **kwargs) -> PredictionResults:
127127
"""
128128
Perform a single prediction
129129
130130
Parameters
131131
----------
132132
state (UncertaintData): Distribution of states
133-
future_loading_eqn : function (t, x={}) -> z
133+
future_loading_eqn: function (t, x=None) -> z, optional
134134
Function to generate an estimate of loading at future time t
135135
options (optional, kwargs): configuration options\n
136136
Any additional configuration values. Note: These parameters can also be specified in the predictor constructor. The following configuration parameters are supported: \n
@@ -169,6 +169,9 @@ def predict(self, state, future_loading_eqn: Callable, **kwargs) -> PredictionRe
169169
else:
170170
raise TypeError("state must be UncertainData, dict, or StateContainer")
171171

172+
if future_loading_eqn is None:
173+
future_loading_eqn = lambda t, x=None: self.model.InputContainer({})
174+
172175
params = deepcopy(self.parameters) # copy parameters
173176
params.update(kwargs) # update for specific run
174177

tests/test_predictors.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -66,10 +66,9 @@ def test_UTP_ThrownObject(self):
6666
m = ThrownObject()
6767
pred = UnscentedTransformPredictor(m)
6868
samples = MultivariateNormalDist(['x', 'v'], [1.83, 40], [[0.1, 0.01], [0.01, 0.1]])
69-
def future_loading(t, x={}):
70-
return {}
7169

72-
mc_results = pred.predict(samples, future_loading, dt=0.01, save_freq=1)
70+
# No future loading (i.e., no load)
71+
mc_results = pred.predict(samples, dt=0.01, save_freq=1)
7372
self.assertAlmostEqual(mc_results.time_of_event.mean['impact'], 8.21, 0)
7473
self.assertAlmostEqual(mc_results.time_of_event.mean['falling'], 4.15, 0)
7574
# self.assertAlmostEqual(mc_results.times[-1], 9, 1) # Saving every second, last time should be around the 1s after impact event (because one of the sigma points fails afterwards)
@@ -126,10 +125,9 @@ def future_loading(t, x=None):
126125
def test_MC(self):
127126
m = ThrownObject()
128127
mc = MonteCarlo(m)
129-
def future_loading(t=None, x=None):
130-
return {}
131-
132-
mc.predict(m.initialize(), future_loading, dt=0.2, num_samples=3, save_freq=1)
128+
129+
# Test with empty future loading (i.e., no load)
130+
mc.predict(m.initialize(), dt=0.2, num_samples=3, save_freq=1)
133131

134132
def test_prediction_mvnormaldist(self):
135133
times = list(range(10))

0 commit comments

Comments
 (0)