Skip to content

Commit 289b9d2

Browse files
committed
CI
1 parent 71c6fba commit 289b9d2

File tree

5 files changed

+9
-27
lines changed

5 files changed

+9
-27
lines changed

src/probprog/HMC.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using ..Reactant: ConcreteRNumber, TracedRArray
1+
using ..Reactant: TracedRArray
22

33
function hmc(
44
rng::AbstractRNG,

src/probprog/MH.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using ..Reactant: ConcreteRNumber, TracedRArray
1+
using ..Reactant: TracedRArray
22

33
function mh(
44
rng::AbstractRNG,

src/probprog/Modeling.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
using ..Reactant: MLIR, TracedUtils, AbstractRNG, TracedRArray, ConcreteRNumber
2-
using ..Compiler: @jit, @compile
1+
using ..Reactant: MLIR, TracedUtils, AbstractRNG, TracedRArray
2+
using ..Compiler: @compile
33

44
include("Utils.jl")
55

src/probprog/ProbProg.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,6 @@ export ProbProgTrace, Constraint, Selection, Address
1616

1717
# Utility functions.
1818
export get_choices, select
19-
export to_trace_tensor, from_trace_tensor
20-
export to_constraint_tensor, from_constraint_tensor
2119

2220
# Core MLIR ops.
2321
export sample, untraced_call, simulate, generate, mh, hmc

src/probprog/Utils.jl

Lines changed: 5 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@ using ..Reactant:
1212
TracedSetPath,
1313
ConcreteToTraced,
1414
AbstractConcreteArray,
15-
XLA,
1615
Sharding,
1716
to_number
1817
import ..Reactant: promote_to, make_tracer
@@ -21,15 +20,8 @@ import ..Compiler: donate_argument!
2120
"""
2221
process_probprog_function(f, args, op_name)
2322
24-
Note: by convention `args` must have the RNG state as the first argument.
25-
26-
This function handles the probprog argument convention where:
27-
- **Index 1**: RNG state
28-
- **Index 2**: Function `f` (when wrapped)
29-
- **Index 3+**: Remaining arguments
30-
31-
This wrapper ensures the RNG state is threaded through as the first result,
32-
followed by the actual function results.
23+
By convention `args` must have the RNG state as the first argument.
24+
Ensures the RNG state is threaded through as the first result, followed by the actual function results.
3325
"""
3426
function process_probprog_function(f, args, op_name, with_rng=true)
3527
seen = OrderedIdDict()
@@ -114,22 +106,14 @@ end
114106
115107
This function handles the probprog argument convention where:
116108
- **Index 1**: RNG state
117-
- **Index 2**: Function `f` (when `fnwrap` is true)
109+
- **Index 2**: Function `f` (when `fnwrapped` is true)
118110
- **Index 3+**: Other arguments
119111
120-
When setting results, the function checks:
121-
1. If result path matches `resprefix`, store in `result`
122-
2. If result path matches `argprefix`, store in `args` (adjust indices for wrapped function)
123-
124-
`offset` varies depending on the ProbProg operation:
125-
- `sample` and `untraced_call` return only function outputs:
126-
Use `offset=0`: `linear_results[i]` corresponds to `op.result[i]`
112+
`offset` and `rng_only` vary depending on the ProbProg operation, e.g.:
127113
- `simulate` and `generate` return trace, weight, then outputs:
128114
Use `offset=2`: `linear_results[i]` corresponds to `op.result[i+2]`
129-
- `mh` and `regenerate` return trace, accepted/weight, rng_state (no model outputs):
115+
- `mh` and `regenerate` return trace, accepted/weight, new rng_state:
130116
Use `offset=2, rng_only=true`: only process first result (rng_state)
131-
132-
`rng_only`: When true, only process the first result (RNG state), skipping model outputs
133117
"""
134118
function process_probprog_outputs(
135119
op,

0 commit comments

Comments
 (0)