Skip to content

Commit bf99af4

Browse files
committed
Add test
1 parent 1765416 commit bf99af4

File tree

1 file changed

+42
-0
lines changed

1 file changed

+42
-0
lines changed

tests/test_predictors.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -385,6 +385,48 @@ def test_utp_surrogate(self):
385385

386386
def test_mc_surrogate(self):
387387
self._test_surrogate_pred(MonteCarlo)
388+
389+
def test_mc_num_samples(self):
390+
"""
391+
This test confirms that monte carlos sampling logic works as expected
392+
"""
393+
m = ThrownObject()
394+
def future_load(t, x=None):
395+
return m.InputContainer({})
396+
397+
pred = MonteCarlo(m)
398+
399+
# First test- scalar input
400+
x_scalar = ScalarData({'x': 10, 'v': 0})
401+
# Should default to 100 samples
402+
result = pred.predict(x_scalar, future_load)
403+
self.assertEqual(len(result.time_of_event), 100)
404+
# Repeat with less samples
405+
result = pred.predict(x_scalar, future_load, n_samples=10)
406+
self.assertEqual(len(result.time_of_event), 10)
407+
408+
# Second test- Same, but with multivariate normal input
409+
# Behavior should be the same
410+
x_mvnormal = MultivariateNormalDist(['x', 'v'], [10, 0], [[0.1, 0], [0, 0.1]])
411+
# Should default to 100 samples
412+
result = pred.predict(x_mvnormal, future_load)
413+
self.assertEqual(len(result.time_of_event), 100)
414+
# Repeat with less samples
415+
result = pred.predict(x_mvnormal, future_load, n_samples=10)
416+
self.assertEqual(len(result.time_of_event), 10)
417+
418+
# Third test- UnweightedSamples input
419+
x_uwsamples = UnweightedSamples([{'x': 10, 'v': 0}, {'x': 9.9, 'v': 0.1}, {'x': 10.1, 'v': -0.1}])
420+
# Should default to same as x_uwsamples - HERE IS THE DIFFERENCE FROM OTHER TYPES
421+
result = pred.predict(x_uwsamples, future_load)
422+
self.assertEqual(len(result.time_of_event), 3)
423+
# Should be exact same data, in the same order
424+
for i in range(3):
425+
self.assertEqual(result.states[i][0]['x'], x_uwsamples[i]['x'])
426+
self.assertEqual(result.states[i][0]['v'], x_uwsamples[i]['v'])
427+
# Repeat with more samples
428+
result = pred.predict(x_uwsamples, future_load, n_samples=10)
429+
self.assertEqual(len(result.time_of_event), 10)
388430

389431
# This allows the module to be executed directly
390432
def main():

0 commit comments

Comments
 (0)