11using DynamicPPL: LogDensityFunction
22
33@testset " Automatic differentiation" begin
4+ # Used as the ground truth that others are compared against.
5+ ref_adtype = AutoForwardDiff ()
6+ test_adtypes = [
7+ AutoReverseDiff (; compile= false ),
8+ AutoReverseDiff (; compile= true ),
9+ AutoMooncake (; config= nothing ),
10+ ]
11+
412 @testset " Unsupported backends" begin
513 @model demo () = x ~ Normal ()
614 @test_logs (:warn , r" not officially supported" ) LogDensityFunction (
@@ -18,15 +26,10 @@ using DynamicPPL: LogDensityFunction
1826 f = LogDensityFunction (m, varinfo)
1927 x = DynamicPPL. getparams (f)
2028 # Calculate reference logp + gradient of logp using ForwardDiff
21- ref_adtype = ADTypes. AutoForwardDiff ()
2229 ref_ldf = LogDensityFunction (m, varinfo; adtype= ref_adtype)
2330 ref_logp, ref_grad = LogDensityProblems. logdensity_and_gradient (ref_ldf, x)
2431
25- @testset " $adtype " for adtype in [
26- AutoReverseDiff (; compile= false ),
27- AutoReverseDiff (; compile= true ),
28- AutoMooncake (; config= nothing ),
29- ]
32+ @testset " $adtype " for adtype in test_adtypes
3033 @info " Testing AD on: $(m. f) - $(short_varinfo_name (varinfo)) - $adtype "
3134
3235 # Put predicates here to avoid long lines
@@ -103,4 +106,66 @@ using DynamicPPL: LogDensityFunction
103106 )
104107 @test LogDensityProblems. logdensity_and_gradient (ldf, vi[:]) isa Any
105108 end
109+
110+ # Test that various different ways of specifying array types as arguments work with all
111+ # ADTypes.
112+ @testset " Array argument types" begin
113+ test_m = randn (2 , 3 )
114+
115+ function eval_logp_and_grad (model, m, adtype)
116+ ldf = LogDensityFunction (model (); adtype= adtype)
117+ return LogDensityProblems. logdensity_and_gradient (ldf, m[:])
118+ end
119+
120+ @model function scalar_matrix_model (:: Type{T} = Float64) where {T<: Real }
121+ m = Matrix {T} (undef, 2 , 3 )
122+ return m ~ filldist (MvNormal (zeros (2 ), I), 3 )
123+ end
124+
125+ scalar_matrix_model_reference = eval_logp_and_grad (
126+ scalar_matrix_model, test_m, ref_adtype
127+ )
128+
129+ @model function matrix_model (:: Type{T} = Matrix{Float64}) where {T}
130+ m = T (undef, 2 , 3 )
131+ return m ~ filldist (MvNormal (zeros (2 ), I), 3 )
132+ end
133+
134+ matrix_model_reference = eval_logp_and_grad (matrix_model, test_m, ref_adtype)
135+
136+ @model function scalar_array_model (:: Type{T} = Float64) where {T<: Real }
137+ m = Array {T} (undef, 2 , 3 )
138+ return m ~ filldist (MvNormal (zeros (2 ), I), 3 )
139+ end
140+
141+ scalar_array_model_reference = eval_logp_and_grad (
142+ scalar_array_model, test_m, ref_adtype
143+ )
144+
145+ @model function array_model (:: Type{T} = Array{Float64}) where {T}
146+ m = T (undef, 2 , 3 )
147+ return m ~ filldist (MvNormal (zeros (2 ), I), 3 )
148+ end
149+
150+ array_model_reference = eval_logp_and_grad (array_model, test_m, ref_adtype)
151+
152+ @testset " $adtype " for adtype in test_adtypes
153+ scalar_matrix_model_logp_and_grad = eval_logp_and_grad (
154+ scalar_matrix_model, test_m, adtype
155+ )
156+ @test scalar_matrix_model_logp_and_grad[1 ] ≈ scalar_matrix_model_reference[1 ]
157+ @test scalar_matrix_model_logp_and_grad[2 ] ≈ scalar_matrix_model_reference[2 ]
158+ matrix_model_logp_and_grad = eval_logp_and_grad (matrix_model, test_m, adtype)
159+ @test matrix_model_logp_and_grad[1 ] ≈ matrix_model_reference[1 ]
160+ @test matrix_model_logp_and_grad[2 ] ≈ matrix_model_reference[2 ]
161+ scalar_array_model_logp_and_grad = eval_logp_and_grad (
162+ scalar_array_model, test_m, adtype
163+ )
164+ @test scalar_array_model_logp_and_grad[1 ] ≈ scalar_array_model_reference[1 ]
165+ @test scalar_array_model_logp_and_grad[2 ] ≈ scalar_array_model_reference[2 ]
166+ array_model_logp_and_grad = eval_logp_and_grad (array_model, test_m, adtype)
167+ @test array_model_logp_and_grad[1 ] ≈ array_model_reference[1 ]
168+ @test array_model_logp_and_grad[2 ] ≈ array_model_reference[2 ]
169+ end
170+ end
106171end
0 commit comments