@@ -12,7 +12,6 @@ using ..Reactant:
1212 TracedSetPath,
1313 ConcreteToTraced,
1414 AbstractConcreteArray,
15- XLA,
1615 Sharding,
1716 to_number
1817import .. Reactant: promote_to, make_tracer
@@ -21,15 +20,8 @@ import ..Compiler: donate_argument!
2120"""
2221 process_probprog_function(f, args, op_name)
2322
24- Note: by convention `args` must have the RNG state as the first argument.
25-
26- This function handles the probprog argument convention where:
27- - **Index 1**: RNG state
28- - **Index 2**: Function `f` (when wrapped)
29- - **Index 3+**: Remaining arguments
30-
31- This wrapper ensures the RNG state is threaded through as the first result,
32- followed by the actual function results.
23+ By convention `args` must have the RNG state as the first argument.
24+ Ensures the RNG state is threaded through as the first result, followed by the actual function results.
3325"""
3426function process_probprog_function (f, args, op_name, with_rng= true )
3527 seen = OrderedIdDict ()
@@ -114,22 +106,14 @@ end
114106
115107This function handles the probprog argument convention where:
116108- **Index 1**: RNG state
117- - **Index 2**: Function `f` (when `fnwrap ` is true)
109+ - **Index 2**: Function `f` (when `fnwrapped ` is true)
118110- **Index 3+**: Other arguments
119111
120- When setting results, the function checks:
121- 1. If result path matches `resprefix`, store in `result`
122- 2. If result path matches `argprefix`, store in `args` (adjust indices for wrapped function)
123-
124- `offset` varies depending on the ProbProg operation:
125- - `sample` and `untraced_call` return only function outputs:
126- Use `offset=0`: `linear_results[i]` corresponds to `op.result[i]`
112+ `offset` and `rng_only` vary depending on the ProbProg operation, e.g.:
127113- `simulate` and `generate` return trace, weight, then outputs:
128114 Use `offset=2`: `linear_results[i]` corresponds to `op.result[i+2]`
129- - `mh` and `regenerate` return trace, accepted/weight, rng_state (no model outputs) :
115+ - `mh` and `regenerate` return trace, accepted/weight, new rng_state :
130116 Use `offset=2, rng_only=true`: only process first result (rng_state)
131-
132- `rng_only`: When true, only process the first result (RNG state), skipping model outputs
133117"""
134118function process_probprog_outputs (
135119 op,
0 commit comments