@@ -35,8 +35,10 @@ class MonteCarlo(Predictor):
3535 Any additional savepoints (s) e.g., [10.1, 22.5]
3636 """
3737
38+ __DEFAULT_N_SAMPLES = 100 # Default number of samples to use, if none specified and not UncertainData
39+
3840 default_parameters = {
39- 'n_samples' : 100 # Default number of samples to use, if none specified
41+ 'n_samples' : None
4042 }
4143
4244 def predict (self , state : UncertainData , future_loading_eqn : Callable , ** kwargs ) -> PredictionResults :
@@ -53,11 +55,20 @@ def predict(self, state: UncertainData, future_loading_eqn: Callable, **kwargs)
5355 params ['print' ] = False
5456 params ['progress' ] = False
5557
58+ if not isinstance (state , UnweightedSamples ) and params ['n_samples' ] is None :
59+ # if not unweighted samples, some sample number is required, so set to default.
60+ params ['n_samples' ] = MonteCarlo .__DEFAULT_N_SAMPLES
61+ elif isinstance (state , UnweightedSamples ) and params ['n_samples' ] is None :
62+ params ['n_samples' ] = len (state ) # number of samples is from provided state
63+
5664 if len (params ['events' ]) == 0 and 'horizon' not in params :
5765 raise ValueError ("If specifying no event (i.e., simulate to time), must specify horizon" )
5866
59- # Sample from state if n_samples specified or state is not UnweightedSamples
60- if not isinstance (state , UnweightedSamples ) or len (state ) != params ['n_samples' ]:
67+ # Sample from state if n_samples specified or state is not UnweightedSamples (Case 2)
68+ # Or if is Unweighted samples, but there are the wrong number of samples (Case 1)
69+ if (
70+ (isinstance (state , UnweightedSamples ) and len (state ) != params ['n_samples' ]) # Case 1
71+ or not isinstance (state , UnweightedSamples )): # Case 2
6172 state = state .sample (params ['n_samples' ])
6273
6374 es_eqn = self .model .event_state
0 commit comments