This repository was archived by the owner on Nov 17, 2025. It is now read-only.
File tree Expand file tree Collapse file tree 1 file changed +40
-0
lines changed Expand file tree Collapse file tree 1 file changed +40
-0
lines changed Original file line number Diff line number Diff line change 1+ """Test file created for the sole purpose of tracking the status of Numba compilation"""
2+ import aesara
3+ import aesara .tensor as at
4+ from aeppl import joint_logprob
5+
6+ import aehmc .nuts as nuts
7+
8+
9+ def test_sample_with_numba ():
10+
11+ srng = at .random .RandomStream (seed = 0 )
12+ Y_rv = srng .normal (1 , 2 )
13+
14+ def logprob_fn (y ):
15+ logprob = joint_logprob ({Y_rv : y })
16+ return logprob
17+
18+ # Build the transition kernel
19+ kernel = nuts .new_kernel (srng , logprob_fn )
20+
21+ # Compile a function that updates the chain
22+ y_vv = Y_rv .clone ()
23+ initial_state = nuts .new_state (y_vv , logprob_fn )
24+
25+ step_size = at .as_tensor (1e-2 )
26+ inverse_mass_matrix = at .as_tensor (1.0 )
27+ (
28+ next_state ,
29+ potential_energy ,
30+ potential_energy_grad ,
31+ acceptance_prob ,
32+ num_doublings ,
33+ is_turning ,
34+ is_diverging ,
35+ ), updates = kernel (* initial_state , step_size , inverse_mass_matrix )
36+
37+ next_step_fn = aesara .function ([y_vv ], next_state , updates = updates , mode = "NUMBA" )
38+
39+ # TODO: Assert something
40+ next_step_fn (Y_rv .eval ())
You can’t perform that action at this time.
0 commit comments