@@ -101,7 +101,9 @@ def instantiate_steppers(
101101 model : Model ,
102102 steps : list [Step ],
103103 selected_steps : Mapping [type [BlockedStep ], list [Any ]],
104+ * ,
104105 step_kwargs : dict [str , dict ] | None = None ,
106+ initial_point : PointType | None = None ,
105107) -> Step | list [Step ]:
106108 """Instantiate steppers assigned to the model variables.
107109
@@ -131,13 +133,22 @@ def instantiate_steppers(
131133 step_kwargs = {}
132134
133135 used_keys = set ()
134- for step_class , vars in selected_steps .items ():
135- if vars :
136- name = getattr (step_class , "name" )
137- args = step_kwargs .get (name , {})
138- used_keys .add (name )
139- step = step_class (vars = vars , model = model , ** args )
140- steps .append (step )
136+ if selected_steps :
137+ if initial_point is None :
138+ initial_point = model .initial_point ()
139+
140+ for step_class , vars in selected_steps .items ():
141+ if vars :
142+ name = getattr (step_class , "name" )
143+ kwargs = step_kwargs .get (name , {})
144+ used_keys .add (name )
145+ step = step_class (
146+ vars = vars ,
147+ model = model ,
148+ initial_point = initial_point ,
149+ ** kwargs ,
150+ )
151+ steps .append (step )
141152
142153 unused_args = set (step_kwargs ).difference (used_keys )
143154 if unused_args :
@@ -161,18 +172,22 @@ def assign_step_methods(
161172 model : Model ,
162173 step : Step | Sequence [Step ] | None = None ,
163174 methods : Sequence [type [BlockedStep ]] | None = None ,
164- step_kwargs : dict [str , Any ] | None = None ,
165- ) -> Step | list [Step ]:
175+ ) -> tuple [list [Step ], dict [type [BlockedStep ], list [Variable ]]]:
166176 """Assign model variables to appropriate step methods.
167177
168- Passing a specified model will auto-assign its constituent stochastic
169- variables to step methods based on the characteristics of the variables.
178+ Passing a specified model will auto-assign its constituent value
179+ variables to step methods based on the characteristics of the respective
180+ random variables, and whether the logp can be differentiated with respect to it.
181+
170182 This function is intended to be called automatically from ``sample()``, but
171183 may be called manually. Each step method passed should have a
172184 ``competence()`` method that returns an ordinal competence value
173185 corresponding to the variable passed to it. This value quantifies the
174186 appropriateness of the step method for sampling the variable.
175187
188+ The outputs of this function can then be passed to `instantiate_steppers()`
189+ to initialize the assigned step samplers.
190+
176191 Parameters
177192 ----------
178193 model : Model object
@@ -183,24 +198,32 @@ def assign_step_methods(
183198 methods : iterable of step method classes, optional
184199 The set of step methods from which the function may choose. Defaults
185200 to the main step methods provided by PyMC.
186- step_kwargs : dict, optional
187- Parameters for the samplers. Keys are the lower case names of
188- the step method, values a dict of arguments.
189201
190202 Returns
191203 -------
192- methods : list
193- List of step methods associated with the model's variables.
204+ provided_steps: list of Step instances
205+ List of user provided instantiated step(s)
206+ assigned_steps: dict of Step class to Variable
207+ Dictionary with automatically selected step classes as keys and associated value variables as values
194208 """
195- steps : list [Step ] = []
209+ provided_steps : list [Step ] = []
196210 assigned_vars : set [Variable ] = set ()
197211
198212 if step is not None :
199213 if isinstance (step , BlockedStep | CompoundStep ):
200- steps .append (step )
214+ provided_steps = [step ]
215+ elif isinstance (step , Sequence ):
216+ provided_steps = list (step )
201217 else :
202- steps .extend (step )
203- for step in steps :
218+ raise ValueError (f"Step should be a Step or a sequence of Steps, got { step } " )
219+
220+ for step in provided_steps :
221+ if not isinstance (step , BlockedStep | CompoundStep ):
222+ if issubclass (step , BlockedStep | CompoundStep ):
223+ raise ValueError (f"Provided { step } was not initialized" )
224+ else :
225+ raise ValueError (f"{ step } is not a Step instance" )
226+
204227 for var in step .vars :
205228 if var not in model .value_vars :
206229 raise ValueError (
@@ -235,7 +258,7 @@ def assign_step_methods(
235258 )
236259 selected_steps .setdefault (selected , []).append (var )
237260
238- return instantiate_steppers ( model , steps , selected_steps , step_kwargs )
261+ return provided_steps , selected_steps
239262
240263
241264def _print_step_hierarchy (s : Step , level : int = 0 ) -> None :
@@ -719,22 +742,23 @@ def joined_blas_limiter():
719742 msg = f"Only { draws } samples per chain. Reliable r-hat and ESS diagnostics require longer chains for accurate estimate."
720743 _log .warning (msg )
721744
722- auto_nuts_init = True
723- if step is not None :
724- if isinstance (step , CompoundStep ):
725- for method in step .methods :
726- if isinstance (method , NUTS ):
727- auto_nuts_init = False
728- elif isinstance (step , NUTS ):
729- auto_nuts_init = False
730-
731- initial_points = None
732- step = assign_step_methods (model , step , methods = pm .STEP_METHODS , step_kwargs = kwargs )
745+ provided_steps , selected_steps = assign_step_methods (model , step , methods = pm .STEP_METHODS )
746+ exclusive_nuts = (
747+ # User provided an instantiated NUTS step, and nothing else is needed
748+ (not selected_steps and len (provided_steps ) == 1 and isinstance (provided_steps [0 ], NUTS ))
749+ or
750+ # Only automatically selected NUTS step is needed
751+ (
752+ not provided_steps
753+ and len (selected_steps ) == 1
754+ and issubclass (next (iter (selected_steps )), NUTS )
755+ )
756+ )
733757
734758 if nuts_sampler != "pymc" :
735- if not isinstance ( step , NUTS ) :
759+ if not exclusive_nuts :
736760 raise ValueError (
737- "Model can not be sampled with NUTS alone. Your model is probably not continuous ."
761+ "Model can not be sampled with NUTS alone. It either has discrete variables or a non-differentiable log-probability ."
738762 )
739763
740764 with joined_blas_limiter ():
@@ -755,13 +779,11 @@ def joined_blas_limiter():
755779 ** kwargs ,
756780 )
757781
758- if isinstance (step , list ):
759- step = CompoundStep (step )
760- elif isinstance (step , NUTS ) and auto_nuts_init :
782+ if exclusive_nuts and not provided_steps :
783+ # Special path for NUTS initialization
761784 if "nuts" in kwargs :
762785 nuts_kwargs = kwargs .pop ("nuts" )
763786 [kwargs .setdefault (k , v ) for k , v in nuts_kwargs .items ()]
764- _log .info ("Auto-assigning NUTS sampler..." )
765787 with joined_blas_limiter ():
766788 initial_points , step = init_nuts (
767789 init = init ,
@@ -775,9 +797,8 @@ def joined_blas_limiter():
775797 initvals = initvals ,
776798 ** kwargs ,
777799 )
778-
779- if initial_points is None :
780- # Time to draw/evaluate numeric start points for each chain.
800+ else :
801+ # Get initial points
781802 ipfns = make_initial_point_fns_per_chain (
782803 model = model ,
783804 overrides = initvals ,
@@ -786,11 +807,16 @@ def joined_blas_limiter():
786807 )
787808 initial_points = [ipfn (seed ) for ipfn , seed in zip (ipfns , random_seed_list )]
788809
789- # One final check that shapes and logps at the starting points are okay.
790- ip : dict [str , np .ndarray ]
791- for ip in initial_points :
792- model .check_start_vals (ip )
793- _check_start_shape (model , ip )
810+ # Instantiate automatically selected steps
811+ step = instantiate_steppers (
812+ model ,
813+ steps = provided_steps ,
814+ selected_steps = selected_steps ,
815+ step_kwargs = kwargs ,
816+ initial_point = initial_points [0 ],
817+ )
818+ if isinstance (step , list ):
819+ step = CompoundStep (step )
794820
795821 if var_names is not None :
796822 trace_vars = [v for v in model .unobserved_RVs if v .name in var_names ]
@@ -806,7 +832,7 @@ def joined_blas_limiter():
806832 expected_length = draws + tune ,
807833 step = step ,
808834 trace_vars = trace_vars ,
809- initial_point = ip ,
835+ initial_point = initial_points [ 0 ] ,
810836 model = model ,
811837 )
812838
@@ -954,7 +980,6 @@ def _sample_return(
954980 f"took { t_sampling :.0f} seconds."
955981 )
956982
957- idata = None
958983 if compute_convergence_checks or return_inferencedata :
959984 ikwargs : dict [str , Any ] = {"model" : model , "save_warmup" : not discard_tuned_samples }
960985 ikwargs .update (idata_kwargs )
@@ -1159,7 +1184,6 @@ def _iter_sample(
11591184 diverging : bool
11601185 Indicates if the draw is divergent. Only available with some samplers.
11611186 """
1162- model = modelcontext (model )
11631187 draws = int (draws )
11641188
11651189 if draws < 1 :
@@ -1174,8 +1198,6 @@ def _iter_sample(
11741198 if hasattr (step , "reset_tuning" ):
11751199 step .reset_tuning ()
11761200 for i in range (draws ):
1177- diverging = False
1178-
11791201 if i == 0 and hasattr (step , "iter_count" ):
11801202 step .iter_count = 0
11811203 if i == tune :
@@ -1298,6 +1320,7 @@ def _init_jitter(
12981320 seeds : Sequence [int ] | np .ndarray ,
12991321 jitter : bool ,
13001322 jitter_max_retries : int ,
1323+ logp_dlogp_func = None ,
13011324) -> list [PointType ]:
13021325 """Apply a uniform jitter in [-1, 1] to the test value as starting point in each chain.
13031326
@@ -1328,19 +1351,30 @@ def _init_jitter(
13281351 if not jitter :
13291352 return [ipfn (seed ) for ipfn , seed in zip (ipfns , seeds )]
13301353
1354+ model_logp_fn : Callable
1355+ if logp_dlogp_func is None :
1356+ model_logp_fn = model .compile_logp ()
1357+ else :
1358+
1359+ def model_logp_fn (ip ):
1360+ q , _ = DictToArrayBijection .map (ip )
1361+ return logp_dlogp_func ([q ], extra_vars = {})[0 ]
1362+
13311363 initial_points = []
13321364 for ipfn , seed in zip (ipfns , seeds ):
1333- rng = np .random .RandomState (seed )
1365+ rng = np .random .default_rng (seed )
13341366 for i in range (jitter_max_retries + 1 ):
13351367 point = ipfn (seed )
1336- if i < jitter_max_retries :
1337- try :
1368+ point_logp = model_logp_fn (point )
1369+ if not np .isfinite (point_logp ):
1370+ if i == jitter_max_retries :
1371+ # Print informative message on last attempted point
13381372 model .check_start_vals (point )
1339- except SamplingError :
1340- # Retry with a new seed
1341- seed = rng . randint ( 2 ** 30 , dtype = np . int64 )
1342- else :
1343- break
1373+ # Retry with a new seed
1374+ seed = rng . integers ( 2 ** 30 , dtype = np . int64 )
1375+ else :
1376+ break
1377+
13441378 initial_points .append (point )
13451379 return initial_points
13461380
@@ -1436,10 +1470,12 @@ def init_nuts(
14361470
14371471 _log .info (f"Initializing NUTS using { init } ..." )
14381472
1439- cb = [
1440- pm .callbacks .CheckParametersConvergence (tolerance = 1e-2 , diff = "absolute" ),
1441- pm .callbacks .CheckParametersConvergence (tolerance = 1e-2 , diff = "relative" ),
1442- ]
1473+ cb = []
1474+ if "advi" in init :
1475+ cb = [
1476+ pm .callbacks .CheckParametersConvergence (tolerance = 1e-2 , diff = "absolute" ),
1477+ pm .callbacks .CheckParametersConvergence (tolerance = 1e-2 , diff = "relative" ),
1478+ ]
14431479
14441480 logp_dlogp_func = model .logp_dlogp_function (ravel_inputs = True )
14451481 logp_dlogp_func .trust_input = True
@@ -1449,6 +1485,7 @@ def init_nuts(
14491485 seeds = random_seed_list ,
14501486 jitter = "jitter" in init ,
14511487 jitter_max_retries = jitter_max_retries ,
1488+ logp_dlogp_func = logp_dlogp_func ,
14521489 )
14531490
14541491 apoints = [DictToArrayBijection .map (point ) for point in initial_points ]
@@ -1562,7 +1599,14 @@ def init_nuts(
15621599 else :
15631600 raise ValueError (f"Unknown initializer: { init } ." )
15641601
1565- step = pm .NUTS (potential = potential , model = model , rng = random_seed_list [0 ], ** kwargs )
1602+ step = pm .NUTS (
1603+ potential = potential ,
1604+ model = model ,
1605+ rng = random_seed_list [0 ],
1606+ initial_point = initial_points [0 ],
1607+ logp_dlogp_func = logp_dlogp_func ,
1608+ ** kwargs ,
1609+ )
15661610
15671611 # Filter deterministics from initial_points
15681612 value_var_names = [var .name for var in model .value_vars ]
0 commit comments