@@ -20,8 +20,9 @@ using DynamicPPL:
2020 hasconditioned_nested,
2121 getconditioned_nested,
2222 collapse_prefix_stack,
23- prefix_cond_and_fixed_variables,
24- getvalue
23+ prefix_cond_and_fixed_variables
24+ using LinearAlgebra: I
25+ using Random: Xoshiro
2526
2627using EnzymeCore
2728
@@ -103,7 +104,7 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown()
103104 # sometimes only the main symbol (e.g. it contains `x` when
104105 # `vn` is `x[1]`)
105106 for vn in conditioned_vns
106- val = DynamicPPL . getvalue (conditioned_values, vn)
107+ val = getvalue (conditioned_values, vn)
107108 # These VarNames are present in the conditioning values, so
108109 # we should always be able to extract the value.
109110 @test hasconditioned_nested (context, vn)
@@ -433,12 +434,180 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown()
433434 end
434435
435436 @testset " InitContext" begin
436- @testset " PriorInit" begin end
437+ empty_varinfos = [
438+ VarInfo (),
439+ DynamicPPL. typed_varinfo (VarInfo ()),
440+ VarInfo (DynamicPPL. VarNamedVector ()),
441+ DynamicPPL. typed_vector_varinfo (DynamicPPL. typed_varinfo (VarInfo ())),
442+ SimpleVarInfo (),
443+ SimpleVarInfo (Dict {VarName,Any} ()),
444+ ]
445+
446+ @model function test_init_model ()
447+ x ~ Normal ()
448+ y ~ MvNormal (fill (x, 2 ), I)
449+ 1.0 ~ Normal ()
450+ return nothing
451+ end
452+ function test_generating_new_values (strategy:: AbstractInitStrategy )
453+ @testset " generating new values: $(typeof (strategy)) " begin
454+ # Check that init!! can generate values that weren't there
455+ # previously.
456+ model = test_init_model ()
457+ for empty_vi in empty_varinfos
458+ this_vi = deepcopy (empty_vi)
459+ _, vi = DynamicPPL. init!! (model, this_vi, strategy)
460+ @test Set (keys (vi)) == Set ([@varname (x), @varname (y)])
461+ x, y = vi[@varname (x)], vi[@varname (y)]
462+ @test x isa Real
463+ @test y isa AbstractVector{<: Real }
464+ @test length (y) == 2
465+ (; logprior, loglikelihood) = getlogp (vi)
466+ @test logpdf (Normal (), x) + logpdf (MvNormal (fill (x, 2 ), I), y) ==
467+ logprior
468+ @test logpdf (Normal (), 1.0 ) == loglikelihood
469+ end
470+ end
471+ end
472+ function test_replacing_values (strategy:: AbstractInitStrategy )
473+ @testset " replacing old values: $(typeof (strategy)) " begin
474+ # Check that init!! can overwrite values that were already there.
475+ model = test_init_model ()
476+ for empty_vi in empty_varinfos
477+ # start by generating some rubbish values
478+ vi = deepcopy (empty_vi)
479+ old_x, old_y = 100000.00 , [300000.00 , 500000.00 ]
480+ push!! (vi, @varname (x), old_x, Normal ())
481+ push!! (vi, @varname (y), old_y, MvNormal (fill (old_x, 2 ), I))
482+ # then overwrite it
483+ _, new_vi = DynamicPPL. init!! (model, vi, strategy)
484+ new_x, new_y = new_vi[@varname (x)], new_vi[@varname (y)]
485+ # check that the values are (presumably) different
486+ @test old_x != new_x
487+ @test old_y != new_y
488+ end
489+ end
490+ end
491+ function test_rng_respected (strategy:: AbstractInitStrategy )
492+ @testset " check that RNG is respected: $(typeof (strategy)) " begin
493+ model = test_init_model ()
494+ for empty_vi in empty_varinfos
495+ _, vi1 = DynamicPPL. init!! (
496+ Xoshiro (468 ), model, deepcopy (empty_vi), strategy
497+ )
498+ _, vi2 = DynamicPPL. init!! (
499+ Xoshiro (468 ), model, deepcopy (empty_vi), strategy
500+ )
501+ _, vi3 = DynamicPPL. init!! (
502+ Xoshiro (469 ), model, deepcopy (empty_vi), strategy
503+ )
504+ @test vi1[@varname (x)] == vi2[@varname (x)]
505+ @test vi1[@varname (y)] == vi2[@varname (y)]
506+ @test vi1[@varname (x)] != vi3[@varname (x)]
507+ @test vi1[@varname (y)] != vi3[@varname (y)]
508+ end
509+ end
510+ end
437511
438- @testset " UniformInit" begin end
512+ @testset " PriorInit" begin
513+ test_generating_new_values (PriorInit ())
514+ test_replacing_values (PriorInit ())
515+ test_rng_respected (PriorInit ())
516+
517+ @testset " check that values are within support" begin
518+ # Not many other sensible checks we can do for priors.
519+ @model just_unif () = x ~ Uniform (0.0 , 1e-7 )
520+ for _ in 1 : 100
521+ _, vi = DynamicPPL. init!! (just_unif (), VarInfo (), PriorInit ())
522+ @test vi[@varname (x)] isa Real
523+ @test 0.0 <= vi[@varname (x)] <= 1e-7
524+ end
525+ end
526+ end
439527
440- @testset " ParamsInit" begin end
528+ @testset " UniformInit" begin
529+ test_generating_new_values (UniformInit ())
530+ test_replacing_values (UniformInit ())
531+ test_rng_respected (UniformInit ())
532+
533+ @testset " check that bounds are respected" begin
534+ @testset " unconstrained" begin
535+ umin, umax = - 1.0 , 1.0
536+ @model just_norm () = x ~ Normal ()
537+ for _ in 1 : 100
538+ _, vi = DynamicPPL. init!! (
539+ just_norm (), VarInfo (), UniformInit (umin, umax)
540+ )
541+ @test vi[@varname (x)] isa Real
542+ @test umin <= vi[@varname (x)] <= umax
543+ end
544+ end
545+ @testset " constrained" begin
546+ umin, umax = - 1.0 , 1.0
547+ @model just_beta () = x ~ Beta (2 , 2 )
548+ inv_bijector = inverse (Bijectors. bijector (Beta (2 , 2 )))
549+ tmin, tmax = inv_bijector (umin), inv_bijector (umax)
550+ for _ in 1 : 100
551+ _, vi = DynamicPPL. init!! (
552+ just_beta (), VarInfo (), UniformInit (umin, umax)
553+ )
554+ @test vi[@varname (x)] isa Real
555+ @test tmin <= vi[@varname (x)] <= tmax
556+ end
557+ end
558+ end
559+ end
441560
442- @testset " rng is respected (at least with PriorInit" begin end
561+ @testset " ParamsInit" begin
562+ @testset " given full set of parameters" begin
563+ # test_init_model has x ~ Normal() and y ~ MvNormal(zeros(2), I)
564+ my_x, my_y = 1.0 , [2.0 , 3.0 ]
565+ params_nt = (; x= my_x, y= my_y)
566+ params_dict = Dict (@varname (x) => my_x, @varname (y) => my_y)
567+ model = test_init_model ()
568+ for empty_vi in empty_varinfos
569+ _, vi = DynamicPPL. init!! (
570+ model, deepcopy (empty_vi), ParamsInit (params_nt)
571+ )
572+ @test vi[@varname (x)] == my_x
573+ @test vi[@varname (y)] == my_y
574+ logp_nt = getlogp (vi)
575+ _, vi = DynamicPPL. init!! (
576+ model, deepcopy (empty_vi), ParamsInit (params_dict)
577+ )
578+ @test vi[@varname (x)] == my_x
579+ @test vi[@varname (y)] == my_y
580+ logp_dict = getlogp (vi)
581+ @test logp_nt == logp_dict
582+ end
583+ end
584+
585+ @testset " given only partial parameters" begin
586+ # In this case, we expect `ParamsInit` to use the value of x, and
587+ # generate a new value for y.
588+ my_x = 1.0
589+ params_nt = (; x= my_x)
590+ params_dict = Dict (@varname (x) => my_x)
591+ model = test_init_model ()
592+ for empty_vi in empty_varinfos
593+ _, vi = DynamicPPL. init!! (
594+ Xoshiro (468 ), model, deepcopy (empty_vi), ParamsInit (params_nt)
595+ )
596+ @test vi[@varname (x)] == my_x
597+ nt_y = vi[@varname (y)]
598+ @test nt_y isa AbstractVector{<: Real }
599+ @test length (nt_y) == 2
600+ _, vi = DynamicPPL. init!! (
601+ Xoshiro (469 ), model, deepcopy (empty_vi), ParamsInit (params_dict)
602+ )
603+ @test vi[@varname (x)] == my_x
604+ dict_y = vi[@varname (y)]
605+ @test dict_y isa AbstractVector{<: Real }
606+ @test length (dict_y) == 2
607+ # the values should be different since we used different seeds
608+ @test dict_y != nt_y
609+ end
610+ end
611+ end
443612 end
444613end
0 commit comments