@@ -26,9 +26,8 @@ function TracedModel(
2626 " Sampling with `$(sampler. alg) ` does not support models with keyword arguments. See issue #2007 for more details." ,
2727 )
2828 end
29- return TracedModel {AbstractSampler,AbstractVarInfo,Model,Tuple} (
30- spl_model, sampler, varinfo, (spl_model. f, args... )
31- )
29+ evaluator = (spl_model. f, args... )
30+ return TracedModel (spl_model, sampler, varinfo, evaluator)
3231end
3332
3433function AdvancedPS. advance! (
@@ -60,20 +59,10 @@ function AdvancedPS.reset_logprob!(trace::TracedModel)
6059 return Accessors. @set trace. model. varinfo = DynamicPPL. resetlogp!! (trace. model. varinfo)
6160end
6261
63- function AdvancedPS. update_rng! (
64- trace:: AdvancedPS.Trace{<:AdvancedPS.LibtaskModel{<:TracedModel}}
65- )
66- # Extract the `args`.
67- args = trace. model. ctask. args
68- # From `args`, extract the `SamplingContext`, which contains the RNG.
69- sampling_context = args[3 ]
70- rng = sampling_context. rng
71- trace. rng = rng
72- return trace
73- end
74-
75- function Libtask. TapedTask (model:: TracedModel , :: Random.AbstractRNG , args... ; kwargs... ) # RNG ?
76- return Libtask. TapedTask (model. evaluator[1 ], model. evaluator[2 : end ]. .. ; kwargs... )
62+ function Libtask. TapedTask (taped_globals, model:: TracedModel ; kwargs... )
63+ return Libtask. TapedTask (
64+ taped_globals, model. evaluator[1 ], model. evaluator[2 : end ]. .. ; kwargs...
65+ )
7766end
7867
7968abstract type ParticleInference <: InferenceAlgorithm end
@@ -403,11 +392,11 @@ end
403392
404393function trace_local_varinfo_maybe (varinfo)
405394 try
406- trace = AdvancedPS . current_trace ()
407- return trace. model. f. varinfo
395+ trace = Libtask . get_taped_globals (Any) . other
396+ return ( trace === nothing ? varinfo : trace . model. f. varinfo) :: AbstractVarInfo
408397 catch e
409398 # NOTE: this heuristic allows Libtask evaluating a model outside a `Trace`.
410- if e == KeyError (:__trace ) || current_task () . storage isa Nothing
399+ if e == KeyError (:task_variable )
411400 return varinfo
412401 else
413402 rethrow (e)
@@ -417,11 +406,10 @@ end
417406
418407function trace_local_rng_maybe (rng:: Random.AbstractRNG )
419408 try
420- trace = AdvancedPS. current_trace ()
421- return trace. rng
409+ return Libtask. get_taped_globals (Any). rng
422410 catch e
423411 # NOTE: this heuristic allows Libtask evaluating a model outside a `Trace`.
424- if e == KeyError (:__trace ) || current_task () . storage isa Nothing
412+ if e == KeyError (:task_variable )
425413 return rng
426414 else
427415 rethrow (e)
@@ -487,6 +475,25 @@ function AdvancedPS.Trace(
487475
488476 tmodel = TracedModel (model, sampler, newvarinfo, rng)
489477 newtrace = AdvancedPS. Trace (tmodel, rng)
490- AdvancedPS. addreference! (newtrace. model. ctask. task, newtrace)
491478 return newtrace
492479end
480+
481+ # We need to tell Libtask which calls may have `produce` calls within them. In practice most
482+ # of these won't be needed, because of inlining and the fact that `might_produce` is only
483+ # called on `:invoke` expressions rather than `:call`s, but since those are implementation
484+ # details of the compiler, we set a bunch of methods as might_produce = true. We start with
485+ # `acclogp_observe!!` which is what calls `produce` and go up the call stack.
486+ Libtask. might_produce (:: Type{<:Tuple{typeof(DynamicPPL.acclogp_observe!!),Vararg}} ) = true
487+ Libtask. might_produce (:: Type{<:Tuple{typeof(DynamicPPL.tilde_observe!!),Vararg}} ) = true
488+ Libtask. might_produce (:: Type{<:Tuple{typeof(DynamicPPL.evaluate!!),Vararg}} ) = true
489+ function Libtask. might_produce (
490+ :: Type{<:Tuple{typeof(DynamicPPL.evaluate_threadsafe!!),Vararg}}
491+ )
492+ return true
493+ end
494+ function Libtask. might_produce (
495+ :: Type{<:Tuple{typeof(DynamicPPL.evaluate_threadunsafe!!),Vararg}}
496+ )
497+ return true
498+ end
499+ Libtask. might_produce (:: Type{<:Tuple{<:DynamicPPL.Model,Vararg}} ) = true
0 commit comments