@@ -20,9 +20,11 @@ import AbstractPPL: condition, decondition, evaluate!!
2020# ######################
2121
2222"""
23- condition(model::BUGSModel, conditioning_spec; regenerate_log_density::Bool=true )
23+ condition(model::BUGSModel, conditioning_spec)
2424
2525Create a new model by conditioning on specified variables with given values.
26+ The returned model uses `UseGraph` evaluation mode. To use optimized log density
27+ computation, call `set_evaluation_mode(model, UseGeneratedLogDensityFunction())`.
2628
2729# Arguments
2830- `model::BUGSModel`: The model to condition
@@ -45,12 +47,6 @@ New `BUGSModel` with:
4547
4648# Examples
4749```jldoctest condition
48- julia> using JuliaBUGS: @bugs, compile, @varname, initialize!
49-
50- julia> using JuliaBUGS.Model: condition, parameters
51-
52- julia> using Test
53-
5450julia> model_def = @bugs begin
5551 for i in 1:3
5652 x[i] ~ Normal(0, 1)
@@ -69,7 +65,7 @@ julia> model_cond.evaluation_env.x[1:2]
6965 2.0
7066
7167julia> parameters(model_cond)
72- 2-element Vector{AbstractPPL. VarName}:
68+ 2-element Vector{VarName}:
7369 x[3]
7470 y
7571
@@ -86,7 +82,7 @@ julia> model_cond2.evaluation_env.x
8682 7.0
8783
8884julia> parameters(model_cond2) # All x[i] removed, only y remains
89- 1-element Vector{AbstractPPL. VarName}:
85+ 1-element Vector{VarName}:
9086 y
9187
9288julia> # Check parameter lengths
@@ -105,8 +101,8 @@ julia> # NamedTuple syntax
105101julia> model_cond3.evaluation_env.y
10610210.0
107103
108- julia> parameters(model_cond3) # y removed, only x[i] remain
109- 3-element Vector{AbstractPPL. VarName}:
104+ julia> sort( parameters(model_cond3); by=string ) # y removed, only x[i] remain
105+ 3-element Vector{VarName}:
110106 x[1]
111107 x[2]
112108 x[3]
@@ -130,12 +126,12 @@ julia> model_cond4.evaluation_env.x[[1, 3]]
130126 3.0
131127
132128julia> parameters(model_cond4)
133- 2-element Vector{AbstractPPL. VarName}:
129+ 2-element Vector{VarName}:
134130 x[2]
135131 y
136132```
137133"""
138- function condition (model:: BUGSModel , conditioning_spec; regenerate_log_density :: Bool = true )
134+ function condition (model:: BUGSModel , conditioning_spec)
139135 # Parse and validate conditioning specification
140136 var_values = _parse_conditioning_spec (conditioning_spec, model):: Dict{<:VarName,<:Any}
141137 vars_to_condition = collect (keys (var_values)):: Vector{<:VarName}
@@ -152,12 +148,12 @@ function condition(model::BUGSModel, conditioning_spec; regenerate_log_density::
152148 new_graph = _mark_as_observed (model. g, expanded_vars)
153149
154150 # Create updated model with conditioned variables
151+ # Log density function will be generated when set_evaluation_mode is called with UseGeneratedLogDensityFunction
155152 return _create_modified_model (
156153 model,
157154 new_graph,
158155 new_evaluation_env;
159156 base_model= isnothing (model. base_model) ? model : model. base_model,
160- regenerate_log_density= regenerate_log_density,
161157 )
162158end
163159
@@ -256,14 +252,6 @@ For base_model restoration (no args):
256252
257253# Examples
258254```jldoctest decondition
259- julia> using JuliaBUGS: @bugs, compile
260-
261- julia> using JuliaBUGS.Model: condition, parameters, decondition
262-
263- julia> using AbstractPPL: @varname
264-
265- julia> using Test
266-
267255julia> model_def = @bugs begin
268256 x ~ dnorm(0, 1)
269257 y ~ dnorm(x, 1)
@@ -276,13 +264,13 @@ julia> # Condition model
276264 model_cond = condition(model, (; x = 1.0, y = 1.5));
277265
278266julia> parameters(model_cond)
279- AbstractPPL. VarName[]
267+ VarName[]
280268
281269julia> # Partial deconditioning with specified variables
282270 model_d1 = decondition(model_cond, [@varname(y)]);
283271
284272julia> parameters(model_d1)
285- 1-element Vector{AbstractPPL. VarName}:
273+ 1-element Vector{VarName}:
286274 y
287275
288276julia> # Full restoration to base model (no arguments)
@@ -351,8 +339,8 @@ julia> # Decondition with subsumption
351339 decondition(model_arr_cond, [@varname(v)])
352340 );
353341
354- julia> parameters(model_arr_decon)
355- 3-element Vector{AbstractPPL. VarName}:
342+ julia> sort( parameters(model_arr_decon); by=string )
343+ 3-element Vector{VarName}:
356344 v[1]
357345 v[2]
358346 v[3]
@@ -552,7 +540,6 @@ function _create_modified_model(
552540 new_graph:: BUGSGraph ,
553541 new_evaluation_env:: NamedTuple ;
554542 base_model= nothing ,
555- regenerate_log_density:: Bool = true ,
556543)
557544 # Create new graph evaluation data
558545 new_graph_evaluation_data = GraphEvaluationData (new_graph)
@@ -563,42 +550,29 @@ function _create_modified_model(
563550 model, new_parameters
564551 )
565552
566- # Generate new log density function and update graph evaluation data
567- new_log_density_computation_function, updated_graph_evaluation_data =
568- if regenerate_log_density
569- _regenerate_log_density_function (
570- model. model_def, new_graph, new_evaluation_env, new_graph_evaluation_data
571- )
572- else
573- # Skip regeneration (fast path): ensure stale code isn't used
574- nothing , new_graph_evaluation_data
575- end
576-
577553 # Recompute mutable symbols for the new graph
578- new_mutable_symbols = get_mutable_symbols (updated_graph_evaluation_data )
554+ new_mutable_symbols = get_mutable_symbols (new_graph_evaluation_data )
579555
580556 # Create the new model with all updated fields
557+ # Log density function is NOT generated here - it will be generated on-demand
558+ # when set_evaluation_mode(model, UseGeneratedLogDensityFunction()) is called
581559 kwargs = Dict {Symbol,Any} (
582560 :untransformed_param_length => new_untransformed_param_length,
583561 :transformed_param_length => new_transformed_param_length,
584562 :evaluation_env => new_evaluation_env,
585- :graph_evaluation_data => updated_graph_evaluation_data ,
563+ :graph_evaluation_data => new_graph_evaluation_data ,
586564 :g => new_graph,
587- :log_density_computation_function => new_log_density_computation_function ,
565+ :log_density_computation_function => nothing ,
588566 :mutable_symbols => new_mutable_symbols,
567+ :evaluation_mode => UseGraph (),
589568 )
590569
591- # Force graph evaluation mode when skipping regeneration to avoid stale compiled code
592- if ! regenerate_log_density
593- kwargs[:evaluation_mode ] = UseGraph ()
594- kwargs[:log_density_computation_function ] = nothing
595- end
596-
597570 # Add base_model if provided
598571 if ! isnothing (base_model)
599572 kwargs[:base_model ] = base_model
600573 end
601574
575+ # Return model without precomputing log density function (on-demand generation)
602576 return BUGSModel (model; kwargs... )
603577end
604578
0 commit comments