@@ -4,7 +4,14 @@ using ADTypes: AbstractADType, AutoForwardDiff
44using Chairmarks: @be
55import DifferentiationInterface as DI
66using DocStringExtensions
7- using DynamicPPL: Model, LogDensityFunction, VarInfo, AbstractVarInfo, link
7+ using DynamicPPL:
8+ Model,
9+ LogDensityFunction,
10+ VarInfo,
11+ AbstractVarInfo,
12+ link,
13+ DefaultContext,
14+ AbstractContext
815using LogDensityProblems: logdensity, logdensity_and_gradient
916using Random: Random, Xoshiro
1017using Statistics: median
@@ -53,6 +60,8 @@ struct ADResult{Tparams<:AbstractFloat,Tresult<:AbstractFloat}
5360 model:: Model
5461 " The VarInfo that was used"
5562 varinfo:: AbstractVarInfo
63+ " The evaluation context that was used"
64+ context:: AbstractContext
5665 " The values at which the model was evaluated"
5766 params:: Vector{Tparams}
5867 " The AD backend that was tested"
8392 grad_atol=1e-6,
8493 varinfo::AbstractVarInfo=link(VarInfo(model), model),
8594 params::Union{Nothing,Vector{<:AbstractFloat}}=nothing,
95+ context::AbstractContext=DefaultContext(),
8696 reference_adtype::ADTypes.AbstractADType=REFERENCE_ADTYPE,
8797 expected_value_and_grad::Union{Nothing,Tuple{AbstractFloat,Vector{<:AbstractFloat}}}=nothing,
8898 verbose=true,
@@ -136,7 +146,13 @@ Everything else is optional, and can be categorised into several groups:
136146 prep_params)`. You could then evaluate the gradient at a different set of
137147 parameters using the `params` keyword argument.
138148
139- 3. _How to specify the results to compare against._ (Only if `test=true`.)
149+ 3. _How to specify the evaluation context._
150+
151+ A `DynamicPPL.AbstractContext` can be passed as the `context` keyword
152+ argument to control the evaluation context. This defaults to
153+ `DefaultContext()`.
154+
155+ 4. _How to specify the results to compare against._ (Only if `test=true`.)
140156
141157 Once logp and its gradient has been calculated with the specified `adtype`,
142158 it must be tested for correctness.
@@ -151,12 +167,12 @@ Everything else is optional, and can be categorised into several groups:
151167 The default reference backend is ForwardDiff. If none of these parameters are
152168 specified, ForwardDiff will be used to calculate the ground truth.
153169
154- 4 . _How to specify the tolerances._ (Only if `test=true`.)
170+ 5 . _How to specify the tolerances._ (Only if `test=true`.)
155171
156172 The tolerances for the value and gradient can be set using `value_atol` and
157173 `grad_atol`. These default to 1e-6.
158174
159- 5 . _Whether to output extra logging information._
175+ 6 . _Whether to output extra logging information._
160176
161177 By default, this function prints messages when it runs. To silence it, set
162178 `verbose=false`.
@@ -179,6 +195,7 @@ function run_ad(
179195 grad_atol:: AbstractFloat = 1e-6 ,
180196 varinfo:: AbstractVarInfo = link (VarInfo (model), model),
181197 params:: Union{Nothing,Vector{<:AbstractFloat}} = nothing ,
198+ context:: AbstractContext = DefaultContext (),
182199 reference_adtype:: AbstractADType = REFERENCE_ADTYPE,
183200 expected_value_and_grad:: Union{Nothing,Tuple{AbstractFloat,Vector{<:AbstractFloat}}} = nothing ,
184201 verbose= true ,
@@ -190,7 +207,7 @@ function run_ad(
190207
191208 verbose && @info " Running AD on $(model. f) with $(adtype) \n "
192209 verbose && println (" params : $(params) " )
193- ldf = LogDensityFunction (model, varinfo; adtype= adtype)
210+ ldf = LogDensityFunction (model, varinfo, context ; adtype= adtype)
194211
195212 value, grad = logdensity_and_gradient (ldf, params)
196213 grad = collect (grad)
@@ -199,7 +216,7 @@ function run_ad(
199216 if test
200217 # Calculate ground truth to compare against
201218 value_true, grad_true = if expected_value_and_grad === nothing
202- ldf_reference = LogDensityFunction (model, varinfo; adtype= reference_adtype)
219+ ldf_reference = LogDensityFunction (model, varinfo, context ; adtype= reference_adtype)
203220 logdensity_and_gradient (ldf_reference, params)
204221 else
205222 expected_value_and_grad
@@ -228,6 +245,7 @@ function run_ad(
228245 return ADResult (
229246 model,
230247 varinfo,
248+ context,
231249 params,
232250 adtype,
233251 value_atol,
0 commit comments