@@ -124,8 +124,9 @@ def _get_log_likelihood(model, samples):
124124 for v in model .observed_RVs :
125125 logp_v = replace_shared_variables ([logpt (v )])
126126 fgraph = FunctionGraph (model .value_vars , logp_v , clone = False )
127+ optimize_graph (fgraph , include = ["fast_run" ], exclude = ["cxx_only" , "BlasOpt" ])
127128 jax_fn = jax_funcify (fgraph )
128- result = jax .vmap (jax .vmap (jax_fn ))(* samples )[0 ]
129+ result = jax .jit ( jax . vmap (jax .vmap (jax_fn ) ))(* samples )[0 ]
129130 data [v .name ] = result
130131 return data
131132
@@ -150,6 +151,20 @@ def sample_numpyro_nuts(
150151
151152 vars_to_sample = list (get_default_varnames (var_names , include_transformed = keep_untransformed ))
152153
154+ coords = {
155+ cname : np .array (cvals ) if isinstance (cvals , tuple ) else cvals
156+ for cname , cvals in model .coords .items ()
157+ if cvals is not None
158+ }
159+
160+ if hasattr (model , "RV_dims" ):
161+ dims = {
162+ var_name : [dim for dim in dims if dim is not None ]
163+ for var_name , dims in model .RV_dims .items ()
164+ }
165+ else :
166+ dims = {}
167+
153168 tic1 = pd .Timestamp .now ()
154169 print ("Compiling..." , file = sys .stdout )
155170
@@ -213,6 +228,7 @@ def sample_numpyro_nuts(
213228 mcmc_samples = {}
214229 for v in vars_to_sample :
215230 fgraph = FunctionGraph (model .value_vars , [v ], clone = False )
231+ optimize_graph (fgraph , include = ["fast_run" ], exclude = ["cxx_only" , "BlasOpt" ])
216232 jax_fn = jax_funcify (fgraph )
217233 result = jax .vmap (jax .vmap (jax_fn ))(* raw_mcmc_samples )[0 ]
218234 mcmc_samples [v .name ] = result
@@ -221,11 +237,13 @@ def sample_numpyro_nuts(
221237 print ("Transformation time = " , tic4 - tic3 , file = sys .stdout )
222238
223239 posterior = mcmc_samples
224- az_posterior = az .from_dict (posterior = posterior )
225-
226- az_obs = az .from_dict (observed_data = find_observations (model ))
227- az_stats = az .from_dict (sample_stats = _sample_stats_to_xarray (pmap_numpyro ))
228- az_ll = az .from_dict (log_likelihood = _get_log_likelihood (model , raw_mcmc_samples ))
229- az_trace = az .concat (az_posterior , az_ll , az_obs , az_stats )
240+ az_trace = az .from_dict (
241+ posterior = posterior ,
242+ log_likelihood = _get_log_likelihood (model , raw_mcmc_samples ),
243+ observed_data = find_observations (model ),
244+ sample_stats = _sample_stats_to_xarray (pmap_numpyro ),
245+ coords = coords ,
246+ dims = dims ,
247+ )
230248
231249 return az_trace
0 commit comments