We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent c61e9cd commit 5352798Copy full SHA for 5352798
pymc/sampling/mcmc.py
@@ -305,7 +305,14 @@ def _sample_external_nuts(
305
"`var_names` are currently ignored by the nutpie sampler",
306
UserWarning,
307
)
308
- compiled_model = nutpie.compile_pymc_model(model)
+ compile_kwargs = {}
309
+ for kwarg in ("backend", "gradient_backend"):
310
+ if kwarg in nuts_sampler_kwargs:
311
+ compile_kwargs[kwarg] = nuts_sampler_kwargs.pop(kwarg)
312
+ compiled_model = nutpie.compile_pymc_model(
313
+ model,
314
+ **compile_kwargs,
315
+ )
316
t_start = time.time()
317
idata = nutpie.sample(
318
compiled_model,
0 commit comments