Skip to content

Conversation

@penelopeysm
Copy link
Member

@penelopeysm penelopeysm commented Nov 11, 2025

This reimplements FastLDF, conceptually in the same way as #1113. Please see that PR for the bulk of the explanation. The difference is that this also unifies the implementation of FastLDF and InitFromParams, such that FastLDF is now actually just InitFromParams but backed by the combination of vector + ranges.

Here's a slightly modified diagram from my slides yesterday:

Diagram showing how InitFromParams and FastLDF are related

Other speedups

Note that this unification also means that other initialisation strategies, i.e. InitFromPrior, InitFromUniform, and other forms of InitFromParams, can also benefit from the speedup (as shown in the top half of the diagram above). This was essentially done in #1125 but lumped into this PR as well. See that PR for benchmarks.

Does this still need to be Experimental?

I'd suggest for this PR yes, if only just to prove correctness compared to old LDF. Making this replace old LDF should be a fairly trivial follow-up. Am open to other ideas.

Does this need to be breaking?

Yes, because the expected return value of DynamicPPL.init has changed. Technically, that wasn't exported, but AbstractInitStrategy was exported, so init was effectively public (it should have been exported).

On top of that, this PR relies on changes in #1133, which are also breaking.

Benchmarks

Performance characteristics are exactly the same as in the original PR #1113. Benchmarks run on Julia 1.11.7, 1 thread.

Benchmarking code
using DynamicPPL, Distributions, LogDensityProblems, Chairmarks, LinearAlgebra
using ADTypes, ForwardDiff, ReverseDiff
@static if VERSION < v"1.12"
    using Enzyme, Mooncake
end

const adtypes = @static if VERSION < v"1.12"
    [
        ("FD", AutoForwardDiff()),
        ("RD", AutoReverseDiff()),
        ("MC", AutoMooncake()),
        ("EN" => AutoEnzyme(; mode=set_runtime_activity(Reverse), function_annotation=Const))
    ]
else
    [
        ("FD", AutoForwardDiff()),
        ("RD", AutoReverseDiff()),
    ]
end

function benchmark_ldfs(model; skip=Union{})
    vi = VarInfo(model)
    x = vi[:]
    ldf_no = DynamicPPL.LogDensityFunction(model, getlogjoint, vi)
    fldf_no = DynamicPPL.Experimental.FastLDF(model, getlogjoint, vi)
    @assert LogDensityProblems.logdensity(ldf_no, x)  LogDensityProblems.logdensity(fldf_no, x)
    median_old = median(@be LogDensityProblems.logdensity(ldf_no, x))
    print("LogDensityFunction: eval      ----  ")
    display(median_old)
    median_new = median(@be LogDensityProblems.logdensity(fldf_no, x))
    print("           FastLDF: eval      ----  ")
    display(median_new)
    println("                  speedup     ----  ", median_old.time / median_new.time)
    for name_adtype in adtypes
        name, adtype = name_adtype
        adtype isa skip && continue
        ldf = DynamicPPL.LogDensityFunction(model, getlogjoint, vi; adtype=adtype)
        fldf = DynamicPPL.Experimental.FastLDF(model, getlogjoint, vi; adtype=adtype)
        ldf_grad = LogDensityProblems.logdensity_and_gradient(ldf, x)
        fldf_grad = LogDensityProblems.logdensity_and_gradient(fldf, x)
        @assert ldf_grad[2]  fldf_grad[2]
        median_old = median(@be LogDensityProblems.logdensity_and_gradient(ldf, x))
        print("LogDensityFunction: grad ($name) ----  ")
        display(median_old)
        median_new = median(@be LogDensityProblems.logdensity_and_gradient(fldf, x))
        print("           FastLDF: grad ($name) ----  ")
        display(median_new)
        println("                 speedup ($name) ----  ", median_old.time / median_new.time)
    end
end

@model f() = x ~ Normal()
benchmark_ldfs(f())

y = [28, 8, -3, 7, -1, 1, 18, 12]
sigma = [15, 10, 16, 11, 9, 11, 10, 18]
@model function eight_schools(y, sigma)
    mu ~ Normal(0, 5)
    tau ~ truncated(Cauchy(0, 5); lower=0)
    theta ~ MvNormal(fill(mu, length(y)), tau^2 * I)
    for i in eachindex(y)
        y[i] ~ Normal(theta[i], sigma[i])
    end
    return (mu=mu, tau=tau)
end
benchmark_ldfs(eight_schools(y, sigma))

@model function badvarnames()
    N = 20
    x = Vector{Float64}(undef, N)
    for i in 1:N
        x[i] ~ Normal()
    end
end
benchmark_ldfs(badvarnames())

@model function inner()
    m ~ Normal(0, 1)
    s ~ Exponential()
    return (m=m, s=s)
end
@model function withsubmodel()
    params ~ to_submodel(inner())
    y ~ Normal(params.m, params.s)
    1.0 ~ Normal(y)
end
benchmark_ldfs(withsubmodel())

Trivial model

julia> benchmark_ldfs(f())
LogDensityFunction: eval      ----  170.964 ns (6 allocs: 192 bytes)
           FastLDF: eval      ----  10.944 ns
                  speedup     ----  15.621700990952178
LogDensityFunction: grad (FD) ----  317.935 ns (13 allocs: 496 bytes)
           FastLDF: grad (FD) ----  54.127 ns (3 allocs: 96 bytes)
                 speedup (FD) ----  5.873855538906422
LogDensityFunction: grad (RD) ----  4.250 μs (82 allocs: 3.062 KiB)
           FastLDF: grad (RD) ----  3.019 μs (46 allocs: 1.562 KiB)
                 speedup (RD) ----  1.4079581845621525
LogDensityFunction: grad (MC) ----  1.100 μs (25 allocs: 1.219 KiB)
           FastLDF: grad (MC) ----  338.481 ns (4 allocs: 192 bytes)
                 speedup (MC) ----  3.250793303424883
LogDensityFunction: grad (EN) ----  432.455 ns (16 allocs: 560 bytes)
           FastLDF: grad (EN) ----  128.409 ns (2 allocs: 64 bytes)
                 speedup (EN) ----  3.3677876106194686

Eight-schools centred

LogDensityFunction: eval      ----  877.594 ns (21 allocs: 1.344 KiB)
           FastLDF: eval      ----  209.184 ns (4 allocs: 256 bytes)
                  speedup     ----  4.195326219512196
LogDensityFunction: grad (FD) ----  1.611 μs (28 allocs: 5.484 KiB)
           FastLDF: grad (FD) ----  672.465 ns (11 allocs: 2.594 KiB)
                 speedup (FD) ----  2.3956633005948262
LogDensityFunction: grad (RD) ----  40.209 μs (614 allocs: 25.562 KiB)
           FastLDF: grad (RD) ----  38.708 μs (562 allocs: 20.562 KiB)
                 speedup (RD) ----  1.03877751369226
LogDensityFunction: grad (MC) ----  4.528 μs (64 allocs: 4.016 KiB)
           FastLDF: grad (MC) ----  1.183 μs (12 allocs: 784 bytes)
                 speedup (MC) ----  3.8262402956653028
LogDensityFunction: grad (EN) ----  1.858 μs (44 allocs: 2.609 KiB)
           FastLDF: grad (EN) ----  739.026 ns (13 allocs: 832 bytes)
                 speedup (EN) ----  2.5145699058742537

Lots of IndexLenses

LogDensityFunction: eval      ----  1.448 μs (46 allocs: 1.906 KiB)
           FastLDF: eval      ----  459.641 ns (2 allocs: 224 bytes)
                  speedup     ----  3.150069687595608
LogDensityFunction: grad (FD) ----  4.535 μs (103 allocs: 14.266 KiB)
           FastLDF: grad (FD) ----  2.697 μs (11 allocs: 4.281 KiB)
                 speedup (FD) ----  1.6813743665801506
LogDensityFunction: grad (RD) ----  59.584 μs (1076 allocs: 38.828 KiB)
           FastLDF: grad (RD) ----  51.209 μs (773 allocs: 27.438 KiB)
                 speedup (RD) ----  1.1635454705227597
LogDensityFunction: grad (MC) ----  6.656 μs (160 allocs: 7.000 KiB)
           FastLDF: grad (MC) ----  2.229 μs (28 allocs: 1.094 KiB)
                 speedup (MC) ----  2.985981308411215
LogDensityFunction: grad (EN) ----  3.271 μs (64 allocs: 6.141 KiB)
           FastLDF: grad (EN) ----  1.608 μs (5 allocs: 2.188 KiB)
                 speedup (EN) ----  2.0343495042622473

Submodel

LogDensityFunction: eval      ----  867.424 ns (20 allocs: 1.234 KiB)
           FastLDF: eval      ----  103.168 ns
                  speedup     ----  8.407896391298882
LogDensityFunction: grad (FD) ----  1.175 μs (27 allocs: 2.219 KiB)
           FastLDF: grad (FD) ----  187.776 ns (3 allocs: 112 bytes)
                 speedup (FD) ----  6.25744516852358
LogDensityFunction: grad (RD) ----  13.959 μs (221 allocs: 9.266 KiB)
           FastLDF: grad (RD) ----  10.896 μs (148 allocs: 5.188 KiB)
                 speedup (RD) ----  1.2811252351888394
LogDensityFunction: grad (MC) ----  5.750 μs (72 allocs: 3.312 KiB)
           FastLDF: grad (MC) ----  599.667 ns (6 allocs: 240 bytes)
                 speedup (MC) ----  9.588660366870483
LogDensityFunction: grad (EN) ----  2.432 μs (52 allocs: 2.500 KiB)
           FastLDF: grad (EN) ----  341.659 ns (2 allocs: 80 bytes)
                 speedup (EN) ----  7.117680019783942

MCMC

using Turing, Random, LinearAlgebra
y = [28, 8, -3, 7, -1, 1, 18, 12]
sigma = [15, 10, 16, 11, 9, 11, 10, 18]
J = 8
@model function eight_schools(y, sigma)
    mu ~ Normal(0, 5)
    tau ~ truncated(Cauchy(0, 5); lower=0)
    theta ~ MvNormal(fill(mu, length(sigma)), tau^2 * I)
    for i in eachindex(sigma)
        y[i] ~ Normal(theta[i], sigma[i])
    end
    return (mu=mu, tau=tau)
end
model = eight_schools(y, sigma);

using Enzyme, ADTypes
adtype = AutoEnzyme(; mode=set_runtime_activity(Reverse), function_annotation=Const)
@time sample(model, NUTS(; adtype=adtype), 1000; nadapts=10000, thinning=10, progress=false, verbose=false);

is down from around 8.8 seconds to 1.7 seconds.

@github-actions
Copy link
Contributor

github-actions bot commented Nov 11, 2025

Benchmark Report for Commit 62a8746

Computer Information

Julia Version 1.11.7
Commit f2b3dbda30a (2025-09-08 12:10 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 4 × AMD EPYC 7763 64-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, znver3)
Threads: 1 default, 0 interactive, 1 GC (on 4 virtual cores)

Benchmark Results

┌───────────────────────┬───────┬─────────────┬───────────────────┬────────┬────────────────┬─────────────────┐
│                 Model │   Dim │  AD Backend │           VarInfo │ Linked │ t(eval)/t(ref) │ t(grad)/t(eval) │
├───────────────────────┼───────┼─────────────┼───────────────────┼────────┼────────────────┼─────────────────┤
│ Simple assume observe │     1 │ forwarddiff │             typed │  false │            6.6 │             1.8 │
│           Smorgasbord │   201 │ forwarddiff │             typed │  false │          751.9 │            42.6 │
│           Smorgasbord │   201 │ forwarddiff │ simple_namedtuple │   true │          422.4 │            53.9 │
│           Smorgasbord │   201 │ forwarddiff │           untyped │   true │          804.5 │            38.0 │
│           Smorgasbord │   201 │ forwarddiff │       simple_dict │   true │         7031.5 │            27.7 │
│           Smorgasbord │   201 │ forwarddiff │      typed_vector │   true │          761.0 │            43.2 │
│           Smorgasbord │   201 │ forwarddiff │    untyped_vector │   true │          813.6 │            37.3 │
│           Smorgasbord │   201 │ reversediff │             typed │   true │          930.5 │            45.8 │
│           Smorgasbord │   201 │    mooncake │             typed │   true │          769.7 │             6.1 │
│           Smorgasbord │   201 │      enzyme │             typed │   true │          918.5 │             4.2 │
│    Loop univariate 1k │  1000 │    mooncake │             typed │   true │         4075.2 │             5.8 │
│       Multivariate 1k │  1000 │    mooncake │             typed │   true │         1052.8 │             8.8 │
│   Loop univariate 10k │ 10000 │    mooncake │             typed │   true │        43296.1 │             5.4 │
│      Multivariate 10k │ 10000 │    mooncake │             typed │   true │         8927.8 │            10.4 │
│               Dynamic │    10 │    mooncake │             typed │   true │          126.4 │            11.7 │
│              Submodel │     1 │    mooncake │             typed │   true │            8.7 │             6.4 │
│                   LDA │    12 │ reversediff │             typed │   true │         1028.8 │             2.0 │
└───────────────────────┴───────┴─────────────┴───────────────────┴────────┴────────────────┴─────────────────┘

@github-actions
Copy link
Contributor

DynamicPPL.jl documentation for PR #1132 is available at:
https://TuringLang.github.io/DynamicPPL.jl/previews/PR1132/

@penelopeysm penelopeysm force-pushed the py/ldf branch 2 times, most recently from 248f374 to a4c71e6 Compare November 11, 2025 12:33
@penelopeysm penelopeysm changed the base branch from main to breaking November 11, 2025 12:33
This was referenced Nov 11, 2025
@codecov
Copy link

codecov bot commented Nov 11, 2025

Codecov Report

❌ Patch coverage is 94.20290% with 8 lines in your changes missing coverage. Please review.
✅ Project coverage is 81.67%. Comparing base (4ca9528) to head (62a8746).

Files with missing lines Patch % Lines
src/onlyaccs.jl 70.00% 3 Missing ⚠️
src/compiler.jl 85.71% 2 Missing ⚠️
ext/DynamicPPLEnzymeCoreExt.jl 0.00% 1 Missing ⚠️
src/fasteval.jl 98.61% 1 Missing ⚠️
src/utils.jl 66.66% 1 Missing ⚠️
Additional details and impacted files
@@             Coverage Diff              @@
##           breaking    #1132      +/-   ##
============================================
+ Coverage     81.32%   81.67%   +0.35%     
============================================
  Files            40       42       +2     
  Lines          3807     3919     +112     
============================================
+ Hits           3096     3201     +105     
- Misses          711      718       +7     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants