-
Notifications
You must be signed in to change notification settings - Fork 37
Open
Description
Run this extract using #1113 or any subsequent PR. The correct gradient is -51.0 (both FiniteDifferences and ForwardDiff get this right). However, sometimes Enzyme misses out on some terms and reports -50 or -48. This happens with both slow LDF and fast LDF, so it's unrelated to #1113.
(Make sure to launch Julia with multiple threads, of course)
using DynamicPPL, Distributions, LogDensityProblems, Chairmarks, LinearAlgebra, ADTypes, ForwardDiff, Enzyme, FiniteDifferences
if Threads.nthreads() == 1
error("run this code with multiple threads")
end
const adtypes = (
AutoFiniteDifferences(; fdm=central_fdm(5, 1)),
AutoForwardDiff(),
AutoEnzyme(; mode=set_runtime_activity(Reverse), function_annotation=Const)
)
function check_ldfs(model)
vi = VarInfo(model)
xs = [1.0]
for adtype in adtypes
sldf = DynamicPPL.LogDensityFunction(model, getlogjoint, vi; adtype=adtype)
fldf = DynamicPPL.Experimental.FastLDF(model, getlogjoint, vi; adtype=adtype)
sldf_grad = LogDensityProblems.logdensity_and_gradient(sldf, xs)
fldf_grad = LogDensityProblems.logdensity_and_gradient(fldf, xs)
@show adtype
@show sldf_grad
@show fldf_grad
@assert sldf_grad[2] ≈ fldf_grad[2]
end
end
@model function threads(y=zeros(50))
x ~ Normal()
Threads.@threads for i in eachindex(y)
y[i] ~ Normal(x)
end
end
check_ldfs(threads())Metadata
Metadata
Assignees
Labels
No labels