@@ -12,21 +12,35 @@ struct RangeAndLinked
1212 is_linked:: Bool
1313end
1414
15- struct FastLDFContext{T<: AbstractVector{<:Real} } <: AbstractContext
15+ struct FastLDFContext{N<: NamedTuple ,T<: AbstractVector{<:Real} } <: AbstractContext
16+ # The ranges of identity VarNames are stored in a NamedTuple for performance
17+ # reasons. For just plain evaluation this doesn't make _that_ much of a
18+ # difference (maybe 1.5x), but when doing AD with Mooncake this makes a HUGE
19+ # difference (around 4x). Of course, the exact numbers depend on the model.
20+ iden_varname_ranges:: N
21+ # This Dict stores the ranges for all other VarNames
1622 varname_ranges:: Dict{VarName,RangeAndLinked}
23+ # The full parameter vector which we index into to get variable values
1724 params:: T
1825end
1926DynamicPPL. NodeTrait (:: FastLDFContext ) = IsLeaf ()
27+ function get_range_and_linked (
28+ ctx:: FastLDFContext , :: VarName{sym,typeof(identity)}
29+ ) where {sym}
30+ return ctx. iden_varname_ranges[sym]
31+ end
32+ function get_range_and_linked (ctx:: FastLDFContext , vn:: VarName )
33+ return ctx. varname_ranges[vn]
34+ end
2035
2136function tilde_assume!! (
2237 ctx:: FastLDFContext , right:: Distribution , vn:: VarName , vi:: OnlyAccsVarInfo
2338)
2439 # Don't need to read the data from the varinfo at all since it's
2540 # all inside the context.
26- range_and_linked = ctx. varname_ranges[vn]
41+ range_and_linked = get_range_and_linked ( ctx, vn)
2742 y = @view ctx. params[range_and_linked. range]
28- is_linked = range_and_linked. is_linked
29- f = if is_linked
43+ f = if range_and_linked. is_linked
3044 from_linked_vec_transform (right)
3145 else
3246 from_vec_transform (right)
5165struct FastLDF{
5266 M<: Model ,
5367 F<: Function ,
68+ N<: NamedTuple ,
5469 AD<: Union{ADTypes.AbstractADType,Nothing} ,
5570 ADP<: Union{Nothing,DI.GradientPrep} ,
5671}
5772 _model:: M
5873 _getlogdensity:: F
74+ # See FastLDFContext for explanation of these two fields
75+ _iden_varname_ranges:: N
5976 _varname_ranges:: Dict{VarName,RangeAndLinked}
6077 _adtype:: AD
6178 _adprep:: ADP
@@ -70,6 +87,7 @@ struct FastLDF{
7087 ) where {syms}
7188 # Figure out which variable corresponds to which index, and
7289 # which variables are linked.
90+ all_iden_ranges = NamedTuple ()
7391 all_ranges = Dict {VarName,RangeAndLinked} ()
7492 offset = 1
7593 for sym in syms
@@ -78,7 +96,16 @@ struct FastLDF{
7896 len = length (md. ranges[idx])
7997 is_linked = md. is_transformed[idx]
8098 range = offset: (offset + len - 1 )
81- all_ranges[vn] = RangeAndLinked (range, is_linked)
99+ if AbstractPPL. getoptic (vn) === identity
100+ all_iden_ranges = merge (
101+ all_iden_ranges,
102+ NamedTuple (
103+ AbstractPPL. getsym (vn) => RangeAndLinked (range, is_linked)
104+ ),
105+ )
106+ else
107+ all_ranges[vn] = RangeAndLinked (range, is_linked)
108+ end
82109 offset += len
83110 end
84111 end
@@ -90,23 +117,32 @@ struct FastLDF{
90117 adtype = tweak_adtype (adtype, model, varinfo)
91118 x = [val for val in varinfo[:]]
92119 DI. prepare_gradient (
93- FastLogDensityAt (model, getlogdensity, all_ranges), adtype, x
120+ FastLogDensityAt (model, getlogdensity, all_iden_ranges, all_ranges),
121+ adtype,
122+ x,
94123 )
95124 end
96125
97- return new {typeof(model),typeof(getlogdensity),typeof(adtype),typeof(prep)} (
98- model, getlogdensity, all_ranges, adtype, prep
126+ return new{
127+ typeof (model),
128+ typeof (getlogdensity),
129+ typeof (all_iden_ranges),
130+ typeof (adtype),
131+ typeof (prep),
132+ }(
133+ model, getlogdensity, all_iden_ranges, all_ranges, adtype, prep
99134 )
100135 end
101136end
102137
103- struct FastLogDensityAt{M<: Model ,F<: Function }
138+ struct FastLogDensityAt{M<: Model ,F<: Function ,N <: NamedTuple }
104139 _model:: M
105140 _getlogdensity:: F
141+ _iden_varname_ranges:: N
106142 _varname_ranges:: Dict{VarName,RangeAndLinked}
107143end
108144function (f:: FastLogDensityAt )(params:: AbstractVector{<:Real} )
109- ctx = FastLDFContext (f. _varname_ranges, params)
145+ ctx = FastLDFContext (f. _iden_varname_ranges, f . _varname_ranges, params)
110146 model = DynamicPPL. setleafcontext (f. _model, ctx)
111147 # This can obviously also be optimised for the case where not
112148 # all accumulators are needed.
@@ -118,14 +154,23 @@ function (f::FastLogDensityAt)(params::AbstractVector{<:Real})
118154end
119155
120156function LogDensityProblems. logdensity (fldf:: FastLDF , params:: AbstractVector{<:Real} )
121- return FastLogDensityAt (fldf. _model, fldf. _getlogdensity, fldf. _varname_ranges)(params)
157+ return FastLogDensityAt (
158+ fldf. _model, fldf. _getlogdensity, fldf. _iden_varname_ranges, fldf. _varname_ranges
159+ )(
160+ params
161+ )
122162end
123163
124164function LogDensityProblems. logdensity_and_gradient (
125165 fldf:: FastLDF , params:: AbstractVector{<:Real}
126166)
127167 return DI. value_and_gradient (
128- FastLogDensityAt (fldf. _model, fldf. _getlogdensity, fldf. _varname_ranges),
168+ FastLogDensityAt (
169+ fldf. _model,
170+ fldf. _getlogdensity,
171+ fldf. _iden_varname_ranges,
172+ fldf. _varname_ranges,
173+ ),
129174 fldf. _adprep,
130175 fldf. _adtype,
131176 params,
0 commit comments