Skip to content

Commit 94ec51f

Browse files
committed
performance optimisations
1 parent 5ed4295 commit 94ec51f

File tree

1 file changed

+55
-12
lines changed

1 file changed

+55
-12
lines changed

src/fastldf.jl

Lines changed: 55 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,21 +12,35 @@ struct RangeAndLinked
1212
is_linked::Bool
1313
end
1414

15-
struct FastLDFContext{T<:AbstractVector{<:Real}} <: AbstractContext
15+
struct FastLDFContext{N<:NamedTuple,T<:AbstractVector{<:Real}} <: AbstractContext
16+
# The ranges of identity VarNames are stored in a NamedTuple for performance
17+
# reasons. For just plain evaluation this doesn't make _that_ much of a
18+
# difference (maybe 1.5x), but when doing AD with Mooncake this makes a HUGE
19+
# difference (around 4x). Of course, the exact numbers depend on the model.
20+
iden_varname_ranges::N
21+
# This Dict stores the ranges for all other VarNames
1622
varname_ranges::Dict{VarName,RangeAndLinked}
23+
# The full parameter vector which we index into to get variable values
1724
params::T
1825
end
1926
DynamicPPL.NodeTrait(::FastLDFContext) = IsLeaf()
27+
function get_range_and_linked(
28+
ctx::FastLDFContext, ::VarName{sym,typeof(identity)}
29+
) where {sym}
30+
return ctx.iden_varname_ranges[sym]
31+
end
32+
function get_range_and_linked(ctx::FastLDFContext, vn::VarName)
33+
return ctx.varname_ranges[vn]
34+
end
2035

2136
function tilde_assume!!(
2237
ctx::FastLDFContext, right::Distribution, vn::VarName, vi::OnlyAccsVarInfo
2338
)
2439
# Don't need to read the data from the varinfo at all since it's
2540
# all inside the context.
26-
range_and_linked = ctx.varname_ranges[vn]
41+
range_and_linked = get_range_and_linked(ctx, vn)
2742
y = @view ctx.params[range_and_linked.range]
28-
is_linked = range_and_linked.is_linked
29-
f = if is_linked
43+
f = if range_and_linked.is_linked
3044
from_linked_vec_transform(right)
3145
else
3246
from_vec_transform(right)
@@ -51,11 +65,14 @@ end
5165
struct FastLDF{
5266
M<:Model,
5367
F<:Function,
68+
N<:NamedTuple,
5469
AD<:Union{ADTypes.AbstractADType,Nothing},
5570
ADP<:Union{Nothing,DI.GradientPrep},
5671
}
5772
_model::M
5873
_getlogdensity::F
74+
# See FastLDFContext for explanation of these two fields
75+
_iden_varname_ranges::N
5976
_varname_ranges::Dict{VarName,RangeAndLinked}
6077
_adtype::AD
6178
_adprep::ADP
@@ -70,6 +87,7 @@ struct FastLDF{
7087
) where {syms}
7188
# Figure out which variable corresponds to which index, and
7289
# which variables are linked.
90+
all_iden_ranges = NamedTuple()
7391
all_ranges = Dict{VarName,RangeAndLinked}()
7492
offset = 1
7593
for sym in syms
@@ -78,7 +96,14 @@ struct FastLDF{
7896
len = length(md.ranges[idx])
7997
is_linked = md.is_transformed[idx]
8098
range = offset:(offset + len - 1)
81-
all_ranges[vn] = RangeAndLinked(range, is_linked)
99+
if AbstractPPL.getoptic(vn) === identity
100+
all_iden_ranges = merge(
101+
all_iden_ranges,
102+
NamedTuple{(Symbol(vn),)}((RangeAndLinked(range, is_linked),)),
103+
)
104+
else
105+
all_ranges[vn] = RangeAndLinked(range, is_linked)
106+
end
82107
offset += len
83108
end
84109
end
@@ -90,23 +115,32 @@ struct FastLDF{
90115
adtype = tweak_adtype(adtype, model, varinfo)
91116
x = [val for val in varinfo[:]]
92117
DI.prepare_gradient(
93-
FastLogDensityAt(model, getlogdensity, all_ranges), adtype, x
118+
FastLogDensityAt(model, getlogdensity, all_iden_ranges, all_ranges),
119+
adtype,
120+
x,
94121
)
95122
end
96123

97-
return new{typeof(model),typeof(getlogdensity),typeof(adtype),typeof(prep)}(
98-
model, getlogdensity, all_ranges, adtype, prep
124+
return new{
125+
typeof(model),
126+
typeof(getlogdensity),
127+
typeof(all_iden_ranges),
128+
typeof(adtype),
129+
typeof(prep),
130+
}(
131+
model, getlogdensity, all_iden_ranges, all_ranges, adtype, prep
99132
)
100133
end
101134
end
102135

103-
struct FastLogDensityAt{M<:Model,F<:Function}
136+
struct FastLogDensityAt{M<:Model,F<:Function,N<:NamedTuple}
104137
_model::M
105138
_getlogdensity::F
139+
_iden_varname_ranges::N
106140
_varname_ranges::Dict{VarName,RangeAndLinked}
107141
end
108142
function (f::FastLogDensityAt)(params::AbstractVector{<:Real})
109-
ctx = FastLDFContext(f._varname_ranges, params)
143+
ctx = FastLDFContext(f._iden_varname_ranges, f._varname_ranges, params)
110144
model = DynamicPPL.setleafcontext(f._model, ctx)
111145
# This can obviously also be optimised for the case where not
112146
# all accumulators are needed.
@@ -118,14 +152,23 @@ function (f::FastLogDensityAt)(params::AbstractVector{<:Real})
118152
end
119153

120154
function LogDensityProblems.logdensity(fldf::FastLDF, params::AbstractVector{<:Real})
121-
return FastLogDensityAt(fldf._model, fldf._getlogdensity, fldf._varname_ranges)(params)
155+
return FastLogDensityAt(
156+
fldf._model, fldf._getlogdensity, fldf._iden_varname_ranges, fldf._varname_ranges
157+
)(
158+
params
159+
)
122160
end
123161

124162
function LogDensityProblems.logdensity_and_gradient(
125163
fldf::FastLDF, params::AbstractVector{<:Real}
126164
)
127165
return DI.value_and_gradient(
128-
FastLogDensityAt(fldf._model, fldf._getlogdensity, fldf._varname_ranges),
166+
FastLogDensityAt(
167+
fldf._model,
168+
fldf._getlogdensity,
169+
fldf._iden_varname_ranges,
170+
fldf._varname_ranges,
171+
),
129172
fldf._adprep,
130173
fldf._adtype,
131174
params,

0 commit comments

Comments
 (0)