@@ -385,6 +385,48 @@ def test_utp_surrogate(self):
385385
386386 def test_mc_surrogate (self ):
387387 self ._test_surrogate_pred (MonteCarlo )
388+
389+ def test_mc_num_samples (self ):
390+ """
391+ This test confirms that monte carlos sampling logic works as expected
392+ """
393+ m = ThrownObject ()
394+ def future_load (t , x = None ):
395+ return m .InputContainer ({})
396+
397+ pred = MonteCarlo (m )
398+
399+ # First test- scalar input
400+ x_scalar = ScalarData ({'x' : 10 , 'v' : 0 })
401+ # Should default to 100 samples
402+ result = pred .predict (x_scalar , future_load )
403+ self .assertEqual (len (result .time_of_event ), 100 )
404+ # Repeat with less samples
405+ result = pred .predict (x_scalar , future_load , n_samples = 10 )
406+ self .assertEqual (len (result .time_of_event ), 10 )
407+
408+ # Second test- Same, but with multivariate normal input
409+ # Behavior should be the same
410+ x_mvnormal = MultivariateNormalDist (['x' , 'v' ], [10 , 0 ], [[0.1 , 0 ], [0 , 0.1 ]])
411+ # Should default to 100 samples
412+ result = pred .predict (x_mvnormal , future_load )
413+ self .assertEqual (len (result .time_of_event ), 100 )
414+ # Repeat with less samples
415+ result = pred .predict (x_mvnormal , future_load , n_samples = 10 )
416+ self .assertEqual (len (result .time_of_event ), 10 )
417+
418+ # Third test- UnweightedSamples input
419+ x_uwsamples = UnweightedSamples ([{'x' : 10 , 'v' : 0 }, {'x' : 9.9 , 'v' : 0.1 }, {'x' : 10.1 , 'v' : - 0.1 }])
420+ # Should default to same as x_uwsamples - HERE IS THE DIFFERENCE FROM OTHER TYPES
421+ result = pred .predict (x_uwsamples , future_load )
422+ self .assertEqual (len (result .time_of_event ), 3 )
423+ # Should be exact same data, in the same order
424+ for i in range (3 ):
425+ self .assertEqual (result .states [i ][0 ]['x' ], x_uwsamples [i ]['x' ])
426+ self .assertEqual (result .states [i ][0 ]['v' ], x_uwsamples [i ]['v' ])
427+ # Repeat with more samples
428+ result = pred .predict (x_uwsamples , future_load , n_samples = 10 )
429+ self .assertEqual (len (result .time_of_event ), 10 )
388430
389431# This allows the module to be executed directly
390432def main ():
0 commit comments