@@ -11,86 +11,76 @@ function sample(
1111 logpdf:: Union{Nothing,Function} = nothing ,
1212) where {Nargs}
1313 args_with_rng = (rng, args... )
14- mlir_fn_res, argprefix, resprefix, _ = process_probprog_function (
14+ (; f_name, mlir_caller_args, mlir_result_types, traced_result, linear_results, fnwrapped, argprefix, resprefix) = process_probprog_function (
1515 f, args_with_rng, " sample"
1616 )
1717
18- (; result, linear_args, linear_results) = mlir_fn_res
19- fnwrap = mlir_fn_res. fnwrapped
20- func2 = mlir_fn_res. f
21-
22- inputs = process_probprog_inputs (linear_args, f, args_with_rng, fnwrap, argprefix)
23- out_tys = [MLIR. IR. type (TracedUtils. get_mlir_data (res)) for res in linear_results]
24-
25- sym = TracedUtils. get_attribute_by_name (func2, " sym_name" )
26- fn_attr = MLIR. IR. FlatSymbolRefAttribute (Base. String (sym))
27-
18+ fn_attr = MLIR. IR. FlatSymbolRefAttribute (f_name)
2819 symbol_addr = reinterpret (UInt64, pointer_from_objref (symbol))
2920 symbol_attr = @ccall MLIR. API. mlir_c. enzymeSymbolAttrGet (
3021 MLIR. IR. context ():: MLIR.API.MlirContext , symbol_addr:: UInt64
3122 ):: MLIR.IR.Attribute
3223
33- # Construct logpdf attribute if `logpdf` function is provided.
3424 logpdf_attr = nothing
3525 if logpdf isa Function
3626 samples = f (args_with_rng... )
3727
38- # Assume that logpdf parameters follow `(sample, args...)` convention.
28+ # Logpdf calling convention: `(sample, args...)` (no rng state)
3929 logpdf_args = (samples, args... )
4030
41- logpdf_mlir = TracedUtils. make_mlir_fn (
42- logpdf,
43- logpdf_args,
44- (),
45- string (logpdf),
46- false ;
47- do_transpose= false ,
48- args_in_result= :result ,
31+ logpdf_attr = MLIR. IR. FlatSymbolRefAttribute (
32+ process_probprog_function (logpdf, logpdf_args, " logpdf" , false ). f_name
4933 )
50-
51- logpdf_sym = TracedUtils. get_attribute_by_name (logpdf_mlir. f, " sym_name" )
52- logpdf_attr = MLIR. IR. FlatSymbolRefAttribute (Base. String (logpdf_sym))
5334 end
5435
5536 sample_op = MLIR. Dialects. enzyme. sample (
56- inputs ;
57- outputs= out_tys ,
37+ mlir_caller_args ;
38+ outputs= mlir_result_types ,
5839 fn= fn_attr,
5940 logpdf= logpdf_attr,
6041 symbol= symbol_attr,
6142 name= Base. String (symbol),
6243 )
6344
64- process_probprog_outputs (
65- sample_op, linear_results, result, f, args_with_rng, fnwrap, resprefix, argprefix
45+ traced_result = process_probprog_outputs (
46+ sample_op,
47+ linear_results,
48+ traced_result,
49+ f,
50+ args_with_rng,
51+ fnwrapped,
52+ resprefix,
53+ argprefix,
6654 )
6755
68- return result
56+ return traced_result
6957end
7058
7159function untraced_call (rng:: AbstractRNG , f:: Function , args:: Vararg{Any,Nargs} ) where {Nargs}
7260 args_with_rng = (rng, args... )
73- mlir_fn_res, argprefix, resprefix, _ = process_probprog_function (
74- f, args_with_rng, " call"
75- )
7661
77- (; result, linear_args, in_tys, linear_results) = mlir_fn_res
78- fnwrap = mlir_fn_res. fnwrapped
79- func2 = mlir_fn_res. f
80-
81- inputs = process_probprog_inputs (linear_args, f, args_with_rng, fnwrap, argprefix)
82- out_tys = [MLIR. IR. type (TracedUtils. get_mlir_data (res)) for res in linear_results]
62+ (; f_name, mlir_caller_args, mlir_result_types, traced_result, linear_results, fnwrapped, argprefix, resprefix) = process_probprog_function (
63+ f, args_with_rng, " untraced_call"
64+ )
8365
84- fname = TracedUtils. get_attribute_by_name (func2, " sym_name" )
85- fn_attr = MLIR. IR. FlatSymbolRefAttribute (Base. String (fname))
66+ fn_attr = MLIR. IR. FlatSymbolRefAttribute (f_name)
8667
87- call_op = MLIR. Dialects. enzyme. untracedCall (inputs; outputs= out_tys, fn= fn_attr)
68+ call_op = MLIR. Dialects. enzyme. untracedCall (
69+ mlir_caller_args; outputs= mlir_result_types, fn= fn_attr
70+ )
8871
89- process_probprog_outputs (
90- call_op, linear_results, result, f, args_with_rng, fnwrap, resprefix, argprefix
72+ traced_result = process_probprog_outputs (
73+ call_op,
74+ linear_results,
75+ traced_result,
76+ f,
77+ args_with_rng,
78+ fnwrapped,
79+ resprefix,
80+ argprefix,
9181 )
9282
93- return result
83+ return traced_result
9484end
9585
9686# Gen-like helper function.
@@ -110,29 +100,34 @@ end
110100
111101function simulate (rng:: AbstractRNG , f:: Function , args:: Vararg{Any,Nargs} ) where {Nargs}
112102 args = (rng, args... )
113- mlir_fn_res, argprefix, resprefix, _ = process_probprog_function (f, args, " simulate" )
114-
115- (; result, linear_args, in_tys, linear_results) = mlir_fn_res
116- fnwrap = mlir_fn_res. fnwrapped
117- func2 = mlir_fn_res. f
118-
119- inputs = process_probprog_inputs (linear_args, f, args, fnwrap, argprefix)
120- out_tys = [MLIR. IR. type (TracedUtils. get_mlir_data (res)) for res in linear_results]
121-
122- fname = TracedUtils. get_attribute_by_name (func2, " sym_name" )
123- fn_attr = MLIR. IR. FlatSymbolRefAttribute (Base. String (fname))
103+ (; f_name, mlir_caller_args, mlir_result_types, traced_result, linear_results, fnwrapped, argprefix, resprefix) = process_probprog_function (
104+ f, args, " simulate"
105+ )
106+ fn_attr = MLIR. IR. FlatSymbolRefAttribute (f_name)
124107
125108 trace_ty = @ccall MLIR. API. mlir_c. enzymeTraceTypeGet (
126109 MLIR. IR. context ():: MLIR.API.MlirContext
127110 ):: MLIR.IR.Type
128111 weight_ty = MLIR. IR. TensorType (Int64[], MLIR. IR. Type (Float64))
129112
130113 simulate_op = MLIR. Dialects. enzyme. simulate (
131- inputs; trace= trace_ty, weight= weight_ty, outputs= out_tys, fn= fn_attr
114+ mlir_caller_args;
115+ trace= trace_ty,
116+ weight= weight_ty,
117+ outputs= mlir_result_types,
118+ fn= fn_attr,
132119 )
133120
134- process_probprog_outputs (
135- simulate_op, linear_results, result, f, args, fnwrap, resprefix, argprefix, 2
121+ traced_result = process_probprog_outputs (
122+ simulate_op,
123+ linear_results,
124+ traced_result,
125+ f,
126+ args,
127+ fnwrapped,
128+ resprefix,
129+ argprefix,
130+ 2 ,
136131 )
137132
138133 trace = MLIR. IR. result (
@@ -146,7 +141,7 @@ function simulate(rng::AbstractRNG, f::Function, args::Vararg{Any,Nargs}) where
146141 trace = TracedRArray {UInt64,0} ((), trace, ())
147142 weight = TracedRArray {Float64,0} ((), MLIR. IR. result (simulate_op, 2 ), ())
148143
149- return trace, weight, result
144+ return trace, weight, traced_result
150145end
151146
152147# Gen-like helper function.
@@ -185,17 +180,12 @@ function generate(
185180 constrained_addresses:: Set{Address} ,
186181) where {Nargs}
187182 args = (rng, args... )
188- mlir_fn_res, argprefix, resprefix, _ = process_probprog_function (f, args, " generate" )
189183
190- (; result, linear_args, in_tys, linear_results) = mlir_fn_res
191- fnwrap = mlir_fn_res. fnwrapped
192- func2 = mlir_fn_res. f
193-
194- inputs = process_probprog_inputs (linear_args, f, args, fnwrap, argprefix)
195- out_tys = [MLIR. IR. type (TracedUtils. get_mlir_data (res)) for res in linear_results]
184+ (; f_name, mlir_caller_args, mlir_result_types, traced_result, linear_results, fnwrapped, argprefix, resprefix) = process_probprog_function (
185+ f, args, " generate"
186+ )
196187
197- fname = TracedUtils. get_attribute_by_name (func2, " sym_name" )
198- fn_attr = MLIR. IR. FlatSymbolRefAttribute (Base. String (fname))
188+ fn_attr = MLIR. IR. FlatSymbolRefAttribute (f_name)
199189
200190 constraint_ty = @ccall MLIR. API. mlir_c. enzymeConstraintTypeGet (
201191 MLIR. IR. context ():: MLIR.API.MlirContext
@@ -229,17 +219,25 @@ function generate(
229219 weight_ty = MLIR. IR. TensorType (Int64[], MLIR. IR. Type (Float64))
230220
231221 generate_op = MLIR. Dialects. enzyme. generate (
232- inputs ,
222+ mlir_caller_args ,
233223 constraint_val;
234224 trace= trace_ty,
235225 weight= weight_ty,
236- outputs= out_tys ,
226+ outputs= mlir_result_types ,
237227 fn= fn_attr,
238228 constrained_addresses= MLIR. IR. Attribute (constrained_addresses_attr),
239229 )
240230
241- process_probprog_outputs (
242- generate_op, linear_results, result, f, args, fnwrap, resprefix, argprefix, 2
231+ traced_result = process_probprog_outputs (
232+ generate_op,
233+ linear_results,
234+ traced_result,
235+ f,
236+ args,
237+ fnwrapped,
238+ resprefix,
239+ argprefix,
240+ 2 ,
243241 )
244242
245243 trace = MLIR. IR. result (
@@ -253,5 +251,5 @@ function generate(
253251 trace = TracedRArray {UInt64,0} ((), trace, ())
254252 weight = TracedRArray {Float64,0} ((), MLIR. IR. result (generate_op, 2 ), ())
255253
256- return trace, weight, result
254+ return trace, weight, traced_result
257255end
0 commit comments