@@ -794,15 +794,26 @@ julia> # Now `a.x` will be sampled.
794794fixed (model:: Model ) = fixed (model. context)
795795
796796"""
797- (model::Model)([rng, varinfo, sampler, context])
797+ (model::Model)()
798+ (model::Model)(rng[, varinfo, sampler, context])
798799
799- Sample from the `model` using the `sampler` with random number generator `rng` and the
800- `context`, and store the sample and log joint probability in `varinfo`.
800+ Sample from the `model` using the `sampler` with random number generator `rng`
801+ and the `context`, and store the sample and log joint probability in `varinfo`.
801802
802- The method resets the log joint probability of `varinfo` and increases the evaluation
803- number of `sampler`.
803+ Returns the model's return value.
804+
805+ If no arguments are provided, uses the default random number generator and
806+ samples from the prior.
804807"""
805- (model:: Model )(args... ) = first (evaluate!! (model, args... ))
808+ (model:: Model )() = model (Random. default_rng ())
809+ function (model:: Model )(
810+ rng:: AbstractRNG ,
811+ varinfo:: AbstractVarInfo = VarInfo (),
812+ sampler:: AbstractSampler = SampleFromPrior (),
813+ )
814+ spl_ctx = SamplingContext (rng, sampler, DefaultContext ())
815+ return evaluate!! (model, varinfo, spl_ctx)
816+ end
806817
807818"""
808819 use_threadsafe_eval(context::AbstractContext, varinfo::AbstractVarInfo)
@@ -815,65 +826,51 @@ function use_threadsafe_eval(context::AbstractContext, varinfo::AbstractVarInfo)
815826end
816827
817828"""
818- evaluate!!(model::Model[, rng, varinfo, sampler, context])
819-
820- Sample from the `model` using the `sampler` with random number generator `rng` and the
821- `context`, and store the sample and log joint probability in `varinfo`.
829+ sample!!([rng::Random.AbstractRNG, ]model::Model, varinfo)
822830
823- Returns both the return-value of the original model, and the resulting varinfo.
831+ Evaluate the `model` with the given `varinfo`, but perform sampling during the
832+ evaluation by wrapping the model's context in a `SamplingContext`.
824833
825- The method resets the log joint probability of `varinfo` and increases the evaluation
826- number of `sampler`.
834+ Returns a tuple of the model's return value, plus the updated `varinfo` object.
827835"""
828- function AbstractPPL. evaluate!! (
829- model:: Model , varinfo:: AbstractVarInfo , context:: AbstractContext
830- )
831- return if use_threadsafe_eval (context, varinfo)
832- evaluate_threadsafe!! (model, varinfo, context)
833- else
834- evaluate_threadunsafe!! (model, varinfo, context)
835- end
836+ function sample!! (rng:: AbstractRNG , model:: Model , varinfo:: AbstractVarInfo )
837+ sampling_model = contextualize (
838+ model, SamplingContext (rng, SampleFromPrior (), model. context)
839+ )
840+ return evaluate!! (sampling_model, varinfo)
836841end
837842
838- function AbstractPPL. evaluate!! (
839- model:: Model ,
840- rng:: Random.AbstractRNG ,
841- varinfo:: AbstractVarInfo = VarInfo (),
842- sampler:: AbstractSampler = SampleFromPrior (),
843- context:: AbstractContext = DefaultContext (),
844- )
845- return evaluate!! (model, varinfo, SamplingContext (rng, sampler, context))
846- end
843+ """
844+ evaluate!!(model::Model, varinfo)
845+ evaluate!!(model::Model, varinfo, context)
847846
848- function AbstractPPL . evaluate!! ( model:: Model , context:: AbstractContext )
849- return evaluate!! (model, VarInfo (), context)
850- end
847+ Evaluate the ` model` with the given `varinfo`. If an extra context stack is
848+ provided, the model's context is inserted into that context stack. See
849+ [`combine_model_and_external_contexts`](@ref).
851850
852- function AbstractPPL. evaluate!! (
853- model:: Model , args:: Union{AbstractVarInfo,AbstractSampler,AbstractContext} ...
854- )
855- return evaluate!! (model, Random. default_rng (), args... )
856- end
851+ If multiple threads are available, the varinfo provided will be wrapped in a
852+ [`DynamicPPL.ThreadSafeVarInfo`](@ref) before evaluation.
857853
858- # without VarInfo
859- function AbstractPPL. evaluate!! (
860- model:: Model ,
861- rng:: Random.AbstractRNG ,
862- sampler:: AbstractSampler ,
863- args:: AbstractContext... ,
864- )
865- return evaluate!! (model, rng, VarInfo (), sampler, args... )
854+ Returns a tuple of the model's return value, plus the updated `varinfo`
855+ (unwrapped if necessary).
856+ """
857+ function AbstractPPL. evaluate!! (model:: Model , varinfo:: AbstractVarInfo )
858+ return if use_threadsafe_eval (model. context, varinfo)
859+ evaluate_threadsafe!! (model, varinfo)
860+ else
861+ evaluate_threadunsafe!! (model, varinfo)
862+ end
866863end
867-
868- # without VarInfo and without AbstractSampler
869864function AbstractPPL. evaluate!! (
870- model:: Model , rng :: Random.AbstractRNG , context:: AbstractContext
865+ model:: Model , varinfo :: AbstractVarInfo , context:: AbstractContext
871866)
872- return evaluate!! (model, rng, VarInfo (), SampleFromPrior (), context)
867+ new_ctx = combine_model_and_external_contexts (model. context, context)
868+ model = contextualize (model, new_ctx)
869+ return evaluate!! (model, varinfo)
873870end
874871
875872"""
876- evaluate_threadunsafe!!(model, varinfo, context )
873+ evaluate_threadunsafe!!(model, varinfo)
877874
878875Evaluate the `model` without wrapping `varinfo` inside a `ThreadSafeVarInfo`.
879876
@@ -882,8 +879,8 @@ This method is not exposed and supposed to be used only internally in DynamicPPL
882879
883880See also: [`evaluate_threadsafe!!`](@ref)
884881"""
885- function evaluate_threadunsafe!! (model, varinfo, context )
886- return _evaluate!! (model, resetlogp!! (varinfo), context )
882+ function evaluate_threadunsafe!! (model, varinfo)
883+ return _evaluate!! (model, resetlogp!! (varinfo))
887884end
888885
889886"""
@@ -897,31 +894,74 @@ This method is not exposed and supposed to be used only internally in DynamicPPL
897894
898895See also: [`evaluate_threadunsafe!!`](@ref)
899896"""
900- function evaluate_threadsafe!! (model, varinfo, context )
897+ function evaluate_threadsafe!! (model, varinfo)
901898 wrapper = ThreadSafeVarInfo (resetlogp!! (varinfo))
902- result, wrapper_new = _evaluate!! (model, wrapper, context)
899+ result, wrapper_new = _evaluate!! (model, wrapper)
900+ # TODO (penelopeysm): If seems that if you pass a TSVI to this method, it
901+ # will return the underlying VI, which is a bit counterintuitive (because
902+ # calling TSVI(::TSVI) returns the original TSVI, instead of wrapping it
903+ # again).
903904 return result, setaccs!! (wrapper_new. varinfo, getaccs (wrapper_new))
904905end
905906
906907"""
908+ _evaluate!!(model::Model, varinfo)
907909 _evaluate!!(model::Model, varinfo, context)
908910
909- Evaluate the `model` with the arguments matching the given `context` and `varinfo` object.
911+ Evaluate the `model` with the given `varinfo`. If an additional `context` is provided,
912+ the model's context is combined with that context.
913+
914+ This function does not wrap the varinfo in a `ThreadSafeVarInfo`.
910915"""
911- function _evaluate!! (model:: Model , varinfo:: AbstractVarInfo , context :: AbstractContext )
912- args, kwargs = make_evaluate_args_and_kwargs (model, varinfo, context )
916+ function _evaluate!! (model:: Model , varinfo:: AbstractVarInfo )
917+ args, kwargs = make_evaluate_args_and_kwargs (model, varinfo)
913918 return model. f (args... ; kwargs... )
914919end
920+ function _evaluate!! (model:: Model , varinfo:: AbstractVarInfo , context:: AbstractContext )
921+ # TODO (penelopeysm): We don't really need this, but it's a useful
922+ # convenience method. We could remove it after we get rid of the
923+ # evaluate_threadsafe!! stuff (in favour of making users call evaluate!!
924+ # with a TSVI themselves).
925+ new_ctx = combine_model_and_external_contexts (model. context, context)
926+ model = contextualize (model, new_ctx)
927+ return _evaluate!! (model, varinfo)
928+ end
915929
916930is_splat_symbol (s:: Symbol ) = startswith (string (s), " #splat#" )
917931
932+ """
933+ combine_model_and_external_contexts(model_context, external_context)
934+
935+ Combine a context from a model and an external context into a single context.
936+
937+ The resulting context stack has the following structure:
938+
939+ `external_context` -> `childcontext(external_context)` -> ... ->
940+ `model_context` -> `childcontext(model_context)` -> ... ->
941+ `leafcontext(external_context)`
942+
943+ The reason for this is that we want to give `external_context` precedence over
944+ `model_context`, while also preserving the leaf context of `external_context`.
945+ We can do this by
946+
947+ 1. Set the leaf context of `model_context` to `leafcontext(external_context)`.
948+ 2. Set leaf context of `external_context` to the context resulting from (1).
949+ """
950+ function combine_model_and_external_contexts (
951+ model_context:: AbstractContext , external_context:: AbstractContext
952+ )
953+ return setleafcontext (
954+ external_context, setleafcontext (model_context, leafcontext (external_context))
955+ )
956+ end
957+
918958"""
919959 make_evaluate_args_and_kwargs(model, varinfo, context)
920960
921961Return the arguments and keyword arguments to be passed to the evaluator of the model, i.e. `model.f`e.
922962"""
923963@generated function make_evaluate_args_and_kwargs (
924- model:: Model{_F,argnames} , varinfo:: AbstractVarInfo , context :: AbstractContext
964+ model:: Model{_F,argnames} , varinfo:: AbstractVarInfo
925965) where {_F,argnames}
926966 unwrap_args = [
927967 if is_splat_symbol (var)
@@ -930,18 +970,7 @@ Return the arguments and keyword arguments to be passed to the evaluator of the
930970 :($ matchingvalue (varinfo, model. args.$ var))
931971 end for var in argnames
932972 ]
933-
934- # We want to give `context` precedence over `model.context` while also
935- # preserving the leaf context of `context`. We can do this by
936- # 1. Set the leaf context of `model.context` to `leafcontext(context)`.
937- # 2. Set leaf context of `context` to the context resulting from (1).
938- # The result is:
939- # `context` -> `childcontext(context)` -> ... -> `model.context`
940- # -> `childcontext(model.context)` -> ... -> `leafcontext(context)`
941973 return quote
942- context_new = setleafcontext (
943- context, setleafcontext (model. context, leafcontext (context))
944- )
945974 args = (
946975 model,
947976 # Maybe perform `invlink!!` once prior to evaluation to avoid
0 commit comments