5050 find_observations ,
5151)
5252from pymc .backends .base import IBaseTrace , MultiTrace , _choose_chains
53+ from pymc .backends .zarr import ZarrTrace
5354from pymc .blocking import DictToArrayBijection
5455from pymc .exceptions import SamplingError
5556from pymc .initial_point import PointType , StartDict , make_initial_point_fns_per_chain
@@ -503,7 +504,7 @@ def sample(
503504 model : Model | None = None ,
504505 compile_kwargs : dict | None = None ,
505506 ** kwargs ,
506- ) -> InferenceData | MultiTrace :
507+ ) -> InferenceData | MultiTrace | ZarrTrace :
507508 r"""Draw samples from the posterior using the given step methods.
508509
509510 Multiple step methods are supported via compound step methods.
@@ -570,7 +571,13 @@ def sample(
570571 Number of iterations of initializer. Only works for 'ADVI' init methods.
571572 trace : backend, optional
572573 A backend instance or None.
573- If None, the NDArray backend is used.
574+ If ``None``, a ``MultiTrace`` object with underlying ``NDArray`` trace objects
575+ is used. If ``trace`` is a :class:`~pymc.backends.zarr.ZarrTrace` instance,
576+ the drawn samples will be written onto the desired storage while sampling is
577+ on-going. This means sampling runs that, for whatever reason, die in the middle
578+ of their execution will write the partial results onto the storage. If the
579+ storage persist on disk, these results should be available even after a server
580+ crash. See :class:`~pymc.backends.zarr.ZarrTrace` for more information.
574581 discard_tuned_samples : bool
575582 Whether to discard posterior samples of the tune interval.
576583 compute_convergence_checks : bool, default=True
@@ -607,8 +614,12 @@ def sample(
607614
608615 Returns
609616 -------
610- trace : pymc.backends.base.MultiTrace or arviz.InferenceData
611- A ``MultiTrace`` or ArviZ ``InferenceData`` object that contains the samples.
617+ trace : pymc.backends.base.MultiTrace | pymc.backends.zarr.ZarrTrace | arviz.InferenceData
618+ A ``MultiTrace``, :class:`~arviz.InferenceData` or
619+ :class:`~pymc.backends.zarr.ZarrTrace` object that contains the samples. A
620+ ``ZarrTrace`` is only returned if the supplied ``trace`` argument is a
621+ ``ZarrTrace`` instance. Refer to :class:`~pymc.backends.zarr.ZarrTrace` for
622+ the benefits this backend provides.
612623
613624 Notes
614625 -----
@@ -741,7 +752,7 @@ def joined_blas_limiter():
741752 rngs = get_random_generator (random_seed ).spawn (chains )
742753 random_seed_list = [rng .integers (2 ** 30 ) for rng in rngs ]
743754
744- if not discard_tuned_samples and not return_inferencedata :
755+ if not discard_tuned_samples and not return_inferencedata and not isinstance ( trace , ZarrTrace ) :
745756 warnings .warn (
746757 "Tuning samples will be included in the returned `MultiTrace` object, which can lead to"
747758 " complications in your downstream analysis. Please consider to switch to `InferenceData`:\n "
@@ -852,6 +863,7 @@ def joined_blas_limiter():
852863 trace_vars = trace_vars ,
853864 initial_point = initial_points [0 ],
854865 model = model ,
866+ tune = tune ,
855867 )
856868
857869 sample_args = {
@@ -934,7 +946,7 @@ def joined_blas_limiter():
934946 # into a function to make it easier to test and refactor.
935947 return _sample_return (
936948 run = run ,
937- traces = traces ,
949+ traces = trace if isinstance ( trace , ZarrTrace ) else traces ,
938950 tune = tune ,
939951 t_sampling = t_sampling ,
940952 discard_tuned_samples = discard_tuned_samples ,
@@ -949,7 +961,7 @@ def joined_blas_limiter():
949961def _sample_return (
950962 * ,
951963 run : RunType | None ,
952- traces : Sequence [IBaseTrace ],
964+ traces : Sequence [IBaseTrace ] | ZarrTrace ,
953965 tune : int ,
954966 t_sampling : float ,
955967 discard_tuned_samples : bool ,
@@ -958,18 +970,69 @@ def _sample_return(
958970 keep_warning_stat : bool ,
959971 idata_kwargs : dict [str , Any ],
960972 model : Model ,
961- ) -> InferenceData | MultiTrace :
973+ ) -> InferenceData | MultiTrace | ZarrTrace :
962974 """Pick/slice chains, run diagnostics and convert to the desired return type.
963975
964976 Final step of `pm.sampler`.
965977 """
978+ if isinstance (traces , ZarrTrace ):
979+ # Split warmup from posterior samples
980+ traces .split_warmup_groups ()
981+
982+ # Set sampling time
983+ traces .sampling_time = t_sampling
984+
985+ # Compute number of actual draws per chain
986+ total_draws_per_chain = traces ._sampling_state .draw_idx [:]
987+ n_chains = len (traces .straces )
988+ desired_tune = traces .tuning_steps
989+ desired_draw = len (traces .posterior .draw )
990+ tuning_steps_per_chain = np .clip (total_draws_per_chain , 0 , desired_tune )
991+ draws_per_chain = total_draws_per_chain - tuning_steps_per_chain
992+
993+ total_n_tune = tuning_steps_per_chain .sum ()
994+ total_draws = draws_per_chain .sum ()
995+
996+ _log .info (
997+ f'Sampling { n_chains } chain{ "s" if n_chains > 1 else "" } for { desired_tune :_d} desired tune and { desired_draw :_d} desired draw iterations '
998+ f"(Actually sampled { total_n_tune :_d} tune and { total_draws :_d} draws total) "
999+ f"took { t_sampling :.0f} seconds."
1000+ )
1001+
1002+ if compute_convergence_checks or return_inferencedata :
1003+ idata = traces .to_inferencedata (save_warmup = not discard_tuned_samples )
1004+ log_likelihood = idata_kwargs .pop ("log_likelihood" , False )
1005+ if log_likelihood :
1006+ from pymc .stats .log_density import compute_log_likelihood
1007+
1008+ idata = compute_log_likelihood (
1009+ idata ,
1010+ var_names = None if log_likelihood is True else log_likelihood ,
1011+ extend_inferencedata = True ,
1012+ model = model ,
1013+ sample_dims = ["chain" , "draw" ],
1014+ progressbar = False ,
1015+ )
1016+ if compute_convergence_checks :
1017+ warns = run_convergence_checks (idata , model )
1018+ for warn in warns :
1019+ traces ._sampling_state .global_warnings .append (np .array ([warn ]))
1020+ log_warnings (warns )
1021+
1022+ if return_inferencedata :
1023+ # By default we drop the "warning" stat which contains `SamplerWarning`
1024+ # objects that can not be stored with `.to_netcdf()`.
1025+ if not keep_warning_stat :
1026+ return drop_warning_stat (idata )
1027+ return idata
1028+ return traces
1029+
9661030 # Pick and slice chains to keep the maximum number of samples
9671031 if discard_tuned_samples :
9681032 traces , length = _choose_chains (traces , tune )
9691033 else :
9701034 traces , length = _choose_chains (traces , 0 )
9711035 mtrace = MultiTrace (traces )[:length ]
972-
9731036 # count the number of tune/draw iterations that happened
9741037 # ideally via the "tune" statistic, but not all samplers record it!
9751038 if "tune" in mtrace .stat_names :
0 commit comments