|
| 1 | +# Copyright © 2021 United States Government as represented by the Administrator of the |
| 2 | +# National Aeronautics and Space Administration. All Rights Reserved. |
| 3 | + |
| 4 | +from io import StringIO |
| 5 | +import sys |
| 6 | +import unittest |
| 7 | + |
| 8 | +from progpy import predictors |
| 9 | +from progpy.models import ThrownObject |
| 10 | + |
| 11 | +class TestHorizon(unittest.TestCase): |
| 12 | + def setUp(self): |
| 13 | + # set stdout (so it won't print) |
| 14 | + sys.stdout = StringIO() |
| 15 | + |
| 16 | + def tearDown(self): |
| 17 | + sys.stdout = sys.__stdout__ |
| 18 | + |
| 19 | + def test_horizon_ex(self): |
| 20 | + # Setup model |
| 21 | + m = ThrownObject(process_noise=0.25, measurement_noise=0.2) |
| 22 | + # Change parameters (to make simulation faster) |
| 23 | + m.parameters['thrower_height'] = 1.0 |
| 24 | + m.parameters['throwing_speed'] = 10.0 |
| 25 | + initial_state = m.initialize() |
| 26 | + |
| 27 | + # Define future loading (necessary for prediction call) |
| 28 | + def future_loading(t, x=None): |
| 29 | + return {} |
| 30 | + |
| 31 | + # Setup Predictor (smaller sample size for efficiency) |
| 32 | + mc = predictors.MonteCarlo(m) |
| 33 | + mc.parameters['n_samples'] = 50 |
| 34 | + |
| 35 | + # Perform a prediction |
| 36 | + # With this horizon, all samples will reach 'falling', but only some will reach 'impact' |
| 37 | + PREDICTION_HORIZON = 2.127 |
| 38 | + STEP_SIZE = 0.001 |
| 39 | + mc_results = mc.predict(initial_state, future_loading, dt=STEP_SIZE, horizon = PREDICTION_HORIZON) |
| 40 | + |
| 41 | + # 'falling' happens before the horizon is met |
| 42 | + falling_res = [mc_results.time_of_event[iter]['falling'] for iter in range(mc.parameters['n_samples']) if mc_results.time_of_event[iter]['falling'] is not None] |
| 43 | + self.assertEqual(len(falling_res), mc.parameters['n_samples']) |
| 44 | + |
| 45 | + # 'impact' happens around the horizon, so some samples have reached this event and others haven't |
| 46 | + impact_res = [mc_results.time_of_event[iter]['impact'] for iter in range(mc.parameters['n_samples']) if mc_results.time_of_event[iter]['impact'] is not None] |
| 47 | + self.assertLess(len(impact_res), mc.parameters['n_samples']) |
| 48 | + |
| 49 | + # Try again with very low prediction_horizon, where no events are reached |
| 50 | + # Note: here we count how many None values there are for each event (in the above and below examples, we count values that are NOT None) |
| 51 | + mc_results_no_event = mc.predict(initial_state, future_loading, dt=STEP_SIZE, horizon=0.3) |
| 52 | + falling_res_no_event = [mc_results_no_event.time_of_event[iter]['falling'] for iter in range(mc.parameters['n_samples']) if mc_results_no_event.time_of_event[iter]['falling'] is None] |
| 53 | + impact_res_no_event = [mc_results_no_event.time_of_event[iter]['impact'] for iter in range(mc.parameters['n_samples']) if mc_results_no_event.time_of_event[iter]['impact'] is None] |
| 54 | + self.assertEqual(len(falling_res_no_event), mc.parameters['n_samples']) |
| 55 | + self.assertEqual(len(impact_res_no_event), mc.parameters['n_samples']) |
| 56 | + |
| 57 | + # Finally, try without horizon, all events should be reached for all samples |
| 58 | + mc_results_all_event = mc.predict(initial_state, future_loading, dt=STEP_SIZE) |
| 59 | + falling_res_all_event = [mc_results_all_event.time_of_event[iter]['falling'] for iter in range(mc.parameters['n_samples']) if mc_results_all_event.time_of_event[iter]['falling'] is not None] |
| 60 | + impact_res_all_event = [mc_results_all_event.time_of_event[iter]['impact'] for iter in range(mc.parameters['n_samples']) if mc_results_all_event.time_of_event[iter]['impact'] is not None] |
| 61 | + self.assertEqual(len(falling_res_all_event), mc.parameters['n_samples']) |
| 62 | + self.assertEqual(len(impact_res_all_event), mc.parameters['n_samples']) |
| 63 | + |
| 64 | +# This allows the module to be executed directly |
| 65 | +def run_tests(): |
| 66 | + unittest.main() |
| 67 | + |
| 68 | +def main(): |
| 69 | + load_test = unittest.TestLoader() |
| 70 | + runner = unittest.TextTestRunner() |
| 71 | + print("\n\nTesting Horizon functionality") |
| 72 | + result = runner.run(load_test.loadTestsFromTestCase(TestHorizon)).wasSuccessful() |
| 73 | + |
| 74 | + if not result: |
| 75 | + raise Exception("Failed test") |
| 76 | + |
| 77 | +if __name__ == '__main__': |
| 78 | + main() |
0 commit comments