2626from aesara .link .jax .dispatch import jax_funcify
2727
2828from pymc import Model , modelcontext
29- from pymc .aesaraf import compile_rv_inplace , inputvars
29+ from pymc .aesaraf import compile_rv_inplace
30+ from pymc .backends .arviz import find_observations
31+ from pymc .distributions import logpt
3032from pymc .util import get_default_varnames
3133
3234warnings .warn ("This module is experimental." )
@@ -95,6 +97,39 @@ def logp_fn_wrap(x):
9597 return logp_fn_wrap
9698
9799
100+ # Adopted from arviz numpyro extractor
101+ def _sample_stats_to_xarray (posterior ):
102+ """Extract sample_stats from NumPyro posterior."""
103+ rename_key = {
104+ "potential_energy" : "lp" ,
105+ "adapt_state.step_size" : "step_size" ,
106+ "num_steps" : "n_steps" ,
107+ "accept_prob" : "acceptance_rate" ,
108+ }
109+ data = {}
110+ for stat , value in posterior .get_extra_fields (group_by_chain = True ).items ():
111+ if isinstance (value , (dict , tuple )):
112+ continue
113+ name = rename_key .get (stat , stat )
114+ value = value .copy ()
115+ data [name ] = value
116+ if stat == "num_steps" :
117+ data ["tree_depth" ] = np .log2 (value ).astype (int ) + 1
118+ return data
119+
120+
121+ def _get_log_likelihood (model , samples ):
122+ "Compute log-likelihood for all observations"
123+ data = {}
124+ for v in model .observed_RVs :
125+ logp_v = replace_shared_variables ([logpt (v )])
126+ fgraph = FunctionGraph (model .value_vars , logp_v , clone = False )
127+ jax_fn = jax_funcify (fgraph )
128+ result = jax .vmap (jax .vmap (jax_fn ))(* samples )[0 ]
129+ data [v .name ] = result
130+ return data
131+
132+
98133def sample_numpyro_nuts (
99134 draws = 1000 ,
100135 tune = 1000 ,
@@ -151,9 +186,23 @@ def sample_numpyro_nuts(
151186 map_seed = jax .random .split (seed , chains )
152187
153188 if chains == 1 :
154- pmap_numpyro .run (seed , init_params = init_state , extra_fields = ("num_steps" ,))
189+ init_params = init_state
190+ map_seed = seed
155191 else :
156- pmap_numpyro .run (map_seed , init_params = init_state_batched , extra_fields = ("num_steps" ,))
192+ init_params = init_state_batched
193+
194+ pmap_numpyro .run (
195+ map_seed ,
196+ init_params = init_params ,
197+ extra_fields = (
198+ "num_steps" ,
199+ "potential_energy" ,
200+ "energy" ,
201+ "adapt_state.step_size" ,
202+ "accept_prob" ,
203+ "diverging" ,
204+ ),
205+ )
157206
158207 raw_mcmc_samples = pmap_numpyro .get_samples (group_by_chain = True )
159208
@@ -172,6 +221,11 @@ def sample_numpyro_nuts(
172221 print ("Transformation time = " , tic4 - tic3 , file = sys .stdout )
173222
174223 posterior = mcmc_samples
175- az_trace = az .from_dict (posterior = posterior )
224+ az_posterior = az .from_dict (posterior = posterior )
225+
226+ az_obs = az .from_dict (observed_data = find_observations (model ))
227+ az_stats = az .from_dict (sample_stats = _sample_stats_to_xarray (pmap_numpyro ))
228+ az_ll = az .from_dict (log_likelihood = _get_log_likelihood (model , raw_mcmc_samples ))
229+ az_trace = az .concat (az_posterior , az_ll , az_obs , az_stats )
176230
177231 return az_trace
0 commit comments