Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
7cddac7
Fast Log Density Function
penelopeysm Nov 5, 2025
5ed4295
Make it work with AD
penelopeysm Nov 6, 2025
e199520
Optimise performance for identity VarNames
penelopeysm Nov 6, 2025
4cefaca
Mark `get_range_and_linked` as having zero derivative
penelopeysm Nov 6, 2025
6dfd106
Update comment
penelopeysm Nov 6, 2025
4ca9cf7
Squeeze down VarInfo allocations
mhauru Nov 6, 2025
7c6e8c1
Remove old out-of-date comment
mhauru Nov 6, 2025
5c817a4
implement `is_transformed(::VarNamedVector)`
penelopeysm Nov 6, 2025
93daa2b
Handle errors in benchmark suite
penelopeysm Nov 6, 2025
41ee7f3
make AD testing / benchmarking use FastLDF
penelopeysm Nov 6, 2025
22e32a6
Fix tests
penelopeysm Nov 6, 2025
79cc128
Optimise away `make_evaluate_args_and_kwargs`
penelopeysm Nov 6, 2025
f7c6a78
const func annotation
penelopeysm Nov 6, 2025
b1a7650
Disable benchmarks on non-typed-Metadata-VarInfo
penelopeysm Nov 6, 2025
e60873a
Fix `_evaluate!!` correctly to handle submodels
penelopeysm Nov 6, 2025
fa0664e
Actually fix submodel evaluate
penelopeysm Nov 6, 2025
09a1fbb
Document thoroughly and organise code
penelopeysm Nov 6, 2025
7306ba4
Support more VarInfos, make it thread-safe (?)
penelopeysm Nov 6, 2025
53bccc1
fix bug in parsing ranges from metadata/VNV
penelopeysm Nov 6, 2025
30b9247
Fix get_param_eltype for TSVI
penelopeysm Nov 6, 2025
316937a
Disable Enzyme benchmark
penelopeysm Nov 6, 2025
7fafc86
Merge branch 'mhauru/no-allocs-allowed' into py/fastldf2
penelopeysm Nov 6, 2025
9c71e81
Revert "Handle errors in benchmark suite"
penelopeysm Nov 6, 2025
075cee8
Don't override _evaluate!!, that breaks ForwardDiff (sometimes)
penelopeysm Nov 6, 2025
c44bae1
Merge branch 'py/fastldf' into py/fastldf2
penelopeysm Nov 6, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions benchmarks/benchmarks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,14 +59,14 @@ chosen_combinations = [
false,
),
("Smorgasbord", smorgasbord_instance, :typed, :forwarddiff, false),
("Smorgasbord", smorgasbord_instance, :simple_namedtuple, :forwarddiff, true),
("Smorgasbord", smorgasbord_instance, :untyped, :forwarddiff, true),
("Smorgasbord", smorgasbord_instance, :simple_dict, :forwarddiff, true),
("Smorgasbord", smorgasbord_instance, :typed_vector, :forwarddiff, true),
("Smorgasbord", smorgasbord_instance, :untyped_vector, :forwarddiff, true),
# ("Smorgasbord", smorgasbord_instance, :simple_namedtuple, :forwarddiff, true),
# ("Smorgasbord", smorgasbord_instance, :untyped, :forwarddiff, true),
# ("Smorgasbord", smorgasbord_instance, :simple_dict, :forwarddiff, true),
# ("Smorgasbord", smorgasbord_instance, :typed_vector, :forwarddiff, true),
# ("Smorgasbord", smorgasbord_instance, :untyped_vector, :forwarddiff, true),
("Smorgasbord", smorgasbord_instance, :typed, :reversediff, true),
("Smorgasbord", smorgasbord_instance, :typed, :mooncake, true),
("Smorgasbord", smorgasbord_instance, :typed, :enzyme, true),
# ("Smorgasbord", smorgasbord_instance, :typed, :enzyme, true),
("Loop univariate 1k", loop_univariate1k, :typed, :mooncake, true),
("Multivariate 1k", multivariate1k, :typed, :mooncake, true),
("Loop univariate 10k", loop_univariate10k, :typed, :mooncake, true),
Expand Down
4 changes: 1 addition & 3 deletions benchmarks/src/DynamicPPLBenchmarks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,7 @@ function make_suite(model, varinfo_choice::Symbol, adbackend::Symbol, islinked::
vi = DynamicPPL.link(vi, model)
end

f = DynamicPPL.LogDensityFunction(
model, DynamicPPL.getlogjoint_internal, vi; adtype=adbackend
)
f = DynamicPPL.FastLDF(model, DynamicPPL.getlogjoint_internal, vi; adtype=adbackend)
# The parameters at which we evaluate f.
θ = vi[:]

Expand Down
13 changes: 6 additions & 7 deletions ext/DynamicPPLEnzymeCoreExt.jl
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
module DynamicPPLEnzymeCoreExt

if isdefined(Base, :get_extension)
using DynamicPPL: DynamicPPL
using EnzymeCore
else
using ..DynamicPPL: DynamicPPL
using ..EnzymeCore
end
using DynamicPPL: DynamicPPL
using EnzymeCore

# Mark is_transformed as having 0 derivative. The `nothing` return value is not significant, Enzyme
# only checks whether such a method exists, and never runs it.
@inline EnzymeCore.EnzymeRules.inactive(::typeof(DynamicPPL.is_transformed), args...) =
nothing
# Likewise for get_range_and_linked.
@inline EnzymeCore.EnzymeRules.inactive(
::typeof(DynamicPPL.get_range_and_linked), args...
) = nothing

end
3 changes: 2 additions & 1 deletion ext/DynamicPPLMooncakeExt.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
module DynamicPPLMooncakeExt

using DynamicPPL: DynamicPPL, is_transformed
using DynamicPPL: DynamicPPL, is_transformed, get_range_and_linked
using Mooncake: Mooncake

# This is purely an optimisation.
Mooncake.@zero_derivative Mooncake.DefaultCtx Tuple{typeof(is_transformed),Vararg}
Mooncake.@zero_derivative Mooncake.DefaultCtx Tuple{typeof(get_range_and_linked),Vararg}

end # module
1 change: 1 addition & 0 deletions src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@ include("simple_varinfo.jl")
include("compiler.jl")
include("pointwise_logdensities.jl")
include("logdensityfunction.jl")
include("fastldf.jl")
include("model_utils.jl")
include("extract_priors.jl")
include("values_as_in_model.jl")
Expand Down
38 changes: 20 additions & 18 deletions src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -718,14 +718,15 @@ end
# TODO(mhauru) matchingvalue has methods that can accept both types and values. Why?
# TODO(mhauru) This function needs a more comprehensive docstring.
"""
matchingvalue(vi, value)
matchingvalue(param_eltype, value)
Convert the `value` to the correct type for the `vi` object.
Convert the `value` to the correct type, given the element type of the parameters
being used to evaluate the model.
"""
function matchingvalue(vi, value)
function matchingvalue(param_eltype, value)
T = typeof(value)
if hasmissing(T)
_value = convert(get_matching_type(vi, T), value)
_value = convert(get_matching_type(param_eltype, T), value)
# TODO(mhauru) Why do we make a deepcopy, even though in the !hasmissing branch we
# are happy to return `value` as-is?
if _value === value
Expand All @@ -738,29 +739,30 @@ function matchingvalue(vi, value)
end
end

function matchingvalue(vi, value::FloatOrArrayType)
return get_matching_type(vi, value)
function matchingvalue(param_eltype, value::FloatOrArrayType)
return get_matching_type(param_eltype, value)
end
function matchingvalue(vi, ::TypeWrap{T}) where {T}
return TypeWrap{get_matching_type(vi, T)}()
function matchingvalue(param_eltype, ::TypeWrap{T}) where {T}
return TypeWrap{get_matching_type(param_eltype, T)}()
end

# TODO(mhauru) This function needs a more comprehensive docstring. What is it for?
"""
get_matching_type(vi, ::TypeWrap{T}) where {T}
get_matching_type(param_eltype, ::TypeWrap{T}) where {T}
Get the specialized version of type `T` for `vi`.
Get the specialized version of type `T`, given an element type of the parameters
being used to evaluate the model.
"""
get_matching_type(_, ::Type{T}) where {T} = T
function get_matching_type(vi, ::Type{<:Union{Missing,AbstractFloat}})
return Union{Missing,float_type_with_fallback(eltype(vi))}
function get_matching_type(param_eltype, ::Type{<:Union{Missing,AbstractFloat}})
return Union{Missing,float_type_with_fallback(param_eltype)}
end
function get_matching_type(vi, ::Type{<:AbstractFloat})
return float_type_with_fallback(eltype(vi))
function get_matching_type(param_eltype, ::Type{<:AbstractFloat})
return float_type_with_fallback(param_eltype)
end
function get_matching_type(vi, ::Type{<:Array{T,N}}) where {T,N}
return Array{get_matching_type(vi, T),N}
function get_matching_type(param_eltype, ::Type{<:Array{T,N}}) where {T,N}
return Array{get_matching_type(param_eltype, T),N}
end
function get_matching_type(vi, ::Type{<:Array{T}}) where {T}
return Array{get_matching_type(vi, T)}
function get_matching_type(param_eltype, ::Type{<:Array{T}}) where {T}
return Array{get_matching_type(param_eltype, T)}
end
Loading
Loading