Skip to content

Commit e7d36e7

Browse files
committed
Fix default n_samples
1 parent b9e3ff5 commit e7d36e7

File tree

1 file changed

+14
-3
lines changed

1 file changed

+14
-3
lines changed

src/progpy/predictors/monte_carlo.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,10 @@ class MonteCarlo(Predictor):
3535
Any additional savepoints (s) e.g., [10.1, 22.5]
3636
"""
3737

38+
__DEFAULT_N_SAMPLES = 100 # Default number of samples to use, if none specified and not UncertainData
39+
3840
default_parameters = {
39-
'n_samples': 100 # Default number of samples to use, if none specified
41+
'n_samples': None
4042
}
4143

4244
def predict(self, state: UncertainData, future_loading_eqn: Callable, **kwargs) -> PredictionResults:
@@ -53,11 +55,20 @@ def predict(self, state: UncertainData, future_loading_eqn: Callable, **kwargs)
5355
params['print'] = False
5456
params['progress'] = False
5557

58+
if not isinstance(state, UnweightedSamples) and params['n_samples'] is None:
59+
# if not unweighted samples, some sample number is required, so set to default.
60+
params['n_samples'] = MonteCarlo.__DEFAULT_N_SAMPLES
61+
elif isinstance(state, UnweightedSamples) and params['n_samples'] is None:
62+
params['n_samples'] = len(state) # number of samples is from provided state
63+
5664
if len(params['events']) == 0 and 'horizon' not in params:
5765
raise ValueError("If specifying no event (i.e., simulate to time), must specify horizon")
5866

59-
# Sample from state if n_samples specified or state is not UnweightedSamples
60-
if not isinstance(state, UnweightedSamples) or len(state) != params['n_samples']:
67+
# Sample from state if n_samples specified or state is not UnweightedSamples (Case 2)
68+
# Or if is Unweighted samples, but there are the wrong number of samples (Case 1)
69+
if (
70+
(isinstance(state, UnweightedSamples) and len(state) != params['n_samples']) # Case 1
71+
or not isinstance(state, UnweightedSamples)): # Case 2
6172
state = state.sample(params['n_samples'])
6273

6374
es_eqn = self.model.event_state

0 commit comments

Comments
 (0)