Skip to content

Commit b225dc9

Browse files
committed
refactored with callcache
1 parent eace462 commit b225dc9

File tree

6 files changed

+275
-181
lines changed

6 files changed

+275
-181
lines changed

src/probprog/HMC.jl

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -12,17 +12,10 @@ function hmc(
1212
initial_momentum=nothing,
1313
) where {Nargs}
1414
args = (rng, args...)
15-
mlir_fn_res, argprefix, resprefix, _ = process_probprog_function(f, args, "hmc")
16-
17-
(; result, linear_args, in_tys, linear_results) = mlir_fn_res
18-
fnwrap = mlir_fn_res.fnwrapped
19-
func2 = mlir_fn_res.f
20-
21-
inputs = process_probprog_inputs(linear_args, f, args, fnwrap, argprefix)
22-
out_tys = [MLIR.IR.type(TracedUtils.get_mlir_data(res)) for res in linear_results]
23-
24-
fname = TracedUtils.get_attribute_by_name(func2, "sym_name")
25-
fn_attr = MLIR.IR.FlatSymbolRefAttribute(Base.String(fname))
15+
(; f_name, mlir_caller_args, mlir_result_types, traced_result, linear_results, fnwrapped, argprefix, resprefix) = process_probprog_function(
16+
f, args, "hmc"
17+
)
18+
fn_attr = MLIR.IR.FlatSymbolRefAttribute(f_name)
2619

2720
trace_ty = @ccall MLIR.API.mlir_c.enzymeTraceTypeGet(
2821
MLIR.IR.context()::MLIR.API.MlirContext
@@ -92,23 +85,32 @@ function hmc(
9285
end
9386

9487
hmc_op = MLIR.Dialects.enzyme.mcmc(
95-
inputs,
88+
mlir_caller_args,
9689
trace_val,
9790
mass_val;
9891
step_size=step_size_val,
9992
num_steps=num_steps_val,
10093
initial_momentum=initial_momentum_val,
10194
new_trace=trace_ty,
10295
accepted=accepted_ty,
103-
output_rng_state=out_tys[1], # by convention
96+
output_rng_state=mlir_result_types[1], # by convention
10497
alg=alg_attr,
10598
fn=fn_attr,
10699
selection=MLIR.IR.Attribute(selection_attr),
107100
)
108101

109102
# (new_trace, accepted, output_rng_state)
110-
process_probprog_outputs(
111-
hmc_op, linear_results, result, f, args, fnwrap, resprefix, argprefix, 2, true
103+
traced_result = process_probprog_outputs(
104+
hmc_op,
105+
linear_results,
106+
traced_result,
107+
f,
108+
args,
109+
fnwrapped,
110+
resprefix,
111+
argprefix,
112+
2,
113+
true,
112114
)
113115

114116
new_trace_val = MLIR.IR.result(hmc_op, 1)
@@ -122,5 +124,5 @@ function hmc(
122124
new_trace = TracedRArray{UInt64,0}((), new_trace_ptr, ())
123125
accepted = TracedRArray{Bool,0}((), MLIR.IR.result(hmc_op, 2), ())
124126

125-
return new_trace, accepted, result
127+
return new_trace, accepted, traced_result
126128
end

src/probprog/MH.jl

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -8,17 +8,10 @@ function mh(
88
selection::Selection,
99
) where {Nargs}
1010
args = (rng, args...)
11-
mlir_fn_res, argprefix, resprefix, _ = process_probprog_function(f, args, "mh")
12-
13-
(; result, linear_args, in_tys, linear_results) = mlir_fn_res
14-
fnwrap = mlir_fn_res.fnwrapped
15-
func2 = mlir_fn_res.f
16-
17-
inputs = process_probprog_inputs(linear_args, f, args, fnwrap, argprefix)
18-
out_tys = [MLIR.IR.type(TracedUtils.get_mlir_data(res)) for res in linear_results]
19-
20-
fname = TracedUtils.get_attribute_by_name(func2, "sym_name")
21-
fn_attr = MLIR.IR.FlatSymbolRefAttribute(Base.String(fname))
11+
(; f_name, mlir_caller_args, mlir_result_types, traced_result, linear_results, fnwrapped, argprefix, resprefix) = process_probprog_function(
12+
f, args, "mh"
13+
)
14+
fn_attr = MLIR.IR.FlatSymbolRefAttribute(f_name)
2215

2316
trace_ty = @ccall MLIR.API.mlir_c.enzymeTraceTypeGet(
2417
MLIR.IR.context()::MLIR.API.MlirContext
@@ -64,18 +57,27 @@ function mh(
6457
accepted_ty = MLIR.IR.TensorType(Int64[], MLIR.IR.Type(Bool))
6558

6659
mh_op = MLIR.Dialects.enzyme.mh(
67-
inputs,
60+
mlir_caller_args,
6861
trace_val;
6962
new_trace=trace_ty,
7063
accepted=accepted_ty,
71-
output_rng_state=out_tys[1], # by convention
64+
output_rng_state=mlir_result_types[1], # by convention
7265
fn=fn_attr,
7366
selection=MLIR.IR.Attribute(selection_attr),
7467
)
7568

7669
# Return (new_trace, accepted, output_rng_state)
77-
process_probprog_outputs(
78-
mh_op, linear_results, result, f, args, fnwrap, resprefix, argprefix, 2, true
70+
traced_result = process_probprog_outputs(
71+
mh_op,
72+
linear_results,
73+
traced_result,
74+
f,
75+
args,
76+
fnwrapped,
77+
resprefix,
78+
argprefix,
79+
2,
80+
true,
7981
)
8082

8183
new_trace_val = MLIR.IR.result(mh_op, 1)
@@ -89,7 +91,7 @@ function mh(
8991
new_trace = TracedRArray{UInt64,0}((), new_trace_ptr, ())
9092
accepted = TracedRArray{Bool,0}((), MLIR.IR.result(mh_op, 2), ())
9193

92-
return new_trace, accepted, result
94+
return new_trace, accepted, traced_result
9395
end
9496

9597
const metropolis_hastings = mh

src/probprog/Modeling.jl

Lines changed: 71 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -11,86 +11,76 @@ function sample(
1111
logpdf::Union{Nothing,Function}=nothing,
1212
) where {Nargs}
1313
args_with_rng = (rng, args...)
14-
mlir_fn_res, argprefix, resprefix, _ = process_probprog_function(
14+
(; f_name, mlir_caller_args, mlir_result_types, traced_result, linear_results, fnwrapped, argprefix, resprefix) = process_probprog_function(
1515
f, args_with_rng, "sample"
1616
)
1717

18-
(; result, linear_args, linear_results) = mlir_fn_res
19-
fnwrap = mlir_fn_res.fnwrapped
20-
func2 = mlir_fn_res.f
21-
22-
inputs = process_probprog_inputs(linear_args, f, args_with_rng, fnwrap, argprefix)
23-
out_tys = [MLIR.IR.type(TracedUtils.get_mlir_data(res)) for res in linear_results]
24-
25-
sym = TracedUtils.get_attribute_by_name(func2, "sym_name")
26-
fn_attr = MLIR.IR.FlatSymbolRefAttribute(Base.String(sym))
27-
18+
fn_attr = MLIR.IR.FlatSymbolRefAttribute(f_name)
2819
symbol_addr = reinterpret(UInt64, pointer_from_objref(symbol))
2920
symbol_attr = @ccall MLIR.API.mlir_c.enzymeSymbolAttrGet(
3021
MLIR.IR.context()::MLIR.API.MlirContext, symbol_addr::UInt64
3122
)::MLIR.IR.Attribute
3223

33-
# Construct logpdf attribute if `logpdf` function is provided.
3424
logpdf_attr = nothing
3525
if logpdf isa Function
3626
samples = f(args_with_rng...)
3727

38-
# Assume that logpdf parameters follow `(sample, args...)` convention.
28+
# Logpdf calling convention: `(sample, args...)` (no rng state)
3929
logpdf_args = (samples, args...)
4030

41-
logpdf_mlir = TracedUtils.make_mlir_fn(
42-
logpdf,
43-
logpdf_args,
44-
(),
45-
string(logpdf),
46-
false;
47-
do_transpose=false,
48-
args_in_result=:result,
31+
logpdf_attr = MLIR.IR.FlatSymbolRefAttribute(
32+
process_probprog_function(logpdf, logpdf_args, "logpdf", false).f_name
4933
)
50-
51-
logpdf_sym = TracedUtils.get_attribute_by_name(logpdf_mlir.f, "sym_name")
52-
logpdf_attr = MLIR.IR.FlatSymbolRefAttribute(Base.String(logpdf_sym))
5334
end
5435

5536
sample_op = MLIR.Dialects.enzyme.sample(
56-
inputs;
57-
outputs=out_tys,
37+
mlir_caller_args;
38+
outputs=mlir_result_types,
5839
fn=fn_attr,
5940
logpdf=logpdf_attr,
6041
symbol=symbol_attr,
6142
name=Base.String(symbol),
6243
)
6344

64-
process_probprog_outputs(
65-
sample_op, linear_results, result, f, args_with_rng, fnwrap, resprefix, argprefix
45+
traced_result = process_probprog_outputs(
46+
sample_op,
47+
linear_results,
48+
traced_result,
49+
f,
50+
args_with_rng,
51+
fnwrapped,
52+
resprefix,
53+
argprefix,
6654
)
6755

68-
return result
56+
return traced_result
6957
end
7058

7159
function untraced_call(rng::AbstractRNG, f::Function, args::Vararg{Any,Nargs}) where {Nargs}
7260
args_with_rng = (rng, args...)
73-
mlir_fn_res, argprefix, resprefix, _ = process_probprog_function(
74-
f, args_with_rng, "call"
75-
)
7661

77-
(; result, linear_args, in_tys, linear_results) = mlir_fn_res
78-
fnwrap = mlir_fn_res.fnwrapped
79-
func2 = mlir_fn_res.f
80-
81-
inputs = process_probprog_inputs(linear_args, f, args_with_rng, fnwrap, argprefix)
82-
out_tys = [MLIR.IR.type(TracedUtils.get_mlir_data(res)) for res in linear_results]
62+
(; f_name, mlir_caller_args, mlir_result_types, traced_result, linear_results, fnwrapped, argprefix, resprefix) = process_probprog_function(
63+
f, args_with_rng, "untraced_call"
64+
)
8365

84-
fname = TracedUtils.get_attribute_by_name(func2, "sym_name")
85-
fn_attr = MLIR.IR.FlatSymbolRefAttribute(Base.String(fname))
66+
fn_attr = MLIR.IR.FlatSymbolRefAttribute(f_name)
8667

87-
call_op = MLIR.Dialects.enzyme.untracedCall(inputs; outputs=out_tys, fn=fn_attr)
68+
call_op = MLIR.Dialects.enzyme.untracedCall(
69+
mlir_caller_args; outputs=mlir_result_types, fn=fn_attr
70+
)
8871

89-
process_probprog_outputs(
90-
call_op, linear_results, result, f, args_with_rng, fnwrap, resprefix, argprefix
72+
traced_result = process_probprog_outputs(
73+
call_op,
74+
linear_results,
75+
traced_result,
76+
f,
77+
args_with_rng,
78+
fnwrapped,
79+
resprefix,
80+
argprefix,
9181
)
9282

93-
return result
83+
return traced_result
9484
end
9585

9686
# Gen-like helper function.
@@ -110,29 +100,34 @@ end
110100

111101
function simulate(rng::AbstractRNG, f::Function, args::Vararg{Any,Nargs}) where {Nargs}
112102
args = (rng, args...)
113-
mlir_fn_res, argprefix, resprefix, _ = process_probprog_function(f, args, "simulate")
114-
115-
(; result, linear_args, in_tys, linear_results) = mlir_fn_res
116-
fnwrap = mlir_fn_res.fnwrapped
117-
func2 = mlir_fn_res.f
118-
119-
inputs = process_probprog_inputs(linear_args, f, args, fnwrap, argprefix)
120-
out_tys = [MLIR.IR.type(TracedUtils.get_mlir_data(res)) for res in linear_results]
121-
122-
fname = TracedUtils.get_attribute_by_name(func2, "sym_name")
123-
fn_attr = MLIR.IR.FlatSymbolRefAttribute(Base.String(fname))
103+
(; f_name, mlir_caller_args, mlir_result_types, traced_result, linear_results, fnwrapped, argprefix, resprefix) = process_probprog_function(
104+
f, args, "simulate"
105+
)
106+
fn_attr = MLIR.IR.FlatSymbolRefAttribute(f_name)
124107

125108
trace_ty = @ccall MLIR.API.mlir_c.enzymeTraceTypeGet(
126109
MLIR.IR.context()::MLIR.API.MlirContext
127110
)::MLIR.IR.Type
128111
weight_ty = MLIR.IR.TensorType(Int64[], MLIR.IR.Type(Float64))
129112

130113
simulate_op = MLIR.Dialects.enzyme.simulate(
131-
inputs; trace=trace_ty, weight=weight_ty, outputs=out_tys, fn=fn_attr
114+
mlir_caller_args;
115+
trace=trace_ty,
116+
weight=weight_ty,
117+
outputs=mlir_result_types,
118+
fn=fn_attr,
132119
)
133120

134-
process_probprog_outputs(
135-
simulate_op, linear_results, result, f, args, fnwrap, resprefix, argprefix, 2
121+
traced_result = process_probprog_outputs(
122+
simulate_op,
123+
linear_results,
124+
traced_result,
125+
f,
126+
args,
127+
fnwrapped,
128+
resprefix,
129+
argprefix,
130+
2,
136131
)
137132

138133
trace = MLIR.IR.result(
@@ -146,7 +141,7 @@ function simulate(rng::AbstractRNG, f::Function, args::Vararg{Any,Nargs}) where
146141
trace = TracedRArray{UInt64,0}((), trace, ())
147142
weight = TracedRArray{Float64,0}((), MLIR.IR.result(simulate_op, 2), ())
148143

149-
return trace, weight, result
144+
return trace, weight, traced_result
150145
end
151146

152147
# Gen-like helper function.
@@ -185,17 +180,12 @@ function generate(
185180
constrained_addresses::Set{Address},
186181
) where {Nargs}
187182
args = (rng, args...)
188-
mlir_fn_res, argprefix, resprefix, _ = process_probprog_function(f, args, "generate")
189183

190-
(; result, linear_args, in_tys, linear_results) = mlir_fn_res
191-
fnwrap = mlir_fn_res.fnwrapped
192-
func2 = mlir_fn_res.f
193-
194-
inputs = process_probprog_inputs(linear_args, f, args, fnwrap, argprefix)
195-
out_tys = [MLIR.IR.type(TracedUtils.get_mlir_data(res)) for res in linear_results]
184+
(; f_name, mlir_caller_args, mlir_result_types, traced_result, linear_results, fnwrapped, argprefix, resprefix) = process_probprog_function(
185+
f, args, "generate"
186+
)
196187

197-
fname = TracedUtils.get_attribute_by_name(func2, "sym_name")
198-
fn_attr = MLIR.IR.FlatSymbolRefAttribute(Base.String(fname))
188+
fn_attr = MLIR.IR.FlatSymbolRefAttribute(f_name)
199189

200190
constraint_ty = @ccall MLIR.API.mlir_c.enzymeConstraintTypeGet(
201191
MLIR.IR.context()::MLIR.API.MlirContext
@@ -229,17 +219,25 @@ function generate(
229219
weight_ty = MLIR.IR.TensorType(Int64[], MLIR.IR.Type(Float64))
230220

231221
generate_op = MLIR.Dialects.enzyme.generate(
232-
inputs,
222+
mlir_caller_args,
233223
constraint_val;
234224
trace=trace_ty,
235225
weight=weight_ty,
236-
outputs=out_tys,
226+
outputs=mlir_result_types,
237227
fn=fn_attr,
238228
constrained_addresses=MLIR.IR.Attribute(constrained_addresses_attr),
239229
)
240230

241-
process_probprog_outputs(
242-
generate_op, linear_results, result, f, args, fnwrap, resprefix, argprefix, 2
231+
traced_result = process_probprog_outputs(
232+
generate_op,
233+
linear_results,
234+
traced_result,
235+
f,
236+
args,
237+
fnwrapped,
238+
resprefix,
239+
argprefix,
240+
2,
243241
)
244242

245243
trace = MLIR.IR.result(
@@ -253,5 +251,5 @@ function generate(
253251
trace = TracedRArray{UInt64,0}((), trace, ())
254252
weight = TracedRArray{Float64,0}((), MLIR.IR.result(generate_op, 2), ())
255253

256-
return trace, weight, result
254+
return trace, weight, traced_result
257255
end

0 commit comments

Comments
 (0)