@@ -22,60 +22,59 @@ using Turing
2222 # Set a seed
2323 rng = StableRNG (123 )
2424 @testset " constrained bounded" begin
25- obs = [0 ,1 , 0 , 1 , 1 , 1 , 1 , 1 , 1 , 1 ]
25+ obs = [0 , 1 , 0 , 1 , 1 , 1 , 1 , 1 , 1 , 1 ]
2626
2727 @model function constrained_test (obs)
28- p ~ Beta (2 ,2 )
29- for i = 1 : length (obs)
28+ p ~ Beta (2 , 2 )
29+ for i in 1 : length (obs)
3030 obs[i] ~ Bernoulli (p)
3131 end
32- p
32+ return p
3333 end
3434
3535 chain = sample (
3636 rng,
3737 constrained_test (obs),
3838 HMC (1.5 , 3 ; adtype= adbackend),# using a large step size (1.5)
39- 1000 )
39+ 1000 ,
40+ )
4041
41- check_numerical (chain, [:p ], [10 / 14 ], atol= 0.1 )
42+ check_numerical (chain, [:p ], [10 / 14 ]; atol= 0.1 )
4243 end
4344 @testset " constrained simplex" begin
44- obs12 = [1 ,2 , 1 , 2 , 2 , 2 , 2 , 2 , 2 , 2 ]
45+ obs12 = [1 , 2 , 1 , 2 , 2 , 2 , 2 , 2 , 2 , 2 ]
4546
4647 @model function constrained_simplex_test (obs12)
4748 ps ~ Dirichlet (2 , 3 )
4849 pd ~ Dirichlet (4 , 1 )
49- for i = 1 : length (obs12)
50+ for i in 1 : length (obs12)
5051 obs12[i] ~ Categorical (ps)
5152 end
5253 return ps
5354 end
5455
5556 chain = sample (
56- rng,
57- constrained_simplex_test (obs12),
58- HMC (0.75 , 2 ; adtype= adbackend),
59- 1000 )
57+ rng, constrained_simplex_test (obs12), HMC (0.75 , 2 ; adtype= adbackend), 1000
58+ )
6059
61- check_numerical (chain, [" ps[1]" , " ps[2]" ], [5 / 16 , 11 / 16 ], atol= 0.015 )
60+ check_numerical (chain, [" ps[1]" , " ps[2]" ], [5 / 16 , 11 / 16 ]; atol= 0.015 )
6261 end
6362 @testset " hmc reverse diff" begin
6463 alg = HMC (0.1 , 10 ; adtype= adbackend)
6564 res = sample (rng, gdemo_default, alg, 4000 )
66- check_gdemo (res, rtol= 0.1 )
65+ check_gdemo (res; rtol= 0.1 )
6766 end
6867 @testset " matrix support" begin
6968 @model function hmcmatrixsup ()
70- v ~ Wishart (7 , [1 0.5 ; 0.5 1 ])
69+ return v ~ Wishart (7 , [1 0.5 ; 0.5 1 ])
7170 end
7271
7372 model_f = hmcmatrixsup ()
7473 n_samples = 1_000
7574 vs = map (1 : 3 ) do _
7675 chain = sample (rng, model_f, HMC (0.15 , 7 ; adtype= adbackend), n_samples)
7776 r = reshape (Array (group (chain, :v )), n_samples, 2 , 2 )
78- reshape (mean (r; dims = 1 ), 2 , 2 )
77+ reshape (mean (r; dims= 1 ), 2 , 2 )
7978 end
8079
8180 @test maximum (abs, mean (vs) - (7 * [1 0.5 ; 0.5 1 ])) <= 0.5
@@ -92,10 +91,10 @@ using Turing
9291 M = N ÷ 4
9392 x1s = rand (M) * 5
9493 x2s = rand (M) * 5
95- xt1s = Array ([[x1s[i]; x2s[i]] for i = 1 : M])
96- append! (xt1s, Array ([[x1s[i] - 6 ; x2s[i] - 6 ] for i = 1 : M]))
97- xt0s = Array ([[x1s[i]; x2s[i] - 6 ] for i = 1 : M])
98- append! (xt0s, Array ([[x1s[i] - 6 ; x2s[i]] for i = 1 : M]))
94+ xt1s = Array ([[x1s[i]; x2s[i]] for i in 1 : M])
95+ append! (xt1s, Array ([[x1s[i] - 6 ; x2s[i] - 6 ] for i in 1 : M]))
96+ xt0s = Array ([[x1s[i]; x2s[i] - 6 ] for i in 1 : M])
97+ append! (xt0s, Array ([[x1s[i] - 6 ; x2s[i]] for i in 1 : M]))
9998
10099 xs = [xt1s; xt0s]
101100 ts = [ones (M); ones (M); zeros (M); zeros (M)]
@@ -106,20 +105,22 @@ using Turing
106105 var_prior = sqrt (1.0 / alpha) # variance of the Gaussian prior
107106
108107 @model function bnn (ts)
109- b1 ~ MvNormal ([0. ;0. ; 0. ],
110- [var_prior 0. 0. ; 0. var_prior 0. ; 0. 0. var_prior])
111- w11 ~ MvNormal ([0. ; 0. ], [var_prior 0. ; 0. var_prior])
112- w12 ~ MvNormal ([0. ; 0. ], [var_prior 0. ; 0. var_prior])
113- w13 ~ MvNormal ([0. ; 0. ], [var_prior 0. ; 0. var_prior])
108+ b1 ~ MvNormal (
109+ [0.0 ; 0.0 ; 0.0 ], [var_prior 0.0 0.0 ; 0.0 var_prior 0.0 ; 0.0 0.0 var_prior]
110+ )
111+ w11 ~ MvNormal ([0.0 ; 0.0 ], [var_prior 0.0 ; 0.0 var_prior])
112+ w12 ~ MvNormal ([0.0 ; 0.0 ], [var_prior 0.0 ; 0.0 var_prior])
113+ w13 ~ MvNormal ([0.0 ; 0.0 ], [var_prior 0.0 ; 0.0 var_prior])
114114 bo ~ Normal (0 , var_prior)
115115
116- wo ~ MvNormal ([0. ; 0 ; 0 ],
117- [var_prior 0. 0. ; 0. var_prior 0. ; 0. 0. var_prior])
118- for i = rand (1 : N, 10 )
116+ wo ~ MvNormal (
117+ [0.0 ; 0 ; 0 ], [var_prior 0.0 0.0 ; 0.0 var_prior 0.0 ; 0.0 0.0 var_prior]
118+ )
119+ for i in rand (1 : N, 10 )
119120 y = nn (xs[i], b1, w11, w12, w13, bo, wo)
120121 ts[i] ~ Bernoulli (y)
121122 end
122- b1, w11, w12, w13, bo, wo
123+ return b1, w11, w12, w13, bo, wo
123124 end
124125
125126 # Sampling
@@ -147,7 +148,7 @@ using Turing
147148 Random. seed! (12345 ) # particle samplers do not support user-provided `rng` yet
148149 alg3 = Gibbs (PG (20 , :s ), HMCDA (500 , 0.8 , 0.25 , :m ; init_ϵ= 0.05 , adtype= adbackend))
149150
150- res3 = sample (rng, gdemo_default, alg3, 3000 , discard_initial= 1000 )
151+ res3 = sample (rng, gdemo_default, alg3, 3000 ; discard_initial= 1000 )
151152 check_gdemo (res3)
152153 end
153154
@@ -191,8 +192,8 @@ using Turing
191192 @testset " check discard" begin
192193 alg = NUTS (100 , 0.8 ; adtype= adbackend)
193194
194- c1 = sample (rng, gdemo_default, alg, 500 , discard_adapt= true )
195- c2 = sample (rng, gdemo_default, alg, 500 , discard_adapt= false )
195+ c1 = sample (rng, gdemo_default, alg, 500 ; discard_adapt= true )
196+ c2 = sample (rng, gdemo_default, alg, 500 ; discard_adapt= false )
196197
197198 @test size (c1, 1 ) == 500
198199 @test size (c2, 1 ) == 500
@@ -210,20 +211,20 @@ using Turing
210211 # https://github.com/TuringLang/DynamicPPL.jl/issues/27
211212 @model function mwe1 (:: Type{T} = Float64) where {T<: Real }
212213 m = Matrix {T} (undef, 2 , 3 )
213- m .~ MvNormal (zeros (2 ), I)
214+ return m .~ MvNormal (zeros (2 ), I)
214215 end
215216 @test sample (rng, mwe1 (), HMC (0.2 , 4 ; adtype= adbackend), 1_000 ) isa Chains
216217
217218 @model function mwe2 (:: Type{T} = Matrix{Float64}) where {T}
218219 m = T (undef, 2 , 3 )
219- m .~ MvNormal (zeros (2 ), I)
220+ return m .~ MvNormal (zeros (2 ), I)
220221 end
221222 @test sample (rng, mwe2 (), HMC (0.2 , 4 ; adtype= adbackend), 1_000 ) isa Chains
222223
223224 # https://github.com/TuringLang/Turing.jl/issues/1308
224225 @model function mwe3 (:: Type{T} = Array{Float64}) where {T}
225226 m = T (undef, 2 , 3 )
226- m .~ MvNormal (zeros (2 ), I)
227+ return m .~ MvNormal (zeros (2 ), I)
227228 end
228229 @test sample (rng, mwe3 (), HMC (0.2 , 4 ; adtype= adbackend), 1_000 ) isa Chains
229230 end
@@ -241,13 +242,17 @@ using Turing
241242 @model function demo_hmc_prior ()
242243 # NOTE: Used to use `InverseGamma(2, 3)` but this has infinite variance
243244 # which means that it's _very_ difficult to find a good tolerance in the test below:)
244- s ~ truncated (Normal (3 , 1 ), lower= 0 )
245- m ~ Normal (0 , sqrt (s))
245+ s ~ truncated (Normal (3 , 1 ); lower= 0 )
246+ return m ~ Normal (0 , sqrt (s))
246247 end
247248 alg = NUTS (1000 , 0.8 ; adtype= adbackend)
248- gdemo_default_prior = DynamicPPL. contextualize (demo_hmc_prior (), DynamicPPL. PriorContext ())
249+ gdemo_default_prior = DynamicPPL. contextualize (
250+ demo_hmc_prior (), DynamicPPL. PriorContext ()
251+ )
249252 chain = sample (gdemo_default_prior, alg, 10_000 ; initial_params= [3.0 , 0.0 ])
250- check_numerical (chain, [:s , :m ], [mean (truncated (Normal (3 , 1 ); lower= 0 )), 0 ], atol= 0.2 )
253+ check_numerical (
254+ chain, [:s , :m ], [mean (truncated (Normal (3 , 1 ); lower= 0 )), 0 ]; atol= 0.2
255+ )
251256 end
252257
253258 @testset " warning for difficult init params" begin
@@ -262,7 +267,7 @@ using Turing
262267 @test_logs (
263268 :warn ,
264269 " failed to find valid initial parameters in 10 tries; consider providing explicit initial parameters using the `initial_params` keyword" ,
265- ) (:info ,) match_mode= :any begin
270+ ) (:info ,) match_mode = :any begin
266271 sample (demo_warn_initial_params (), NUTS (; adtype= adbackend), 5 )
267272 end
268273 end
@@ -271,7 +276,7 @@ using Turing
271276 @model function vector_of_dirichlet (:: Type{TV} = Vector{Float64}) where {TV}
272277 xs = Vector {TV} (undef, 2 )
273278 xs[1 ] ~ Dirichlet (ones (5 ))
274- xs[2 ] ~ Dirichlet (ones (5 ))
279+ return xs[2 ] ~ Dirichlet (ones (5 ))
275280 end
276281 model = vector_of_dirichlet ()
277282 chain = sample (model, NUTS (), 1000 )
@@ -296,15 +301,10 @@ using Turing
296301 end
297302 end
298303
299- model = buggy_model ();
300- num_samples = 1_000 ;
304+ model = buggy_model ()
305+ num_samples = 1_000
301306
302- chain = sample (
303- model,
304- NUTS (),
305- num_samples;
306- initial_params= [0.5 , 1.75 , 1.0 ]
307- )
307+ chain = sample (model, NUTS (), num_samples; initial_params= [0.5 , 1.75 , 1.0 ])
308308 chain_prior = sample (model, Prior (), num_samples)
309309
310310 # Extract the `x` like this because running `generated_quantities` was how
0 commit comments