|
237 | 237 | end |
238 | 238 | end |
239 | 239 |
|
240 | | -@testset verbose = true "AD / SamplingContext" begin |
241 | | - # AD tests for gradient-based samplers need to be run with SamplingContext |
242 | | - # because samplers can potentially use this to define custom behaviour in |
243 | | - # the tilde-pipeline and thus change the code executed during model |
244 | | - # evaluation. |
245 | | - @testset "adtype=$adtype" for adtype in ADTYPES |
246 | | - @testset "alg=$alg" for alg in [ |
247 | | - HMC(0.1, 10; adtype=adtype), |
248 | | - HMCDA(0.8, 0.75; adtype=adtype), |
249 | | - NUTS(1000, 0.8; adtype=adtype), |
250 | | - SGHMC(; learning_rate=0.02, momentum_decay=0.5, adtype=adtype), |
251 | | - SGLD(; stepsize=PolynomialStepsize(0.25), adtype=adtype), |
252 | | - ] |
253 | | - @info "Testing AD for $alg" |
254 | | - |
255 | | - @testset "model=$(model.f)" for model in DEMO_MODELS |
256 | | - rng = StableRNG(123) |
257 | | - spl_model = DynamicPPL.contextualize( |
258 | | - model, DynamicPPL.SamplingContext(rng, DynamicPPL.Sampler(alg)) |
259 | | - ) |
260 | | - @test run_ad(spl_model, adtype; test=true, benchmark=false) isa Any |
261 | | - end |
262 | | - end |
263 | | - end |
264 | | -end |
265 | | - |
266 | 240 | @testset verbose = true "AD / GibbsContext" begin |
267 | | - # Gibbs sampling also needs extra AD testing because the models are |
| 241 | + # Gibbs sampling needs some extra AD testing because the models are |
268 | 242 | # executed with GibbsContext and a subsetted varinfo. (see e.g. |
269 | 243 | # `gibbs_initialstep_recursive` and `gibbs_step_recursive` in |
270 | 244 | # src/mcmc/gibbs.jl -- the code here mimics what happens in those |
|
283 | 257 | model, varnames, deepcopy(global_vi) |
284 | 258 | ) |
285 | 259 | rng = StableRNG(123) |
286 | | - spl_model = DynamicPPL.contextualize( |
287 | | - model, DynamicPPL.SamplingContext(rng, DynamicPPL.Sampler(HMC(0.1, 10))) |
288 | | - ) |
289 | | - @test run_ad(spl_model, adtype; test=true, benchmark=false) isa Any |
| 260 | + @test run_ad(model, adtype; test=true, benchmark=false) isa Any |
290 | 261 | end |
291 | 262 | end |
292 | 263 | end |
|
0 commit comments