|
1 | 1 | module MCMCUtilitiesTests |
2 | 2 |
|
3 | 3 | using ..Models: gdemo_default |
4 | | -using Distributions: Normal, sample, truncated |
5 | | -using LinearAlgebra: I, vec |
6 | | -using Random: Random |
7 | | -using Random: MersenneTwister |
8 | 4 | using Test: @test, @testset |
9 | 5 | using Turing |
10 | 6 |
|
11 | | -@testset "predict" begin |
12 | | - Random.seed!(100) |
13 | | - |
14 | | - @model function linear_reg(x, y, σ=0.1) |
15 | | - β ~ Normal(0, 1) |
16 | | - |
17 | | - for i in eachindex(y) |
18 | | - y[i] ~ Normal(β * x[i], σ) |
19 | | - end |
20 | | - end |
21 | | - |
22 | | - @model function linear_reg_vec(x, y, σ=0.1) |
23 | | - β ~ Normal(0, 1) |
24 | | - return y ~ MvNormal(β .* x, σ^2 * I) |
25 | | - end |
26 | | - |
27 | | - f(x) = 2 * x + 0.1 * randn() |
28 | | - |
29 | | - Δ = 0.1 |
30 | | - xs_train = 0:Δ:10 |
31 | | - ys_train = f.(xs_train) |
32 | | - xs_test = [10 + Δ, 10 + 2 * Δ] |
33 | | - ys_test = f.(xs_test) |
34 | | - |
35 | | - # Infer |
36 | | - m_lin_reg = linear_reg(xs_train, ys_train) |
37 | | - chain_lin_reg = sample(m_lin_reg, NUTS(100, 0.65), 200) |
38 | | - |
39 | | - # Predict on two last indices |
40 | | - m_lin_reg_test = linear_reg(xs_test, fill(missing, length(ys_test))) |
41 | | - predictions = Turing.Inference.predict(m_lin_reg_test, chain_lin_reg) |
42 | | - |
43 | | - ys_pred = vec(mean(Array(group(predictions, :y)); dims=1)) |
44 | | - |
45 | | - @test sum(abs2, ys_test - ys_pred) ≤ 0.1 |
46 | | - |
47 | | - # Ensure that `rng` is respected |
48 | | - predictions1 = let rng = MersenneTwister(42) |
49 | | - predict(rng, m_lin_reg_test, chain_lin_reg[1:2]) |
50 | | - end |
51 | | - predictions2 = let rng = MersenneTwister(42) |
52 | | - predict(rng, m_lin_reg_test, chain_lin_reg[1:2]) |
53 | | - end |
54 | | - @test all(Array(predictions1) .== Array(predictions2)) |
55 | | - |
56 | | - # Predict on two last indices for vectorized |
57 | | - m_lin_reg_test = linear_reg_vec(xs_test, missing) |
58 | | - predictions_vec = Turing.Inference.predict(m_lin_reg_test, chain_lin_reg) |
59 | | - ys_pred_vec = vec(mean(Array(group(predictions_vec, :y)); dims=1)) |
60 | | - |
61 | | - @test sum(abs2, ys_test - ys_pred_vec) ≤ 0.1 |
62 | | - |
63 | | - # Multiple chains |
64 | | - chain_lin_reg = sample(m_lin_reg, NUTS(100, 0.65), MCMCThreads(), 200, 2) |
65 | | - m_lin_reg_test = linear_reg(xs_test, fill(missing, length(ys_test))) |
66 | | - predictions = Turing.Inference.predict(m_lin_reg_test, chain_lin_reg) |
67 | | - |
68 | | - @test size(chain_lin_reg, 3) == size(predictions, 3) |
69 | | - |
70 | | - for chain_idx in MCMCChains.chains(chain_lin_reg) |
71 | | - ys_pred = vec(mean(Array(group(predictions[:, :, chain_idx], :y)); dims=1)) |
72 | | - @test sum(abs2, ys_test - ys_pred) ≤ 0.1 |
73 | | - end |
74 | | - |
75 | | - # Predict on two last indices for vectorized |
76 | | - m_lin_reg_test = linear_reg_vec(xs_test, missing) |
77 | | - predictions_vec = Turing.Inference.predict(m_lin_reg_test, chain_lin_reg) |
78 | | - |
79 | | - for chain_idx in MCMCChains.chains(chain_lin_reg) |
80 | | - ys_pred_vec = vec(mean(Array(group(predictions_vec[:, :, chain_idx], :y)); dims=1)) |
81 | | - @test sum(abs2, ys_test - ys_pred_vec) ≤ 0.1 |
82 | | - end |
83 | | - |
84 | | - # https://github.com/TuringLang/Turing.jl/issues/1352 |
85 | | - @model function simple_linear1(x, y) |
86 | | - intercept ~ Normal(0, 1) |
87 | | - coef ~ MvNormal(zeros(2), I) |
88 | | - coef = reshape(coef, 1, size(x, 1)) |
89 | | - |
90 | | - mu = vec(intercept .+ coef * x) |
91 | | - error ~ truncated(Normal(0, 1), 0, Inf) |
92 | | - return y ~ MvNormal(mu, error^2 * I) |
93 | | - end |
94 | | - |
95 | | - @model function simple_linear2(x, y) |
96 | | - intercept ~ Normal(0, 1) |
97 | | - coef ~ filldist(Normal(0, 1), 2) |
98 | | - coef = reshape(coef, 1, size(x, 1)) |
99 | | - |
100 | | - mu = vec(intercept .+ coef * x) |
101 | | - error ~ truncated(Normal(0, 1), 0, Inf) |
102 | | - return y ~ MvNormal(mu, error^2 * I) |
103 | | - end |
104 | | - |
105 | | - @model function simple_linear3(x, y) |
106 | | - intercept ~ Normal(0, 1) |
107 | | - coef = Vector(undef, 2) |
108 | | - for i in axes(coef, 1) |
109 | | - coef[i] ~ Normal(0, 1) |
110 | | - end |
111 | | - coef = reshape(coef, 1, size(x, 1)) |
112 | | - |
113 | | - mu = vec(intercept .+ coef * x) |
114 | | - error ~ truncated(Normal(0, 1), 0, Inf) |
115 | | - return y ~ MvNormal(mu, error^2 * I) |
116 | | - end |
117 | | - |
118 | | - @model function simple_linear4(x, y) |
119 | | - intercept ~ Normal(0, 1) |
120 | | - coef1 ~ Normal(0, 1) |
121 | | - coef2 ~ Normal(0, 1) |
122 | | - coef = [coef1, coef2] |
123 | | - coef = reshape(coef, 1, size(x, 1)) |
124 | | - |
125 | | - mu = vec(intercept .+ coef * x) |
126 | | - error ~ truncated(Normal(0, 1), 0, Inf) |
127 | | - return y ~ MvNormal(mu, error^2 * I) |
128 | | - end |
129 | | - |
130 | | - # Some data |
131 | | - x = randn(2, 100) |
132 | | - y = [1 + 2 * a + 3 * b for (a, b) in eachcol(x)] |
133 | | - |
134 | | - for model in [simple_linear1, simple_linear2, simple_linear3, simple_linear4] |
135 | | - m = model(x, y) |
136 | | - chain = sample(m, NUTS(), 100) |
137 | | - chain_predict = predict(model(x, missing), chain) |
138 | | - mean_prediction = [mean(chain_predict["y[$i]"].data) for i in 1:length(y)] |
139 | | - @test mean(abs2, mean_prediction - y) ≤ 1e-3 |
140 | | - end |
141 | | -end |
142 | | - |
143 | 7 | @testset "Timer" begin |
144 | 8 | chain = sample(gdemo_default, MH(), 1000) |
145 | 9 |
|
|
0 commit comments