diff --git a/HISTORY.md b/HISTORY.md index f181897f7..0f0102ce4 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -21,6 +21,21 @@ You should not need to use these directly, please use `AbstractPPL.condition` an Removed the method `returned(::Model, values, keys)`; please use `returned(::Model, ::AbstractDict{<:VarName})` instead. +The method `DynamicPPL.init` (for implementing `AbstractInitStrategy`) now has a different signature: it must return a tuple of the generated value, plus a transform function that maps it back to unlinked space. +This is a generalisation of the previous behaviour, where `init` would always return an unlinked value (in effect forcing the transform to be the identity function). + +### Other changes + +#### FastLDF + +Added `DynamicPPL.Experimental.FastLDF`, a version of `LogDensityFunction` that provides performance improvements on the order of 2–10× for both model evaluation as well as automatic differentiation. +Exact speedups depend on the model size: larger models have less significant speedups because the bulk of the work is done in calls to `logpdf`. + +Please note that `FastLDF` is currently considered internal and its API may change without warning. +We intend to replace `LogDensityFunction` with `FastLDF` in a release in the near future, but until then we recommend not using it. + +For more information about `FastLDF`, please see https://github.com/TuringLang/DynamicPPL.jl/pull/1113 as well as the `src/fasteval.jl` file, which contains extensive comments. + ## 0.38.9 Remove warning when using Enzyme as the AD backend. diff --git a/docs/src/api.md b/docs/src/api.md index 63dafdfca..e81f18dc7 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -170,6 +170,12 @@ DynamicPPL.prefix ## Utilities +`typed_identity` is the same as `identity`, but with an overload for `with_logabsdet_jacobian` that ensures that it never errors. + +```@docs +typed_identity +``` + It is possible to manually increase (or decrease) the accumulated log likelihood or prior from within a model function. ```@docs @@ -517,10 +523,12 @@ InitFromParams ``` If you wish to write your own, you have to subtype [`DynamicPPL.AbstractInitStrategy`](@ref) and implement the `init` method. +In very rare situations, you may also need to implement `get_param_eltype`, which defines the element type of the parameters generated by the strategy. ```@docs -DynamicPPL.AbstractInitStrategy -DynamicPPL.init +AbstractInitStrategy +init +get_param_eltype ``` ### Choosing a suitable VarInfo diff --git a/ext/DynamicPPLEnzymeCoreExt.jl b/ext/DynamicPPLEnzymeCoreExt.jl index 35159636f..ef21c255b 100644 --- a/ext/DynamicPPLEnzymeCoreExt.jl +++ b/ext/DynamicPPLEnzymeCoreExt.jl @@ -1,16 +1,15 @@ module DynamicPPLEnzymeCoreExt -if isdefined(Base, :get_extension) - using DynamicPPL: DynamicPPL - using EnzymeCore -else - using ..DynamicPPL: DynamicPPL - using ..EnzymeCore -end +using DynamicPPL: DynamicPPL +using EnzymeCore # Mark is_transformed as having 0 derivative. The `nothing` return value is not significant, Enzyme # only checks whether such a method exists, and never runs it. @inline EnzymeCore.EnzymeRules.inactive(::typeof(DynamicPPL.is_transformed), args...) = nothing +# Likewise for get_range_and_linked. +@inline EnzymeCore.EnzymeRules.inactive( + ::typeof(DynamicPPL._get_range_and_linked), args... +) = nothing end diff --git a/ext/DynamicPPLMooncakeExt.jl b/ext/DynamicPPLMooncakeExt.jl index 23a3430eb..8adf66030 100644 --- a/ext/DynamicPPLMooncakeExt.jl +++ b/ext/DynamicPPLMooncakeExt.jl @@ -5,5 +5,8 @@ using Mooncake: Mooncake # This is purely an optimisation. Mooncake.@zero_derivative Mooncake.DefaultCtx Tuple{typeof(is_transformed),Vararg} +Mooncake.@zero_derivative Mooncake.DefaultCtx Tuple{ + typeof(DynamicPPL._get_range_and_linked),Vararg +} end # module diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index c43bd89d5..e9b902363 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -84,8 +84,8 @@ export AbstractVarInfo, # Compiler @model, # Utilities - init, OrderedDict, + typed_identity, # Model Model, getmissings, @@ -113,6 +113,8 @@ export AbstractVarInfo, InitFromPrior, InitFromUniform, InitFromParams, + init, + get_param_eltype, # Pseudo distributions NamedDist, NoDist, @@ -193,6 +195,7 @@ include("abstract_varinfo.jl") include("threadsafe.jl") include("varinfo.jl") include("simple_varinfo.jl") +include("onlyaccs.jl") include("compiler.jl") include("pointwise_logdensities.jl") include("logdensityfunction.jl") diff --git a/src/compiler.jl b/src/compiler.jl index badba9f9d..3324780ca 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -718,14 +718,15 @@ end # TODO(mhauru) matchingvalue has methods that can accept both types and values. Why? # TODO(mhauru) This function needs a more comprehensive docstring. """ - matchingvalue(vi, value) + matchingvalue(param_eltype, value) -Convert the `value` to the correct type for the `vi` object. +Convert the `value` to the correct type, given the element type of the parameters +being used to evaluate the model. """ -function matchingvalue(vi, value) +function matchingvalue(param_eltype, value) T = typeof(value) if hasmissing(T) - _value = convert(get_matching_type(vi, T), value) + _value = convert(get_matching_type(param_eltype, T), value) # TODO(mhauru) Why do we make a deepcopy, even though in the !hasmissing branch we # are happy to return `value` as-is? if _value === value @@ -738,29 +739,30 @@ function matchingvalue(vi, value) end end -function matchingvalue(vi, value::FloatOrArrayType) - return get_matching_type(vi, value) +function matchingvalue(param_eltype, value::FloatOrArrayType) + return get_matching_type(param_eltype, value) end -function matchingvalue(vi, ::TypeWrap{T}) where {T} - return TypeWrap{get_matching_type(vi, T)}() +function matchingvalue(param_eltype, ::TypeWrap{T}) where {T} + return TypeWrap{get_matching_type(param_eltype, T)}() end # TODO(mhauru) This function needs a more comprehensive docstring. What is it for? """ - get_matching_type(vi, ::TypeWrap{T}) where {T} + get_matching_type(param_eltype, ::TypeWrap{T}) where {T} -Get the specialized version of type `T` for `vi`. +Get the specialized version of type `T`, given an element type of the parameters +being used to evaluate the model. """ get_matching_type(_, ::Type{T}) where {T} = T -function get_matching_type(vi, ::Type{<:Union{Missing,AbstractFloat}}) - return Union{Missing,float_type_with_fallback(eltype(vi))} +function get_matching_type(param_eltype, ::Type{<:Union{Missing,AbstractFloat}}) + return Union{Missing,float_type_with_fallback(param_eltype)} end -function get_matching_type(vi, ::Type{<:AbstractFloat}) - return float_type_with_fallback(eltype(vi)) +function get_matching_type(param_eltype, ::Type{<:AbstractFloat}) + return float_type_with_fallback(param_eltype) end -function get_matching_type(vi, ::Type{<:Array{T,N}}) where {T,N} - return Array{get_matching_type(vi, T),N} +function get_matching_type(param_eltype, ::Type{<:Array{T,N}}) where {T,N} + return Array{get_matching_type(param_eltype, T),N} end -function get_matching_type(vi, ::Type{<:Array{T}}) where {T} - return Array{get_matching_type(vi, T)} +function get_matching_type(param_eltype, ::Type{<:Array{T}}) where {T} + return Array{get_matching_type(param_eltype, T)} end diff --git a/src/contexts/init.jl b/src/contexts/init.jl index 83507353f..065a11f29 100644 --- a/src/contexts/init.jl +++ b/src/contexts/init.jl @@ -1,11 +1,11 @@ """ AbstractInitStrategy -Abstract type representing the possible ways of initialising new values for -the random variables in a model (e.g., when creating a new VarInfo). +Abstract type representing the possible ways of initialising new values for the random +variables in a model (e.g., when creating a new VarInfo). -Any subtype of `AbstractInitStrategy` must implement the -[`DynamicPPL.init`](@ref) method. +Any subtype of `AbstractInitStrategy` must implement the [`DynamicPPL.init`](@ref) method, +and very rarely, [`DynamicPPL.get_param_eltype`](@ref). """ abstract type AbstractInitStrategy end @@ -14,14 +14,50 @@ abstract type AbstractInitStrategy end Generate a new value for a random variable with the given distribution. -!!! warning "Return values must be unlinked" - The values returned by `init` must always be in the untransformed space, i.e., - they must be within the support of the original distribution. That means that, - for example, `init(rng, dist, u::InitFromUniform)` will in general return values that - are outside the range [u.lower, u.upper]. +This function must return a tuple `(x, trf)`, where + +- `x` is the generated value + +- `trf` is a function that transforms the generated value back to the unlinked space. If the + value is already in unlinked space, then this should be `DynamicPPL.typed_identity`. You + can also use `Base.identity`, but if you use this, you **must** be confident that + `zero(eltype(x))` will **never** error. See the docstring of `typed_identity` for more + information. """ function init end +""" + DynamicPPL.get_param_eltype(strategy::AbstractInitStrategy) + +Return the element type of the parameters generated from the given initialisation strategy. + +The default implementation returns `Any`. However, for `InitFromParams` which provides known +parameters for evaluating the model, methods are implemented in order to return more specific +types. + +For the most part, a return value of `Any` will actually suffice. However, there are a few +edge cases in DynamicPPL where the element type is needed. These largely relate to +determining the element type of accumulators ahead of time (_before_ evaluation), as well as +promoting type parameters in model arguments. The classic case is when evaluating a model +with ForwardDiff: the accumulators must be set to `Dual`s, and any `Vector{Float64}` +arguments must be promoted to `Vector{Dual}`. Other tracer types, for example those in +SparseConnectivityTracer.jl, also require similar treatment. + +If `AbstractInitStrategy` is never used in combination with tracer types, then it is +perfectly safe to return `Any`. This does not lead to type instability downstream because +the actual accumulators will still be created with concrete Float types (the `Any` is just +used to determine whether the float type needs to be modified). + +(Detail: in fact, the above is not always true. Firstly, the accumulator argument is only +true when evaluating with ThreadSafeVarInfo. See the comments in `DynamicPPL.unflatten` for +more details. For non-threadsafe evaluation, Julia is capable of automatically promoting the +types on its own. Secondly, the promotion only matters if you are trying to directly assign +into a `Vector{Float64}` with a `ForwardDiff.Dual` or similar tracer type, for example using +`xs[i] = MyDual`. This doesn't actually apply to tilde-statements like `xs[i] ~ ...` because +those use `Accessors.@set` under the hood, which also does the promotion for you.) +""" +get_param_eltype(::AbstractInitStrategy) = Any + """ InitFromPrior() @@ -29,7 +65,7 @@ Obtain new values by sampling from the prior distribution. """ struct InitFromPrior <: AbstractInitStrategy end function init(rng::Random.AbstractRNG, ::VarName, dist::Distribution, ::InitFromPrior) - return rand(rng, dist) + return rand(rng, dist), typed_identity end """ @@ -69,43 +105,66 @@ function init(rng::Random.AbstractRNG, ::VarName, dist::Distribution, u::InitFro if x isa Array{<:Any,0} x = x[] end - return x + return x, typed_identity end """ InitFromParams( - params::Union{AbstractDict{<:VarName},NamedTuple}, + params::Any fallback::Union{AbstractInitStrategy,Nothing}=InitFromPrior() ) -Obtain new values by extracting them from the given dictionary or NamedTuple. +Obtain new values by extracting them from the given set of `params`. + +The most common use case is to provide a `NamedTuple` or `AbstractDict{<:VarName}`, which +provides a mapping from variable names to values. However, we leave the type of `params` +open in order to allow for custom parameter storage types. + +## Custom parameter storage types -The parameter `fallback` specifies how new values are to be obtained if they -cannot be found in `params`, or they are specified as `missing`. `fallback` -can either be an initialisation strategy itself, in which case it will be -used to obtain new values, or it can be `nothing`, in which case an error -will be thrown. The default for `fallback` is `InitFromPrior()`. +For `InitFromParams` to work correctly with a custom `params::P`, you need to implement -!!! note - The values in `params` must be provided in the space of the untransformed - distribution. +```julia +DynamicPPL.init(rng, vn::VarName, dist::Distribution, p::InitFromParams{P}) where {P} +``` + +This tells you how to obtain values for the random variable `vn` from `p.params`. Note that +the last argument is `InitFromParams(params)`, not just `params` itself. Please see the +docstring of [`DynamicPPL.init`](@ref) for more information on the expected behaviour. + +If you only use `InitFromParams` with `DynamicPPL.OnlyAccsVarInfo`, as is usually the case, +then you will not need to implement anything else. So far, this is the same as you would do +for creating any new `AbstractInitStrategy` subtype. + +However, to use `InitFromParams` with a full `DynamicPPL.VarInfo`, you *may* also need to +implement + +```julia +DynamicPPL.get_param_eltype(p::InitFromParams{P}) where {P} +``` + +See the docstring of [`DynamicPPL.get_param_eltype`](@ref) for more information on when this +is needed. + +The argument `fallback` specifies how new values are to be obtained if they cannot be found +in `params`, or they are specified as `missing`. `fallback` can either be an initialisation +strategy itself, in which case it will be used to obtain new values, or it can be `nothing`, +in which case an error will be thrown. The default for `fallback` is `InitFromPrior()`. """ struct InitFromParams{P,S<:Union{AbstractInitStrategy,Nothing}} <: AbstractInitStrategy params::P fallback::S + function InitFromParams( - params::AbstractDict{<:VarName}, - fallback::Union{AbstractInitStrategy,Nothing}=InitFromPrior(), - ) - return new{typeof(params),typeof(fallback)}(params, fallback) - end - function InitFromParams( - params::NamedTuple, fallback::Union{AbstractInitStrategy,Nothing}=InitFromPrior() - ) - return InitFromParams(to_varname_dict(params), fallback) + params::P, fallback::Union{AbstractInitStrategy,Nothing}=InitFromPrior() + ) where {P} + return new{P,typeof(fallback)}(params, fallback) end end -function init(rng::Random.AbstractRNG, vn::VarName, dist::Distribution, p::InitFromParams) + +function init( + rng::Random.AbstractRNG, vn::VarName, dist::Distribution, p::InitFromParams{P} +) where {P<:Union{AbstractDict{<:VarName},NamedTuple}} # TODO(penelopeysm): It would be nice to do a check to make sure that all # of the parameters in `p.params` were actually used, and either warn or # error if they aren't. This is actually quite non-trivial though because @@ -119,13 +178,89 @@ function init(rng::Random.AbstractRNG, vn::VarName, dist::Distribution, p::InitF else # TODO(penelopeysm): Since x is user-supplied, maybe we could also # check here that the type / size of x matches the dist? - x + x, typed_identity end else p.fallback === nothing && error("No value was provided for the variable `$(vn)`.") init(rng, vn, dist, p.fallback) end end +function get_param_eltype( + strategy::InitFromParams{<:Union{AbstractDict{<:VarName},NamedTuple}} +) + return infer_nested_eltype(typeof(strategy.params)) +end + +""" + RangeAndLinked + +Suppose we have vectorised parameters `params::AbstractVector{<:Real}`. Each random variable +in the model will in general correspond to a sub-vector of `params`. This struct stores +information about that range, as well as whether the sub-vector represents a linked value or +an unlinked value. + +$(TYPEDFIELDS) +""" +struct RangeAndLinked + # indices that the variable corresponds to in the vectorised parameter + range::UnitRange{Int} + # whether it's linked + is_linked::Bool +end + +""" + VectorWithRanges( + iden_varname_ranges::NamedTuple, + varname_ranges::Dict{VarName,RangeAndLinked}, + vect::AbstractVector{<:Real}, + ) + +A struct that wraps a vector of parameter values, plus information about how random +variables map to ranges in that vector. + +In the simplest case, this could be accomplished only with a single dictionary mapping +VarNames to ranges and link status. However, for performance reasons, we separate out +VarNames with identity optics into a NamedTuple (`iden_varname_ranges`). All +non-identity-optic VarNames are stored in the `varname_ranges` Dict. + +It would be nice to improve the NamedTuple and Dict approach. See, e.g. +https://github.com/TuringLang/DynamicPPL.jl/issues/1116. +""" +struct VectorWithRanges{N<:NamedTuple,T<:AbstractVector{<:Real}} + # This NamedTuple stores the ranges for identity VarNames + iden_varname_ranges::N + # This Dict stores the ranges for all other VarNames + varname_ranges::Dict{VarName,RangeAndLinked} + # The full parameter vector which we index into to get variable values + vect::T +end + +function _get_range_and_linked( + vr::VectorWithRanges, ::VarName{sym,typeof(identity)} +) where {sym} + return vr.iden_varname_ranges[sym] +end +function _get_range_and_linked(vr::VectorWithRanges, vn::VarName) + return vr.varname_ranges[vn] +end +function init( + ::Random.AbstractRNG, + vn::VarName, + dist::Distribution, + p::InitFromParams{<:VectorWithRanges}, +) + vr = p.params + range_and_linked = _get_range_and_linked(vr, vn) + transform = if range_and_linked.is_linked + from_linked_vec_transform(dist) + else + from_vec_transform(dist) + end + return (@view vr.vect[range_and_linked.range]), transform +end +function get_param_eltype(strategy::InitFromParams{<:VectorWithRanges}) + return eltype(strategy.params.vect) +end """ InitContext( @@ -155,9 +290,8 @@ function tilde_assume!!( ctx::InitContext, dist::Distribution, vn::VarName, vi::AbstractVarInfo ) in_varinfo = haskey(vi, vn) - # `init()` always returns values in original space, i.e. possibly - # constrained - x = init(ctx.rng, vn, dist, ctx.strategy) + val, transform = init(ctx.rng, vn, dist, ctx.strategy) + x, inv_logjac = with_logabsdet_jacobian(transform, val) # Determine whether to insert a transformed value into the VarInfo. # If the VarInfo alrady had a value for this variable, we will # keep the same linked status as in the original VarInfo. If not, we @@ -165,17 +299,49 @@ function tilde_assume!!( # is_transformed(vi) returns true if vi is nonempty and all variables in vi # are linked. insert_transformed_value = in_varinfo ? is_transformed(vi, vn) : is_transformed(vi) - y, logjac = if insert_transformed_value - with_logabsdet_jacobian(link_transform(dist), x) + val_to_insert, logjac = if insert_transformed_value + # Calculate the forward logjac and sum them up. + y, fwd_logjac = with_logabsdet_jacobian(link_transform(dist), x) + # Note that if we use VectorWithRanges with a full VarInfo, this double-Jacobian + # calculation wastes a lot of time going from linked vectorised -> unlinked -> + # linked, and `inv_logjac` will also just be the negative of `fwd_logjac`. + # + # However, `VectorWithRanges` is only really used with `OnlyAccsVarInfo`, in which + # case this branch is never hit (since `in_varinfo` will always be false). It does + # mean that the combination of InitFromParams{<:VectorWithRanges} with a full, + # linked, VarInfo will be very slow. That should never really be used, though. So + # (at least for now) we can leave this branch in for full generality with other + # combinations of init strategies / VarInfo. + # + # TODO(penelopeysm): Figure out one day how to refactor this. The crux of the issue + # is that the transform used by `VectorWithRanges` is `from_linked_VEC_transform`, + # which is NOT the same as `inverse(link_transform)` (because there is an additional + # vectorisation step). We need `init` and `tilde_assume!!` to share this information + # but it's not clear right now how to do this. In my opinion, there are a couple of + # potential ways forward: + # + # 1. Just remove metadata entirely so that there is never any need to construct + # a linked vectorised value again. This would require us to use VAIMAcc as the only + # way of getting values. I consider this the best option, but it might take a long + # time. + # + # 2. Clean up the behaviour of bijectors so that we can have a complete separation + # between the linking and vectorisation parts of it. That way, `x` can either be + # unlinked, unlinked vectorised, linked, or linked vectorised, and regardless of + # which it is, we should only need to apply at most one linking and one + # vectorisation transform. Doing so would allow us to remove the first call to + # `with_logabsdet_jacobian`, and instead compose and/or uncompose the + # transformations before calling `with_logabsdet_jacobian` once. + y, -inv_logjac + fwd_logjac else - x, zero(LogProbType) + x, -inv_logjac end # Add the new value to the VarInfo. `push!!` errors if the value already # exists, hence the need for setindex!!. if in_varinfo - vi = setindex!!(vi, y, vn) + vi = setindex!!(vi, val_to_insert, vn) else - vi = push!!(vi, vn, y, dist) + vi = push!!(vi, vn, val_to_insert, dist) end # Neither of these set the `trans` flag so we have to do it manually if # necessary. diff --git a/src/experimental.jl b/src/experimental.jl index 8c82dca68..c644c09b2 100644 --- a/src/experimental.jl +++ b/src/experimental.jl @@ -2,6 +2,8 @@ module Experimental using DynamicPPL: DynamicPPL +include("fasteval.jl") + # This file only defines the names of the functions, and their docstrings. The actual implementations are in `ext/DynamicPPLJETExt.jl`, since we don't want to depend on JET.jl other than as a weak dependency. """ is_suitable_varinfo(model::Model, varinfo::AbstractVarInfo; kwargs...) diff --git a/src/fasteval.jl b/src/fasteval.jl new file mode 100644 index 000000000..5b9b767df --- /dev/null +++ b/src/fasteval.jl @@ -0,0 +1,332 @@ +using DynamicPPL: + AbstractVarInfo, + AccumulatorTuple, + InitContext, + InitFromParams, + LogJacobianAccumulator, + LogLikelihoodAccumulator, + LogPriorAccumulator, + Model, + ThreadSafeVarInfo, + VarInfo, + OnlyAccsVarInfo, + RangeAndLinked, + VectorWithRanges, + Metadata, + VarNamedVector, + default_accumulators, + float_type_with_fallback, + getlogjoint, + getlogjoint_internal, + getloglikelihood, + getlogprior, + getlogprior_internal +using ADTypes: ADTypes +using BangBang: BangBang +using AbstractPPL: AbstractPPL, VarName +using LogDensityProblems: LogDensityProblems +import DifferentiationInterface as DI +using Random: Random + +""" + FastLDF( + model::Model, + getlogdensity::Function=getlogjoint_internal, + varinfo::AbstractVarInfo=VarInfo(model); + adtype::Union{ADTypes.AbstractADType,Nothing}=nothing, + ) + +A struct which contains a model, along with all the information necessary to: + + - calculate its log density at a given point; + - and if `adtype` is provided, calculate the gradient of the log density at that point. + +This information can be extracted using the LogDensityProblems.jl interface, specifically, +using `LogDensityProblems.logdensity` and `LogDensityProblems.logdensity_and_gradient`. If +`adtype` is nothing, then only `logdensity` is implemented. If `adtype` is a concrete AD +backend type, then `logdensity_and_gradient` is also implemented. + +There are several options for `getlogdensity` that are 'supported' out of the box: + +- [`getlogjoint_internal`](@ref): calculate the log joint, including the log-Jacobian term + for any variables that have been linked in the provided VarInfo. +- [`getlogprior_internal`](@ref): calculate the log prior, including the log-Jacobian term + for any variables that have been linked in the provided VarInfo. +- [`getlogjoint`](@ref): calculate the log joint in the model space, ignoring any effects of + linking +- [`getlogprior`](@ref): calculate the log prior in the model space, ignoring any effects of + linking +- [`getloglikelihood`](@ref): calculate the log likelihood (this is unaffected by linking, + since transforms are only applied to random variables) + +!!! note + By default, `FastLDF` uses `getlogjoint_internal`, i.e., the result of + `LogDensityProblems.logdensity(f, x)` will depend on whether the `FastLDF` was created + with a linked or unlinked VarInfo. This is done primarily to ease interoperability with + MCMC samplers. + +If you provide one of these functions, a `VarInfo` will be automatically created for you. If +you provide a different function, you have to manually create a VarInfo and pass it as the +third argument. + +If the `adtype` keyword argument is provided, then this struct will also store the adtype +along with other information for efficient calculation of the gradient of the log density. +Note that preparing a `FastLDF` with an AD type `AutoBackend()` requires the AD backend +itself to have been loaded (e.g. with `import Backend`). + +## Fields + +Note that it is undefined behaviour to access any of a `FastLDF`'s fields, apart from: + +- `fastldf.model`: The original model from which this `FastLDF` was constructed. +- `fastldf.adtype`: The AD type used for gradient calculations, or `nothing` if no AD + type was provided. + +## Extended help + +Up until DynamicPPL v0.38, there have been two ways of evaluating a DynamicPPL model at a +given set of parameters: + +1. With `unflatten` + `evaluate!!` with `DefaultContext`: this stores a vector of parameters + inside a VarInfo's metadata, then reads parameter values from the VarInfo during evaluation. + +2. With `InitFromParams`: this reads parameter values from a NamedTuple or a Dict, and stores + them inside a VarInfo's metadata. + +In general, both of these approaches work fine, but the fact that they modify the VarInfo's +metadata can often be quite wasteful. In particular, it is very common that the only outputs +we care about from model evaluation are those which are stored in accumulators, such as log +probability densities, or `ValuesAsInModel`. + +To avoid this issue, we use `OnlyAccsVarInfo`, which is a VarInfo that only contains +accumulators. It implements enough of the `AbstractVarInfo` interface to not error during +model evaluation. + +Because `OnlyAccsVarInfo` does not store any parameter values, when evaluating a model with +it, it is mandatory that parameters are provided from outside the VarInfo, namely via +`InitContext{<:InitFromParams}`. + +The main problem that we face is that it is not possible to directly implement +`DynamicPPL.init(rng, vn, dist, strategy)` for `strategy::InitFromParams{<:AbstractVector}`. +In particular, it is not clear: + + - which parts of the vector correspond to which random variables, and + - whether the variables are linked or unlinked. + +Traditionally, this problem has been solved by `unflatten`, because that function would +place values into the VarInfo's metadata alongside the information about ranges and linking. +That way, when we evaluate with `DefaultContext`, we can read this information out again. +However, we want to avoid doing this. Thus, here, we _extract this information from the +VarInfo_ a single time when constructing a `FastLDF` object. Inside the `FastLDF, we store: + + - the vector of parameters + - a mapping from VarNames to ranges in that vector, along with link status + +When evaluating the model, this allows us to create an `InitFromParams{VectorWithRanges}`, which +lets us very quickly read parameter values from the vector. + +Note that this assumes that the ranges and link status are static throughout the lifetime of +the `FastLDF` object. Therefore, a `FastLDF` object cannot handle models which have variable +numbers of parameters, or models which may visit random variables in different orders depending +on stochastic control flow. **Indeed, silent errors may occur with such models.** This is a +general limitation of vectorised parameters: the original `unflatten` + `evaluate!!` +approach also fails with such models. +""" +struct FastLDF{ + M<:Model, + AD<:Union{ADTypes.AbstractADType,Nothing}, + F<:Function, + N<:NamedTuple, + ADP<:Union{Nothing,DI.GradientPrep}, +} + model::M + adtype::AD + _getlogdensity::F + _iden_varname_ranges::N + _varname_ranges::Dict{VarName,RangeAndLinked} + _adprep::ADP + + function FastLDF( + model::Model, + getlogdensity::Function=getlogjoint_internal, + varinfo::AbstractVarInfo=VarInfo(model); + adtype::Union{ADTypes.AbstractADType,Nothing}=nothing, + ) + # Figure out which variable corresponds to which index, and + # which variables are linked. + all_iden_ranges, all_ranges = get_ranges_and_linked(varinfo) + # Do AD prep if needed + prep = if adtype === nothing + nothing + else + # Make backend-specific tweaks to the adtype + adtype = DynamicPPL.tweak_adtype(adtype, model, varinfo) + x = [val for val in varinfo[:]] + DI.prepare_gradient( + FastLogDensityAt(model, getlogdensity, all_iden_ranges, all_ranges), + adtype, + x, + ) + end + return new{ + typeof(model), + typeof(adtype), + typeof(getlogdensity), + typeof(all_iden_ranges), + typeof(prep), + }( + model, adtype, getlogdensity, all_iden_ranges, all_ranges, prep + ) + end +end + +################################### +# LogDensityProblems.jl interface # +################################### +""" + fast_ldf_accs(getlogdensity::Function) + +Determine which accumulators are needed for fast evaluation with the given +`getlogdensity` function. +""" +fast_ldf_accs(::Function) = default_accumulators() +fast_ldf_accs(::typeof(getlogjoint_internal)) = default_accumulators() +function fast_ldf_accs(::typeof(getlogjoint)) + return AccumulatorTuple((LogPriorAccumulator(), LogLikelihoodAccumulator())) +end +function fast_ldf_accs(::typeof(getlogprior_internal)) + return AccumulatorTuple((LogPriorAccumulator(), LogJacobianAccumulator())) +end +fast_ldf_accs(::typeof(getlogprior)) = AccumulatorTuple((LogPriorAccumulator(),)) +fast_ldf_accs(::typeof(getloglikelihood)) = AccumulatorTuple((LogLikelihoodAccumulator(),)) + +struct FastLogDensityAt{M<:Model,F<:Function,N<:NamedTuple} + _model::M + _getlogdensity::F + _iden_varname_ranges::N + _varname_ranges::Dict{VarName,RangeAndLinked} +end +function (f::FastLogDensityAt)(params::AbstractVector{<:Real}) + ctx = InitContext( + Random.default_rng(), + InitFromParams( + VectorWithRanges(f._iden_varname_ranges, f._varname_ranges, params), nothing + ), + ) + model = DynamicPPL.setleafcontext(f._model, ctx) + accs = fast_ldf_accs(f._getlogdensity) + # Calling `evaluate!!` would be fine, but would lead to an extra call to resetaccs!!, + # which is unnecessary. So we shortcircuit this by simply calling `_evaluate!!` + # directly. To preserve thread-safety we need to reproduce the ThreadSafeVarInfo logic + # here. + # TODO(penelopeysm): This should _not_ check Threads.nthreads(). I still don't know what + # it _should_ do, but this is wrong regardless. + # https://github.com/TuringLang/DynamicPPL.jl/issues/1086 + vi = if Threads.nthreads() > 1 + accs = map( + acc -> DynamicPPL.convert_eltype(float_type_with_fallback(eltype(params)), acc), + accs, + ) + ThreadSafeVarInfo(OnlyAccsVarInfo(accs)) + else + OnlyAccsVarInfo(accs) + end + _, vi = DynamicPPL._evaluate!!(model, vi) + return f._getlogdensity(vi) +end + +function LogDensityProblems.logdensity(fldf::FastLDF, params::AbstractVector{<:Real}) + return FastLogDensityAt( + fldf.model, fldf._getlogdensity, fldf._iden_varname_ranges, fldf._varname_ranges + )( + params + ) +end + +function LogDensityProblems.logdensity_and_gradient( + fldf::FastLDF, params::AbstractVector{<:Real} +) + return DI.value_and_gradient( + FastLogDensityAt( + fldf.model, fldf._getlogdensity, fldf._iden_varname_ranges, fldf._varname_ranges + ), + fldf._adprep, + fldf.adtype, + params, + ) +end + +###################################################### +# Helper functions to extract ranges and link status # +###################################################### + +# This fails for SimpleVarInfo, but honestly there is no reason to support that here. The +# fact is that evaluation doesn't use a VarInfo, it only uses it once to generate the ranges +# and link status. So there is no motivation to use SimpleVarInfo inside a +# LogDensityFunction any more, we can just always use typed VarInfo. In fact one could argue +# that there is no purpose in supporting untyped VarInfo either. +""" + get_ranges_and_linked(varinfo::VarInfo) + +Given a `VarInfo`, extract the ranges of each variable in the vectorised parameter +representation, along with whether each variable is linked or unlinked. + +This function should return a tuple containing: + +- A NamedTuple mapping VarNames with identity optics to their corresponding `RangeAndLinked` +- A Dict mapping all other VarNames to their corresponding `RangeAndLinked`. +""" +function get_ranges_and_linked(varinfo::VarInfo{<:NamedTuple{syms}}) where {syms} + all_iden_ranges = NamedTuple() + all_ranges = Dict{VarName,RangeAndLinked}() + offset = 1 + for sym in syms + md = varinfo.metadata[sym] + this_md_iden, this_md_others, offset = get_ranges_and_linked_metadata(md, offset) + all_iden_ranges = merge(all_iden_ranges, this_md_iden) + all_ranges = merge(all_ranges, this_md_others) + end + return all_iden_ranges, all_ranges +end +function get_ranges_and_linked(varinfo::VarInfo{<:Union{Metadata,VarNamedVector}}) + all_iden, all_others, _ = get_ranges_and_linked_metadata(varinfo.metadata, 1) + return all_iden, all_others +end +function get_ranges_and_linked_metadata(md::Metadata, start_offset::Int) + all_iden_ranges = NamedTuple() + all_ranges = Dict{VarName,RangeAndLinked}() + offset = start_offset + for (vn, idx) in md.idcs + is_linked = md.is_transformed[idx] + range = md.ranges[idx] .+ (start_offset - 1) + if AbstractPPL.getoptic(vn) === identity + all_iden_ranges = merge( + all_iden_ranges, + NamedTuple((AbstractPPL.getsym(vn) => RangeAndLinked(range, is_linked),)), + ) + else + all_ranges[vn] = RangeAndLinked(range, is_linked) + end + offset += length(range) + end + return all_iden_ranges, all_ranges, offset +end +function get_ranges_and_linked_metadata(vnv::VarNamedVector, start_offset::Int) + all_iden_ranges = NamedTuple() + all_ranges = Dict{VarName,RangeAndLinked}() + offset = start_offset + for (vn, idx) in vnv.varname_to_index + is_linked = vnv.is_unconstrained[idx] + range = vnv.ranges[idx] .+ (start_offset - 1) + if AbstractPPL.getoptic(vn) === identity + all_iden_ranges = merge( + all_iden_ranges, + NamedTuple((AbstractPPL.getsym(vn) => RangeAndLinked(range, is_linked),)), + ) + else + all_ranges[vn] = RangeAndLinked(range, is_linked) + end + offset += length(range) + end + return all_iden_ranges, all_ranges, offset +end diff --git a/src/model.jl b/src/model.jl index 94fcd9fd4..2bcfe8f98 100644 --- a/src/model.jl +++ b/src/model.jl @@ -986,9 +986,13 @@ Return the arguments and keyword arguments to be passed to the evaluator of the ) where {_F,argnames} unwrap_args = [ if is_splat_symbol(var) - :($matchingvalue(varinfo, model.args.$var)...) + :( + $matchingvalue( + $get_param_eltype(varinfo, model.context), model.args.$var + )... + ) else - :($matchingvalue(varinfo, model.args.$var)) + :($matchingvalue($get_param_eltype(varinfo, model.context), model.args.$var)) end for var in argnames ] return quote @@ -1006,6 +1010,30 @@ Return the arguments and keyword arguments to be passed to the evaluator of the end end +""" + get_param_eltype(varinfo::AbstractVarInfo, context::AbstractContext) + +Get the element type of the parameters being used to evaluate a model, using a `varinfo` +under the given `context`. For example, when evaluating a model with ForwardDiff AD, this +should return `ForwardDiff.Dual`. + +By default, this uses `eltype(varinfo)` which is slightly cursed. This relies on the fact +that typically, before evaluation, the parameters will have been inserted into the VarInfo's +metadata field. + +For `InitContext`, it's quite different: because `InitContext` is responsible for supplying +the parameters, we can avoid using `eltype(varinfo)` and instead query the parameters inside +it. See the docstring of `get_param_eltype(strategy::AbstractInitStrategy)` for more +explanation. +""" +function get_param_eltype(vi::AbstractVarInfo, ctx::AbstractParentContext) + return get_param_eltype(vi, DynamicPPL.childcontext(ctx)) +end +get_param_eltype(vi::AbstractVarInfo, ::AbstractContext) = eltype(vi) +function get_param_eltype(::AbstractVarInfo, ctx::InitContext) + return get_param_eltype(ctx.strategy) +end + """ getargnames(model::Model) diff --git a/src/onlyaccs.jl b/src/onlyaccs.jl new file mode 100644 index 000000000..6aff9d04d --- /dev/null +++ b/src/onlyaccs.jl @@ -0,0 +1,31 @@ +""" + OnlyAccsVarInfo + +This is a wrapper around an `AccumulatorTuple` that implements the minimal `AbstractVarInfo` +interface to work with the `tilde_assume!!` and `tilde_observe!!` functions for +`InitContext`. + +Note that this does not implement almost every other AbstractVarInfo interface function, and +so using attempting to use this with a different leaf context such as `DefaultContext` will +result in errors. + +Conceptually, one can also think of this as a VarInfo that doesn't contain a metadata field. +This is also why it only works with `InitContext`: in this case, the parameters used for +evaluation are supplied by the context instead of the metadata. +""" +struct OnlyAccsVarInfo{Accs<:AccumulatorTuple} <: AbstractVarInfo + accs::Accs +end +OnlyAccsVarInfo() = OnlyAccsVarInfo(default_accumulators()) +function OnlyAccsVarInfo(accs::NTuple{N,AbstractAccumulator}) where {N} + return OnlyAccsVarInfo(AccumulatorTuple(accs)) +end + +# Minimal AbstractVarInfo interface +@inline DynamicPPL.maybe_invlink_before_eval!!(vi::OnlyAccsVarInfo, ::Model) = vi +@inline DynamicPPL.getaccs(vi::OnlyAccsVarInfo) = vi.accs +@inline DynamicPPL.setaccs!!(::OnlyAccsVarInfo, accs::AccumulatorTuple) = + OnlyAccsVarInfo(accs) +@inline Base.haskey(::OnlyAccsVarInfo, ::VarName) = false +@inline DynamicPPL.is_transformed(::OnlyAccsVarInfo) = false +@inline BangBang.push!!(vi::OnlyAccsVarInfo, vn, y, dist) = vi diff --git a/src/utils.jl b/src/utils.jl index b55a2f715..75fb805dc 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -15,6 +15,41 @@ This is Float64 on 64-bit systems and Float32 on 32-bit systems. """ const LogProbType = float(Real) +""" + typed_identity(x) + +Identity function, but with an overload for `with_logabsdet_jacobian` to ensure +that it returns a sensible zero logjac. + +The problem with plain old `identity` is that the default definition of +`with_logabsdet_jacobian` for `identity` returns `zero(eltype(x))`: +https://github.com/JuliaMath/ChangesOfVariables.jl/blob/d6a8115fc9b9419decbdb48e2c56ec9675b4c6a4/src/with_ladj.jl#L154 + +This is fine for most samples `x`, but if `eltype(x)` doesn't return a sensible type (e.g. +if it's `Any`), then using `identity` will error with `zero(Any)`. This can happen with, +for example, `ProductNamedTupleDistribution`: + +```julia +julia> using Distributions; d = product_distribution((a = Normal(), b = LKJCholesky(3, 0.5))); + +julia> eltype(rand(d)) +Any +``` + +The same problem precludes us from eventually broadening the scope of DynamicPPL.jl to +support distributions with non-numeric samples. + +Furthermore, in principle, the type of the log-probability should be separate from the type +of the sample. Thus, instead of using `zero(LogProbType)`, we should use the eltype of the +LogJacobianAccumulator. There's no easy way to thread that through here, but if a way to do +this is discovered, then `typed_identity` is what will allow us to obtain that custom +behaviour. +""" +function typed_identity end +@inline typed_identity(x) = x +@inline Bijectors.with_logabsdet_jacobian(::typeof(typed_identity), x) = + (x, zero(LogProbType)) + """ @addlogprob!(ex) diff --git a/test/Project.toml b/test/Project.toml index 2dbd5b455..efd916308 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -6,6 +6,7 @@ Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" +Chairmarks = "0ca39b1e-fe0b-4e98-acfc-b1656634c4de" Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa" DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" diff --git a/test/fasteval.jl b/test/fasteval.jl new file mode 100644 index 000000000..f1c535643 --- /dev/null +++ b/test/fasteval.jl @@ -0,0 +1,232 @@ +module DynamicPPLFastLDFTests + +using AbstractPPL: AbstractPPL +using Chairmarks +using DynamicPPL +using Distributions +using DistributionsAD: filldist +using ADTypes +using DynamicPPL.Experimental: FastLDF +using DynamicPPL.TestUtils.AD: run_ad, WithExpectedResult, NoTest +using LinearAlgebra: I +using Test +using LogDensityProblems: LogDensityProblems + +using ForwardDiff: ForwardDiff +using ReverseDiff: ReverseDiff +# Need to include this block here in case we run this test file standalone +@static if VERSION < v"1.12" + using Pkg + Pkg.add("Mooncake") + using Mooncake: Mooncake +end + +@testset "FastLDF: Correctness" begin + @testset "$(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS + @testset "$varinfo_func" for varinfo_func in [ + DynamicPPL.untyped_varinfo, + DynamicPPL.typed_varinfo, + DynamicPPL.untyped_vector_varinfo, + DynamicPPL.typed_vector_varinfo, + ] + unlinked_vi = varinfo_func(m) + @testset "$islinked" for islinked in (false, true) + vi = if islinked + DynamicPPL.link!!(unlinked_vi, m) + else + unlinked_vi + end + nt_ranges, dict_ranges = DynamicPPL.Experimental.get_ranges_and_linked(vi) + params = map(identity, vi[:]) + # Iterate over all variables + for vn in keys(vi) + # Check that `getindex_internal` returns the same thing as using the ranges + # directly + range_with_linked = if AbstractPPL.getoptic(vn) === identity + nt_ranges[AbstractPPL.getsym(vn)] + else + dict_ranges[vn] + end + @test params[range_with_linked.range] == + DynamicPPL.getindex_internal(vi, vn) + # Check that the link status is correct + @test range_with_linked.is_linked == islinked + end + + # Compare results of FastLDF vs ordinary LogDensityFunction. These tests + # can eventually go once we replace LogDensityFunction with FastLDF, but + # for now it helps to have this check! (Eventually we should just check + # against manually computed log-densities). + # + # TODO(penelopeysm): I think we need to add tests for some really + # pathological models here. + @testset "$getlogdensity" for getlogdensity in ( + DynamicPPL.getlogjoint_internal, + DynamicPPL.getlogjoint, + DynamicPPL.getloglikelihood, + DynamicPPL.getlogprior_internal, + DynamicPPL.getlogprior, + ) + ldf = DynamicPPL.LogDensityFunction(m, getlogdensity, vi) + fldf = FastLDF(m, getlogdensity, vi) + @test LogDensityProblems.logdensity(ldf, params) ≈ + LogDensityProblems.logdensity(fldf, params) + end + end + end + end + + @testset "Threaded observe" begin + if Threads.nthreads() > 1 + @model function threaded(y) + x ~ Normal() + Threads.@threads for i in eachindex(y) + y[i] ~ Normal(x) + end + end + N = 100 + model = threaded(zeros(N)) + ldf = DynamicPPL.Experimental.FastLDF(model) + + xs = [1.0] + @test LogDensityProblems.logdensity(ldf, xs) ≈ + logpdf(Normal(), xs[1]) + N * logpdf(Normal(xs[1]), 0.0) + end + end +end + +@testset "FastLDF: performance" begin + if Threads.nthreads() == 1 + # Evaluating these three models should not lead to any allocations (but only when + # not using TSVI). + @model function f() + x ~ Normal() + return 1.0 ~ Normal(x) + end + @model function submodel_inner() + m ~ Normal(0, 1) + s ~ Exponential() + return (m=m, s=s) + end + # Note that for the allocation tests to work on this one, `inner` has + # to be passed as an argument to `submodel_outer`, instead of just + # being called inside the model function itself + @model function submodel_outer(inner) + params ~ to_submodel(inner) + y ~ Normal(params.m, params.s) + return 1.0 ~ Normal(y) + end + @testset for model in (f(), submodel_inner(), submodel_outer(submodel_inner())) + vi = VarInfo(model) + fldf = DynamicPPL.Experimental.FastLDF( + model, DynamicPPL.getlogjoint_internal, vi + ) + x = vi[:] + bench = median(@be LogDensityProblems.logdensity(fldf, x)) + @test iszero(bench.allocs) + end + end +end + +@testset "AD with FastLDF" begin + # Used as the ground truth that others are compared against. + ref_adtype = AutoForwardDiff() + + test_adtypes = @static if VERSION < v"1.12" + [ + AutoReverseDiff(; compile=false), + AutoReverseDiff(; compile=true), + AutoMooncake(; config=nothing), + ] + else + [AutoReverseDiff(; compile=false), AutoReverseDiff(; compile=true)] + end + + @testset "Correctness" begin + @testset "$(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS + varinfo = VarInfo(m) + linked_varinfo = DynamicPPL.link(varinfo, m) + f = FastLDF(m, getlogjoint_internal, linked_varinfo) + x = map(identity, linked_varinfo[:]) + + # Calculate reference logp + gradient of logp using ForwardDiff + ref_ad_result = run_ad(m, ref_adtype; varinfo=linked_varinfo, test=NoTest()) + ref_logp, ref_grad = ref_ad_result.value_actual, ref_ad_result.grad_actual + + @testset "$adtype" for adtype in test_adtypes + @info "Testing AD on: $(m.f) - $adtype" + + @test run_ad( + m, + adtype; + varinfo=linked_varinfo, + test=WithExpectedResult(ref_logp, ref_grad), + ) isa Any + end + end + end + + # Test that various different ways of specifying array types as arguments work with all + # ADTypes. + @testset "Array argument types" begin + test_m = randn(2, 3) + + function eval_logp_and_grad(model, m, adtype) + ldf = FastLDF(model(); adtype=adtype) + return LogDensityProblems.logdensity_and_gradient(ldf, m[:]) + end + + @model function scalar_matrix_model(::Type{T}=Float64) where {T<:Real} + m = Matrix{T}(undef, 2, 3) + return m ~ filldist(MvNormal(zeros(2), I), 3) + end + + scalar_matrix_model_reference = eval_logp_and_grad( + scalar_matrix_model, test_m, ref_adtype + ) + + @model function matrix_model(::Type{T}=Matrix{Float64}) where {T} + m = T(undef, 2, 3) + return m ~ filldist(MvNormal(zeros(2), I), 3) + end + + matrix_model_reference = eval_logp_and_grad(matrix_model, test_m, ref_adtype) + + @model function scalar_array_model(::Type{T}=Float64) where {T<:Real} + m = Array{T}(undef, 2, 3) + return m ~ filldist(MvNormal(zeros(2), I), 3) + end + + scalar_array_model_reference = eval_logp_and_grad( + scalar_array_model, test_m, ref_adtype + ) + + @model function array_model(::Type{T}=Array{Float64}) where {T} + m = T(undef, 2, 3) + return m ~ filldist(MvNormal(zeros(2), I), 3) + end + + array_model_reference = eval_logp_and_grad(array_model, test_m, ref_adtype) + + @testset "$adtype" for adtype in test_adtypes + scalar_matrix_model_logp_and_grad = eval_logp_and_grad( + scalar_matrix_model, test_m, adtype + ) + @test scalar_matrix_model_logp_and_grad[1] ≈ scalar_matrix_model_reference[1] + @test scalar_matrix_model_logp_and_grad[2] ≈ scalar_matrix_model_reference[2] + matrix_model_logp_and_grad = eval_logp_and_grad(matrix_model, test_m, adtype) + @test matrix_model_logp_and_grad[1] ≈ matrix_model_reference[1] + @test matrix_model_logp_and_grad[2] ≈ matrix_model_reference[2] + scalar_array_model_logp_and_grad = eval_logp_and_grad( + scalar_array_model, test_m, adtype + ) + @test scalar_array_model_logp_and_grad[1] ≈ scalar_array_model_reference[1] + @test scalar_array_model_logp_and_grad[2] ≈ scalar_array_model_reference[2] + array_model_logp_and_grad = eval_logp_and_grad(array_model, test_m, adtype) + @test array_model_logp_and_grad[1] ≈ array_model_reference[1] + @test array_model_logp_and_grad[2] ≈ array_model_reference[2] + end + end +end + +end diff --git a/test/integration/enzyme/main.jl b/test/integration/enzyme/main.jl index b40bbeb8f..ea4ec497d 100644 --- a/test/integration/enzyme/main.jl +++ b/test/integration/enzyme/main.jl @@ -6,8 +6,10 @@ import Enzyme: set_runtime_activity, Forward, Reverse, Const using ForwardDiff: ForwardDiff # run_ad uses FD for correctness test ADTYPES = Dict( - "EnzymeForward" => AutoEnzyme(; mode=set_runtime_activity(Forward)), - "EnzymeReverse" => AutoEnzyme(; mode=set_runtime_activity(Reverse)), + "EnzymeForward" => + AutoEnzyme(; mode=set_runtime_activity(Forward), function_annotation=Const), + "EnzymeReverse" => + AutoEnzyme(; mode=set_runtime_activity(Reverse), function_annotation=Const), ) @testset "$ad_key" for (ad_key, ad_type) in ADTYPES diff --git a/test/runtests.jl b/test/runtests.jl index 861d3bb87..10fac8b0f 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -89,6 +89,7 @@ include("test_util.jl") include("ext/DynamicPPLMooncakeExt.jl") end include("ad.jl") + include("fasteval.jl") end @testset "prob and logprob macro" begin @test_throws ErrorException prob"..."