|
34 | 34 | rng = ReactantRNG(seed) |
35 | 35 | μ = Reactant.ConcreteRNumber(0.0) |
36 | 36 | σ = Reactant.ConcreteRNumber(1.0) |
37 | | - trace, weight = ProbProg.generate_(rng, model, μ, σ, shape) |
| 37 | + trace, weight = ProbProg.generate_(rng, ProbProg.Constraint(), model, μ, σ, shape) |
38 | 38 | @test mean(trace.retval[1]) ≈ 0.0 atol = 0.05 rtol = 0.05 |
39 | 39 | end |
40 | 40 |
|
|
47 | 47 |
|
48 | 48 | constraint = ProbProg.Constraint(:s => (fill(0.1, shape),)) |
49 | 49 |
|
50 | | - trace, weight = ProbProg.generate_(rng, model, μ, σ, shape; constraint) |
| 50 | + trace, weight = ProbProg.generate_(rng, constraint, model, μ, σ, shape) |
51 | 51 |
|
52 | 52 | @test trace.choices[:s][1] == constraint[ProbProg.Address(:s)][1] |
53 | 53 |
|
|
72 | 72 | :u => :y => (fill(0.3, shape),), |
73 | 73 | ) |
74 | 74 |
|
75 | | - trace, weight = ProbProg.generate_(rng, nested_model, μ, σ, shape; constraint) |
| 75 | + trace, weight = ProbProg.generate_(rng, constraint, nested_model, μ, σ, shape) |
76 | 76 |
|
77 | 77 | @test trace.choices[:s][1] == fill(0.1, shape) |
78 | 78 | @test trace.subtraces[:t].choices[:x][1] == fill(0.2, shape) |
@@ -108,33 +108,24 @@ end |
108 | 108 |
|
109 | 109 | constrained_addresses = ProbProg.extract_addresses(constraint1) |
110 | 110 |
|
111 | | - constraint_ptr1 = Reactant.ConcreteRNumber( |
112 | | - reinterpret(UInt64, pointer_from_objref(constraint1)) |
| 111 | + compiled_fn = @compile optimize = :probprog ProbProg.generate( |
| 112 | + rng, constraint1, model, μ, σ, shape; constrained_addresses |
113 | 113 | ) |
114 | 114 |
|
115 | | - wrapper_fn(rng, constraint_ptr, μ, σ) = ProbProg.generate( |
116 | | - rng, model, μ, σ, shape; constraint_ptr, constrained_addresses |
117 | | - ) |
118 | | - |
119 | | - compiled_fn = @compile optimize = :probprog wrapper_fn(rng, constraint_ptr1, μ, σ) |
120 | | - |
121 | 115 | trace1 = nothing |
122 | 116 | seed_buffer = only(rng.seed.data).buffer |
123 | 117 | GC.@preserve seed_buffer constraint1 begin |
124 | | - trace1, _ = compiled_fn(rng, constraint_ptr1, μ, σ) |
125 | | - trace1 = ProbProg.from_trace_tensor(trace1) |
| 118 | + trace1, _ = compiled_fn(rng, constraint1, model, μ, σ, shape) |
| 119 | + trace1 = ProbProg.ProbProgTrace(trace1) |
126 | 120 | end |
127 | 121 |
|
128 | 122 | constraint2 = ProbProg.Constraint(:s => (fill(0.2, shape),)) |
129 | | - constraint_ptr2 = Reactant.ConcreteRNumber( |
130 | | - reinterpret(UInt64, pointer_from_objref(constraint2)) |
131 | | - ) |
132 | 123 |
|
133 | 124 | trace2 = nothing |
134 | 125 | seed_buffer = only(rng.seed.data).buffer |
135 | 126 | GC.@preserve seed_buffer constraint2 begin |
136 | | - trace2, _ = compiled_fn(rng, constraint_ptr2, μ, σ) |
137 | | - trace2 = ProbProg.from_trace_tensor(trace2) |
| 127 | + trace2, _ = compiled_fn(rng, constraint2, model, μ, σ, shape) |
| 128 | + trace2 = ProbProg.ProbProgTrace(trace2) |
138 | 129 | end |
139 | 130 |
|
140 | 131 | @test trace1.choices[:s][1] != trace2.choices[:s][1] |
|
0 commit comments