@@ -12,21 +12,35 @@ struct RangeAndLinked
1212 is_linked:: Bool
1313end
1414
15- struct FastLDFContext{T<: AbstractVector{<:Real} } <: AbstractContext
15+ struct FastLDFContext{N<: NamedTuple ,T<: AbstractVector{<:Real} } <: AbstractContext
16+ # The ranges of identity VarNames are stored in a NamedTuple for performance
17+ # reasons. For just plain evaluation this doesn't make _that_ much of a
18+ # difference (maybe 1.5x), but when doing AD with Mooncake this makes a HUGE
19+ # difference (around 4x). Of course, the exact numbers depend on the model.
20+ iden_varname_ranges:: N
21+ # This Dict stores the ranges for all other VarNames
1622 varname_ranges:: Dict{VarName,RangeAndLinked}
23+ # The full parameter vector which we index into to get variable values
1724 params:: T
1825end
1926DynamicPPL. NodeTrait (:: FastLDFContext ) = IsLeaf ()
27+ function get_range_and_linked (
28+ ctx:: FastLDFContext , :: VarName{sym,typeof(identity)}
29+ ) where {sym}
30+ return ctx. iden_varname_ranges[sym]
31+ end
32+ function get_range_and_linked (ctx:: FastLDFContext , vn:: VarName )
33+ return ctx. varname_ranges[vn]
34+ end
2035
2136function tilde_assume!! (
2237 ctx:: FastLDFContext , right:: Distribution , vn:: VarName , vi:: OnlyAccsVarInfo
2338)
2439 # Don't need to read the data from the varinfo at all since it's
2540 # all inside the context.
26- range_and_linked = ctx. varname_ranges[vn]
41+ range_and_linked = get_range_and_linked ( ctx, vn)
2742 y = @view ctx. params[range_and_linked. range]
28- is_linked = range_and_linked. is_linked
29- f = if is_linked
43+ f = if range_and_linked. is_linked
3044 from_linked_vec_transform (right)
3145 else
3246 from_vec_transform (right)
5165struct FastLDF{
5266 M<: Model ,
5367 F<: Function ,
68+ N<: NamedTuple ,
5469 AD<: Union{ADTypes.AbstractADType,Nothing} ,
5570 ADP<: Union{Nothing,DI.GradientPrep} ,
5671}
5772 _model:: M
5873 _getlogdensity:: F
74+ # See FastLDFContext for explanation of these two fields
75+ _iden_varname_ranges:: N
5976 _varname_ranges:: Dict{VarName,RangeAndLinked}
6077 _adtype:: AD
6178 _adprep:: ADP
@@ -70,6 +87,7 @@ struct FastLDF{
7087 ) where {syms}
7188 # Figure out which variable corresponds to which index, and
7289 # which variables are linked.
90+ all_iden_ranges = NamedTuple ()
7391 all_ranges = Dict {VarName,RangeAndLinked} ()
7492 offset = 1
7593 for sym in syms
@@ -78,7 +96,14 @@ struct FastLDF{
7896 len = length (md. ranges[idx])
7997 is_linked = md. is_transformed[idx]
8098 range = offset: (offset + len - 1 )
81- all_ranges[vn] = RangeAndLinked (range, is_linked)
99+ if AbstractPPL. getoptic (vn) === identity
100+ all_iden_ranges = merge (
101+ all_iden_ranges,
102+ NamedTuple {(Symbol(vn),)} ((RangeAndLinked (range, is_linked),)),
103+ )
104+ else
105+ all_ranges[vn] = RangeAndLinked (range, is_linked)
106+ end
82107 offset += len
83108 end
84109 end
@@ -90,23 +115,32 @@ struct FastLDF{
90115 adtype = tweak_adtype (adtype, model, varinfo)
91116 x = [val for val in varinfo[:]]
92117 DI. prepare_gradient (
93- FastLogDensityAt (model, getlogdensity, all_ranges), adtype, x
118+ FastLogDensityAt (model, getlogdensity, all_iden_ranges, all_ranges),
119+ adtype,
120+ x,
94121 )
95122 end
96123
97- return new {typeof(model),typeof(getlogdensity),typeof(adtype),typeof(prep)} (
98- model, getlogdensity, all_ranges, adtype, prep
124+ return new{
125+ typeof (model),
126+ typeof (getlogdensity),
127+ typeof (all_iden_ranges),
128+ typeof (adtype),
129+ typeof (prep),
130+ }(
131+ model, getlogdensity, all_iden_ranges, all_ranges, adtype, prep
99132 )
100133 end
101134end
102135
103- struct FastLogDensityAt{M<: Model ,F<: Function }
136+ struct FastLogDensityAt{M<: Model ,F<: Function ,N <: NamedTuple }
104137 _model:: M
105138 _getlogdensity:: F
139+ _iden_varname_ranges:: N
106140 _varname_ranges:: Dict{VarName,RangeAndLinked}
107141end
108142function (f:: FastLogDensityAt )(params:: AbstractVector{<:Real} )
109- ctx = FastLDFContext (f. _varname_ranges, params)
143+ ctx = FastLDFContext (f. _iden_varname_ranges, f . _varname_ranges, params)
110144 model = DynamicPPL. setleafcontext (f. _model, ctx)
111145 # This can obviously also be optimised for the case where not
112146 # all accumulators are needed.
@@ -118,14 +152,23 @@ function (f::FastLogDensityAt)(params::AbstractVector{<:Real})
118152end
119153
120154function LogDensityProblems. logdensity (fldf:: FastLDF , params:: AbstractVector{<:Real} )
121- return FastLogDensityAt (fldf. _model, fldf. _getlogdensity, fldf. _varname_ranges)(params)
155+ return FastLogDensityAt (
156+ fldf. _model, fldf. _getlogdensity, fldf. _iden_varname_ranges, fldf. _varname_ranges
157+ )(
158+ params
159+ )
122160end
123161
124162function LogDensityProblems. logdensity_and_gradient (
125163 fldf:: FastLDF , params:: AbstractVector{<:Real}
126164)
127165 return DI. value_and_gradient (
128- FastLogDensityAt (fldf. _model, fldf. _getlogdensity, fldf. _varname_ranges),
166+ FastLogDensityAt (
167+ fldf. _model,
168+ fldf. _getlogdensity,
169+ fldf. _iden_varname_ranges,
170+ fldf. _varname_ranges,
171+ ),
129172 fldf. _adprep,
130173 fldf. _adtype,
131174 params,
0 commit comments