|
| 1 | +using .ReverseDiff: compile, GradientTape |
| 2 | +using .ReverseDiff.DiffResults: GradientResult |
| 3 | + |
| 4 | +struct ReverseDiffAD{cache} <: ADBackend end |
| 5 | +const RDCache = Ref(false) |
| 6 | +setrdcache(b::Bool) = setrdcache(Val(b)) |
| 7 | +setrdcache(::Val{false}) = RDCache[] = false |
| 8 | +setrdcache(::Val) = throw("Memoization.jl is not loaded. Please load it before setting the cache to true.") |
| 9 | +function emptyrdcache end |
| 10 | + |
| 11 | +getrdcache() = RDCache[] |
| 12 | +ADBackend(::Val{:reversediff}) = ReverseDiffAD{getrdcache()} |
| 13 | +function setadbackend(::Val{:reversediff}) |
| 14 | + ADBACKEND[] = :reversediff |
| 15 | +end |
| 16 | + |
| 17 | +function gradient_logp( |
| 18 | + backend::ReverseDiffAD{false}, |
| 19 | + θ::AbstractVector{<:Real}, |
| 20 | + vi::VarInfo, |
| 21 | + model::Model, |
| 22 | + sampler::AbstractSampler = SampleFromPrior(), |
| 23 | +) |
| 24 | + T = typeof(getlogp(vi)) |
| 25 | + |
| 26 | + # Specify objective function. |
| 27 | + function f(θ) |
| 28 | + new_vi = VarInfo(vi, sampler, θ) |
| 29 | + model(new_vi, sampler) |
| 30 | + return getlogp(new_vi) |
| 31 | + end |
| 32 | + tp, result = taperesult(f, θ) |
| 33 | + ReverseDiff.gradient!(result, tp, θ) |
| 34 | + l = DiffResults.value(result) |
| 35 | + ∂l∂θ::typeof(θ) = DiffResults.gradient(result) |
| 36 | + |
| 37 | + return l, ∂l∂θ |
| 38 | +end |
| 39 | + |
| 40 | +tape(f, x) = GradientTape(f, x) |
| 41 | +function taperesult(f, x) |
| 42 | + return tape(f, x), GradientResult(x) |
| 43 | +end |
| 44 | + |
| 45 | +@require Memoization = "6fafb56a-5788-4b4e-91ca-c0cea6611c73" @eval begin |
| 46 | + setrdcache(::Val{true}) = RDCache[] = true |
| 47 | + function emptyrdcache() |
| 48 | + for k in keys(Memoization.caches) |
| 49 | + if k[1] === typeof(memoized_taperesult) |
| 50 | + pop!(Memoization.caches, k) |
| 51 | + end |
| 52 | + end |
| 53 | + end |
| 54 | + function gradient_logp( |
| 55 | + backend::ReverseDiffAD{true}, |
| 56 | + θ::AbstractVector{<:Real}, |
| 57 | + vi::VarInfo, |
| 58 | + model::Model, |
| 59 | + sampler::AbstractSampler = SampleFromPrior(), |
| 60 | + ) |
| 61 | + T = typeof(getlogp(vi)) |
| 62 | + |
| 63 | + # Specify objective function. |
| 64 | + function f(θ) |
| 65 | + new_vi = VarInfo(vi, sampler, θ) |
| 66 | + model(new_vi, sampler) |
| 67 | + return getlogp(new_vi) |
| 68 | + end |
| 69 | + ctp, result = memoized_taperesult(f, θ) |
| 70 | + ReverseDiff.gradient!(result, ctp, θ) |
| 71 | + l = DiffResults.value(result) |
| 72 | + ∂l∂θ = DiffResults.gradient(result) |
| 73 | + |
| 74 | + return l, ∂l∂θ |
| 75 | + end |
| 76 | + |
| 77 | + # This makes sure we generate a single tape per Turing model and sampler |
| 78 | + struct RDTapeKey{F, Tx} |
| 79 | + f::F |
| 80 | + x::Tx |
| 81 | + end |
| 82 | + function Memoization._get!(f::Union{Function, Type}, d::IdDict, keys::Tuple{Tuple{RDTapeKey}, Any}) |
| 83 | + key = keys[1][1] |
| 84 | + return Memoization._get!(f, d, (typeof(key.f), typeof(key.x), size(key.x))) |
| 85 | + end |
| 86 | + memoized_taperesult(f, x) = memoized_taperesult(RDTapeKey(f, x)) |
| 87 | + Memoization.@memoize function memoized_taperesult(k::RDTapeKey) |
| 88 | + return compiledtape(k.f, k.x), GradientResult(k.x) |
| 89 | + end |
| 90 | + memoized_tape(f, x) = memoized_tape(RDTapeKey(f, x)) |
| 91 | + Memoization.@memoize memoized_tape(k::RDTapeKey) = compiledtape(k.f, k.x) |
| 92 | + compiledtape(f, x) = compile(GradientTape(f, x)) |
| 93 | +end |
0 commit comments