Skip to content

Commit acb8bf5

Browse files
Make log density function generation on-demand (#416)
This PR introduces on-demand generation of log density functions for JuliaBUGS models. The `skip_source_generation` parameter has been removed from `compile()` and `BUGSModel()`. Instead, log density functions are now generated on-demand when `set_evaluation_mode(model, UseGeneratedLogDensityFunction())` is called. All models start with `UseGraph()` mode. --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent 5208959 commit acb8bf5

File tree

14 files changed

+222
-277
lines changed

14 files changed

+222
-277
lines changed

JuliaBUGS/benchmark/juliabugs.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,8 @@ end
7474
# ! writing a _function_ to benchmark all models won't work because of world-age error
7575

7676
function benchmark_JuliaBUGS_model_with_Mooncake(model::JuliaBUGS.BUGSModel)
77-
# p = Base.Fix1(LogDensityProblems.logdensity, model)
77+
# Use generated log density function for Mooncake
78+
model = JuliaBUGS.set_evaluation_mode(model, JuliaBUGS.UseGeneratedLogDensityFunction())
7879
p = Base.Fix1(model.log_density_computation_function, model.evaluation_env)
7980
backend = AutoMooncake(; config=nothing)
8081
dim = LogDensityProblems.dimension(model)

JuliaBUGS/benchmark/run_benchmarks.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@ output_file = get(ENV, "BENCHMARK_OUTPUT", "benchmark_results.md")
2929
open(output_file, "w") do io
3030
println(io, "## Benchmark Results\n")
3131
cpu_info = first(Sys.cpu_info())
32-
println(io, "**Julia $(VERSION)** on $(cpu_info.model)\n")
32+
os_info = Sys.KERNEL
33+
println(io, "**Julia $(VERSION)** on $(cpu_info.model) ($(os_info))\n")
3334
println(io, "Ratio = JuliaBUGS/Stan (lower is better for JuliaBUGS)\n")
3435
println(io, "| Model | Stan Params | JBUGS Params | LD Ratio | Grad Ratio |")
3536
println(io, "|:------|------------:|-------------:|---------:|-----------:|")
@@ -52,6 +53,6 @@ open(output_file, "w") do io
5253
end
5354
println(
5455
io,
55-
"\n*Note: Performance comparison may not be apples-to-apples as parameter counts can differ due to different model parameterizations.*",
56+
"\n*Note: Stan benchmarks use hand-optimized Stan models, not direct BUGS translations. Comparison is illustrative only.*",
5657
)
5758
end

JuliaBUGS/examples/gp.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,8 +120,8 @@ model = gp_golf_putting(
120120
data.jitter, # Numerical stability term
121121
)
122122

123-
# Optionally, set the evaluation mode. Using generated functions can be faster.
124-
# model = JuliaBUGS.set_evaluation_mode(model, JuliaBUGS.UseGeneratedLogDensityFunction())
123+
# Generate the log density function for optimal performance
124+
model = JuliaBUGS.set_evaluation_mode(model, JuliaBUGS.UseGeneratedLogDensityFunction())
125125

126126
# --- MCMC Setup with Custom LogDensityProblems Wrapper ---
127127

JuliaBUGS/src/JuliaBUGS.jl

Lines changed: 4 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -239,25 +239,22 @@ function validate_bugs_expression(expr, line_num)
239239
end
240240

241241
"""
242-
compile(model_def, data[, initial_params]; skip_validation=false, skip_source_generation=false)
242+
compile(model_def, data[, initial_params]; skip_validation=false)
243243
244244
Compile the model with model definition and data. Optionally, initializations can be provided.
245245
If initializations are not provided, values will be sampled from the prior distributions.
246246
247247
By default, validates that all functions in the model are in the BUGS allowlist (suitable for @bugs macro).
248248
Set `skip_validation=true` to skip validation (for @model macro usage).
249249
250-
Set `skip_source_generation=true` to bypass generating optimized log-density functions. This produces
251-
models with stable type signatures (no anonymous function types in type parameters), which is essential
252-
for serialization. When enabled, the model will use graph traversal for evaluation (UseGraph mode),
253-
which may be slower but ensures type stability across sessions.
250+
The compiled model uses `UseGraph` evaluation mode by default. To use the optimized generated
251+
log-density function, call `set_evaluation_mode(model, UseGeneratedLogDensityFunction())`.
254252
"""
255253
function compile(
256254
model_def::Expr,
257255
data::NamedTuple,
258256
initial_params::NamedTuple=NamedTuple();
259257
skip_validation::Bool=false,
260-
skip_source_generation::Bool=false,
261258
eval_module::Module=@__MODULE__,
262259
)
263260
# Validate functions by default (for @bugs macro usage)
@@ -287,15 +284,7 @@ function compile(
287284
values(eval_env),
288285
),
289286
)
290-
return BUGSModel(
291-
g,
292-
nonmissing_eval_env,
293-
model_def,
294-
data,
295-
initial_params,
296-
true,
297-
skip_source_generation,
298-
)
287+
return BUGSModel(g, nonmissing_eval_env, model_def, data, initial_params, true)
299288
end
300289
# function compile(
301290
# model_str::String,

JuliaBUGS/src/gibbs.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -215,13 +215,13 @@ julia> using JuliaBUGS: expand_variables, @varname
215215
julia> model_parameters = [@varname(x[1]), @varname(x[2]), @varname(x[3]), @varname(y)];
216216
217217
julia> expand_variables([@varname(x)], model_parameters)
218-
3-element Vector{AbstractPPL.VarName}:
218+
3-element Vector{VarName}:
219219
x[1]
220220
x[2]
221221
x[3]
222222
223223
julia> expand_variables([@varname(x[1]), @varname(y)], model_parameters)
224-
2-element Vector{AbstractPPL.VarName}:
224+
2-element Vector{VarName}:
225225
x[1]
226226
y
227227
```

JuliaBUGS/src/model/abstractppl.jl

Lines changed: 21 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,11 @@ import AbstractPPL: condition, decondition, evaluate!!
2020
#######################
2121

2222
"""
23-
condition(model::BUGSModel, conditioning_spec; regenerate_log_density::Bool=true)
23+
condition(model::BUGSModel, conditioning_spec)
2424
2525
Create a new model by conditioning on specified variables with given values.
26+
The returned model uses `UseGraph` evaluation mode. To use optimized log density
27+
computation, call `set_evaluation_mode(model, UseGeneratedLogDensityFunction())`.
2628
2729
# Arguments
2830
- `model::BUGSModel`: The model to condition
@@ -45,12 +47,6 @@ New `BUGSModel` with:
4547
4648
# Examples
4749
```jldoctest condition
48-
julia> using JuliaBUGS: @bugs, compile, @varname, initialize!
49-
50-
julia> using JuliaBUGS.Model: condition, parameters
51-
52-
julia> using Test
53-
5450
julia> model_def = @bugs begin
5551
for i in 1:3
5652
x[i] ~ Normal(0, 1)
@@ -69,7 +65,7 @@ julia> model_cond.evaluation_env.x[1:2]
6965
2.0
7066
7167
julia> parameters(model_cond)
72-
2-element Vector{AbstractPPL.VarName}:
68+
2-element Vector{VarName}:
7369
x[3]
7470
y
7571
@@ -86,7 +82,7 @@ julia> model_cond2.evaluation_env.x
8682
7.0
8783
8884
julia> parameters(model_cond2) # All x[i] removed, only y remains
89-
1-element Vector{AbstractPPL.VarName}:
85+
1-element Vector{VarName}:
9086
y
9187
9288
julia> # Check parameter lengths
@@ -105,8 +101,8 @@ julia> # NamedTuple syntax
105101
julia> model_cond3.evaluation_env.y
106102
10.0
107103
108-
julia> parameters(model_cond3) # y removed, only x[i] remain
109-
3-element Vector{AbstractPPL.VarName}:
104+
julia> sort(parameters(model_cond3); by=string) # y removed, only x[i] remain
105+
3-element Vector{VarName}:
110106
x[1]
111107
x[2]
112108
x[3]
@@ -130,12 +126,12 @@ julia> model_cond4.evaluation_env.x[[1, 3]]
130126
3.0
131127
132128
julia> parameters(model_cond4)
133-
2-element Vector{AbstractPPL.VarName}:
129+
2-element Vector{VarName}:
134130
x[2]
135131
y
136132
```
137133
"""
138-
function condition(model::BUGSModel, conditioning_spec; regenerate_log_density::Bool=true)
134+
function condition(model::BUGSModel, conditioning_spec)
139135
# Parse and validate conditioning specification
140136
var_values = _parse_conditioning_spec(conditioning_spec, model)::Dict{<:VarName,<:Any}
141137
vars_to_condition = collect(keys(var_values))::Vector{<:VarName}
@@ -152,12 +148,12 @@ function condition(model::BUGSModel, conditioning_spec; regenerate_log_density::
152148
new_graph = _mark_as_observed(model.g, expanded_vars)
153149

154150
# Create updated model with conditioned variables
151+
# Log density function will be generated when set_evaluation_mode is called with UseGeneratedLogDensityFunction
155152
return _create_modified_model(
156153
model,
157154
new_graph,
158155
new_evaluation_env;
159156
base_model=isnothing(model.base_model) ? model : model.base_model,
160-
regenerate_log_density=regenerate_log_density,
161157
)
162158
end
163159

@@ -256,14 +252,6 @@ For base_model restoration (no args):
256252
257253
# Examples
258254
```jldoctest decondition
259-
julia> using JuliaBUGS: @bugs, compile
260-
261-
julia> using JuliaBUGS.Model: condition, parameters, decondition
262-
263-
julia> using AbstractPPL: @varname
264-
265-
julia> using Test
266-
267255
julia> model_def = @bugs begin
268256
x ~ dnorm(0, 1)
269257
y ~ dnorm(x, 1)
@@ -276,13 +264,13 @@ julia> # Condition model
276264
model_cond = condition(model, (; x = 1.0, y = 1.5));
277265
278266
julia> parameters(model_cond)
279-
AbstractPPL.VarName[]
267+
VarName[]
280268
281269
julia> # Partial deconditioning with specified variables
282270
model_d1 = decondition(model_cond, [@varname(y)]);
283271
284272
julia> parameters(model_d1)
285-
1-element Vector{AbstractPPL.VarName}:
273+
1-element Vector{VarName}:
286274
y
287275
288276
julia> # Full restoration to base model (no arguments)
@@ -351,8 +339,8 @@ julia> # Decondition with subsumption
351339
decondition(model_arr_cond, [@varname(v)])
352340
);
353341
354-
julia> parameters(model_arr_decon)
355-
3-element Vector{AbstractPPL.VarName}:
342+
julia> sort(parameters(model_arr_decon); by=string)
343+
3-element Vector{VarName}:
356344
v[1]
357345
v[2]
358346
v[3]
@@ -552,7 +540,6 @@ function _create_modified_model(
552540
new_graph::BUGSGraph,
553541
new_evaluation_env::NamedTuple;
554542
base_model=nothing,
555-
regenerate_log_density::Bool=true,
556543
)
557544
# Create new graph evaluation data
558545
new_graph_evaluation_data = GraphEvaluationData(new_graph)
@@ -563,42 +550,29 @@ function _create_modified_model(
563550
model, new_parameters
564551
)
565552

566-
# Generate new log density function and update graph evaluation data
567-
new_log_density_computation_function, updated_graph_evaluation_data =
568-
if regenerate_log_density
569-
_regenerate_log_density_function(
570-
model.model_def, new_graph, new_evaluation_env, new_graph_evaluation_data
571-
)
572-
else
573-
# Skip regeneration (fast path): ensure stale code isn't used
574-
nothing, new_graph_evaluation_data
575-
end
576-
577553
# Recompute mutable symbols for the new graph
578-
new_mutable_symbols = get_mutable_symbols(updated_graph_evaluation_data)
554+
new_mutable_symbols = get_mutable_symbols(new_graph_evaluation_data)
579555

580556
# Create the new model with all updated fields
557+
# Log density function is NOT generated here - it will be generated on-demand
558+
# when set_evaluation_mode(model, UseGeneratedLogDensityFunction()) is called
581559
kwargs = Dict{Symbol,Any}(
582560
:untransformed_param_length => new_untransformed_param_length,
583561
:transformed_param_length => new_transformed_param_length,
584562
:evaluation_env => new_evaluation_env,
585-
:graph_evaluation_data => updated_graph_evaluation_data,
563+
:graph_evaluation_data => new_graph_evaluation_data,
586564
:g => new_graph,
587-
:log_density_computation_function => new_log_density_computation_function,
565+
:log_density_computation_function => nothing,
588566
:mutable_symbols => new_mutable_symbols,
567+
:evaluation_mode => UseGraph(),
589568
)
590569

591-
# Force graph evaluation mode when skipping regeneration to avoid stale compiled code
592-
if !regenerate_log_density
593-
kwargs[:evaluation_mode] = UseGraph()
594-
kwargs[:log_density_computation_function] = nothing
595-
end
596-
597570
# Add base_model if provided
598571
if !isnothing(base_model)
599572
kwargs[:base_model] = base_model
600573
end
601574

575+
# Return model without precomputing log density function (on-demand generation)
602576
return BUGSModel(model; kwargs...)
603577
end
604578

0 commit comments

Comments
 (0)