Skip to content

Commit f8b6fbd

Browse files
committed
Optimise performance for identity VarNames
1 parent 5ed4295 commit f8b6fbd

File tree

1 file changed

+57
-12
lines changed

1 file changed

+57
-12
lines changed

src/fastldf.jl

Lines changed: 57 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,21 +12,35 @@ struct RangeAndLinked
1212
is_linked::Bool
1313
end
1414

15-
struct FastLDFContext{T<:AbstractVector{<:Real}} <: AbstractContext
15+
struct FastLDFContext{N<:NamedTuple,T<:AbstractVector{<:Real}} <: AbstractContext
16+
# The ranges of identity VarNames are stored in a NamedTuple for performance
17+
# reasons. For just plain evaluation this doesn't make _that_ much of a
18+
# difference (maybe 1.5x), but when doing AD with Mooncake this makes a HUGE
19+
# difference (around 4x). Of course, the exact numbers depend on the model.
20+
iden_varname_ranges::N
21+
# This Dict stores the ranges for all other VarNames
1622
varname_ranges::Dict{VarName,RangeAndLinked}
23+
# The full parameter vector which we index into to get variable values
1724
params::T
1825
end
1926
DynamicPPL.NodeTrait(::FastLDFContext) = IsLeaf()
27+
function get_range_and_linked(
28+
ctx::FastLDFContext, ::VarName{sym,typeof(identity)}
29+
) where {sym}
30+
return ctx.iden_varname_ranges[sym]
31+
end
32+
function get_range_and_linked(ctx::FastLDFContext, vn::VarName)
33+
return ctx.varname_ranges[vn]
34+
end
2035

2136
function tilde_assume!!(
2237
ctx::FastLDFContext, right::Distribution, vn::VarName, vi::OnlyAccsVarInfo
2338
)
2439
# Don't need to read the data from the varinfo at all since it's
2540
# all inside the context.
26-
range_and_linked = ctx.varname_ranges[vn]
41+
range_and_linked = get_range_and_linked(ctx, vn)
2742
y = @view ctx.params[range_and_linked.range]
28-
is_linked = range_and_linked.is_linked
29-
f = if is_linked
43+
f = if range_and_linked.is_linked
3044
from_linked_vec_transform(right)
3145
else
3246
from_vec_transform(right)
@@ -51,11 +65,14 @@ end
5165
struct FastLDF{
5266
M<:Model,
5367
F<:Function,
68+
N<:NamedTuple,
5469
AD<:Union{ADTypes.AbstractADType,Nothing},
5570
ADP<:Union{Nothing,DI.GradientPrep},
5671
}
5772
_model::M
5873
_getlogdensity::F
74+
# See FastLDFContext for explanation of these two fields
75+
_iden_varname_ranges::N
5976
_varname_ranges::Dict{VarName,RangeAndLinked}
6077
_adtype::AD
6178
_adprep::ADP
@@ -70,6 +87,7 @@ struct FastLDF{
7087
) where {syms}
7188
# Figure out which variable corresponds to which index, and
7289
# which variables are linked.
90+
all_iden_ranges = NamedTuple()
7391
all_ranges = Dict{VarName,RangeAndLinked}()
7492
offset = 1
7593
for sym in syms
@@ -78,7 +96,16 @@ struct FastLDF{
7896
len = length(md.ranges[idx])
7997
is_linked = md.is_transformed[idx]
8098
range = offset:(offset + len - 1)
81-
all_ranges[vn] = RangeAndLinked(range, is_linked)
99+
if AbstractPPL.getoptic(vn) === identity
100+
all_iden_ranges = merge(
101+
all_iden_ranges,
102+
NamedTuple(
103+
AbstractPPL.getsym(vn) => RangeAndLinked(range, is_linked)
104+
),
105+
)
106+
else
107+
all_ranges[vn] = RangeAndLinked(range, is_linked)
108+
end
82109
offset += len
83110
end
84111
end
@@ -90,23 +117,32 @@ struct FastLDF{
90117
adtype = tweak_adtype(adtype, model, varinfo)
91118
x = [val for val in varinfo[:]]
92119
DI.prepare_gradient(
93-
FastLogDensityAt(model, getlogdensity, all_ranges), adtype, x
120+
FastLogDensityAt(model, getlogdensity, all_iden_ranges, all_ranges),
121+
adtype,
122+
x,
94123
)
95124
end
96125

97-
return new{typeof(model),typeof(getlogdensity),typeof(adtype),typeof(prep)}(
98-
model, getlogdensity, all_ranges, adtype, prep
126+
return new{
127+
typeof(model),
128+
typeof(getlogdensity),
129+
typeof(all_iden_ranges),
130+
typeof(adtype),
131+
typeof(prep),
132+
}(
133+
model, getlogdensity, all_iden_ranges, all_ranges, adtype, prep
99134
)
100135
end
101136
end
102137

103-
struct FastLogDensityAt{M<:Model,F<:Function}
138+
struct FastLogDensityAt{M<:Model,F<:Function,N<:NamedTuple}
104139
_model::M
105140
_getlogdensity::F
141+
_iden_varname_ranges::N
106142
_varname_ranges::Dict{VarName,RangeAndLinked}
107143
end
108144
function (f::FastLogDensityAt)(params::AbstractVector{<:Real})
109-
ctx = FastLDFContext(f._varname_ranges, params)
145+
ctx = FastLDFContext(f._iden_varname_ranges, f._varname_ranges, params)
110146
model = DynamicPPL.setleafcontext(f._model, ctx)
111147
# This can obviously also be optimised for the case where not
112148
# all accumulators are needed.
@@ -118,14 +154,23 @@ function (f::FastLogDensityAt)(params::AbstractVector{<:Real})
118154
end
119155

120156
function LogDensityProblems.logdensity(fldf::FastLDF, params::AbstractVector{<:Real})
121-
return FastLogDensityAt(fldf._model, fldf._getlogdensity, fldf._varname_ranges)(params)
157+
return FastLogDensityAt(
158+
fldf._model, fldf._getlogdensity, fldf._iden_varname_ranges, fldf._varname_ranges
159+
)(
160+
params
161+
)
122162
end
123163

124164
function LogDensityProblems.logdensity_and_gradient(
125165
fldf::FastLDF, params::AbstractVector{<:Real}
126166
)
127167
return DI.value_and_gradient(
128-
FastLogDensityAt(fldf._model, fldf._getlogdensity, fldf._varname_ranges),
168+
FastLogDensityAt(
169+
fldf._model,
170+
fldf._getlogdensity,
171+
fldf._iden_varname_ranges,
172+
fldf._varname_ranges,
173+
),
129174
fldf._adprep,
130175
fldf._adtype,
131176
params,

0 commit comments

Comments
 (0)