|
1 | | -module TestExtUtils |
2 | | - |
3 | | -################################################### |
4 | | -# These used to be in DPPL/src/test_utils.jl ###### |
5 | | -################################################### |
| 1 | +module TestUtils |
6 | 2 |
|
7 | 3 | using AbstractMCMC |
8 | 4 | using DynamicPPL |
@@ -1101,123 +1097,4 @@ function DynamicPPL.dot_tilde_observe( |
1101 | 1097 | return logp * context.mod, vi |
1102 | 1098 | end |
1103 | 1099 |
|
1104 | | - |
1105 | | - |
1106 | | -################################################### |
1107 | | -# These used to be in DPPL/test/test_util.jl ###### |
1108 | | -################################################### |
1109 | | - |
1110 | | -# default model |
1111 | | -@model function gdemo_d() |
1112 | | - s ~ InverseGamma(2, 3) |
1113 | | - m ~ Normal(0, sqrt(s)) |
1114 | | - 1.5 ~ Normal(m, sqrt(s)) |
1115 | | - 2.0 ~ Normal(m, sqrt(s)) |
1116 | | - return s, m |
1117 | | -end |
1118 | | -const gdemo_default = gdemo_d() |
1119 | | - |
1120 | | -function test_model_ad(model, logp_manual) |
1121 | | - vi = VarInfo(model) |
1122 | | - x = DynamicPPL.getall(vi) |
1123 | | - |
1124 | | - # Log probabilities using the model. |
1125 | | - ℓ = DynamicPPL.LogDensityFunction(model, vi) |
1126 | | - logp_model = Base.Fix1(LogDensityProblems.logdensity, ℓ) |
1127 | | - |
1128 | | - # Check that both functions return the same values. |
1129 | | - lp = logp_manual(x) |
1130 | | - @test logp_model(x) ≈ lp |
1131 | | - |
1132 | | - # Gradients based on the manual implementation. |
1133 | | - grad = ForwardDiff.gradient(logp_manual, x) |
1134 | | - |
1135 | | - y, back = Tracker.forward(logp_manual, x) |
1136 | | - @test Tracker.data(y) ≈ lp |
1137 | | - @test Tracker.data(back(1)[1]) ≈ grad |
1138 | | - |
1139 | | - y, back = Zygote.pullback(logp_manual, x) |
1140 | | - @test y ≈ lp |
1141 | | - @test back(1)[1] ≈ grad |
1142 | | - |
1143 | | - # Gradients based on the model. |
1144 | | - @test ForwardDiff.gradient(logp_model, x) ≈ grad |
1145 | | - |
1146 | | - y, back = Tracker.forward(logp_model, x) |
1147 | | - @test Tracker.data(y) ≈ lp |
1148 | | - @test Tracker.data(back(1)[1]) ≈ grad |
1149 | | - |
1150 | | - y, back = Zygote.pullback(logp_model, x) |
1151 | | - @test y ≈ lp |
1152 | | - @test back(1)[1] ≈ grad |
1153 | | -end |
1154 | | - |
1155 | | -""" |
1156 | | - test_setval!(model, chain; sample_idx = 1, chain_idx = 1) |
1157 | | -
|
1158 | | -Test `setval!` on `model` and `chain`. |
1159 | | -
|
1160 | | -Worth noting that this only supports models containing symbols of the forms |
1161 | | -`m`, `m[1]`, `m[1, 2]`, not `m[1][1]`, etc. |
1162 | | -""" |
1163 | | -function test_setval!(model, chain; sample_idx=1, chain_idx=1) |
1164 | | - var_info = VarInfo(model) |
1165 | | - spl = SampleFromPrior() |
1166 | | - θ_old = var_info[spl] |
1167 | | - DynamicPPL.setval!(var_info, chain, sample_idx, chain_idx) |
1168 | | - θ_new = var_info[spl] |
1169 | | - @test θ_old != θ_new |
1170 | | - vals = DynamicPPL.values_as(var_info, OrderedDict) |
1171 | | - iters = map(DynamicPPL.varname_and_value_leaves, keys(vals), values(vals)) |
1172 | | - for (n, v) in mapreduce(collect, vcat, iters) |
1173 | | - n = string(n) |
1174 | | - if Symbol(n) ∉ keys(chain) |
1175 | | - # Assume it's a group |
1176 | | - chain_val = vec( |
1177 | | - MCMCChains.group(chain, Symbol(n)).value[sample_idx, :, chain_idx] |
1178 | | - ) |
1179 | | - v_true = vec(v) |
1180 | | - else |
1181 | | - chain_val = chain[sample_idx, n, chain_idx] |
1182 | | - v_true = v |
1183 | | - end |
1184 | | - |
1185 | | - @test v_true == chain_val |
1186 | | - end |
1187 | | -end |
1188 | | - |
1189 | | -""" |
1190 | | - short_varinfo_name(vi::AbstractVarInfo) |
1191 | | -
|
1192 | | -Return string representing a short description of `vi`. |
1193 | | -""" |
1194 | | -short_varinfo_name(vi::DynamicPPL.ThreadSafeVarInfo) = |
1195 | | - "threadsafe($(short_varinfo_name(vi.varinfo)))" |
1196 | | -function short_varinfo_name(vi::TypedVarInfo) |
1197 | | - DynamicPPL.has_varnamedvector(vi) && return "TypedVarInfo with VarNamedVector" |
1198 | | - return "TypedVarInfo" |
1199 | | -end |
1200 | | -short_varinfo_name(::UntypedVarInfo) = "UntypedVarInfo" |
1201 | | -short_varinfo_name(::DynamicPPL.VectorVarInfo) = "VectorVarInfo" |
1202 | | -short_varinfo_name(::SimpleVarInfo{<:NamedTuple}) = "SimpleVarInfo{<:NamedTuple}" |
1203 | | -short_varinfo_name(::SimpleVarInfo{<:OrderedDict}) = "SimpleVarInfo{<:OrderedDict}" |
1204 | | -function short_varinfo_name(::SimpleVarInfo{<:DynamicPPL.VarNamedVector}) |
1205 | | - return "SimpleVarInfo{<:VarNamedVector}" |
1206 | | -end |
1207 | | - |
1208 | | -# convenient functions for testing model.jl |
1209 | | -# function to modify the representation of values based on their length |
1210 | | -function modify_value_representation(nt::NamedTuple) |
1211 | | - modified_nt = NamedTuple() |
1212 | | - for (key, value) in zip(keys(nt), values(nt)) |
1213 | | - if length(value) == 1 # Scalar value |
1214 | | - modified_value = value[1] |
1215 | | - else # Non-scalar value |
1216 | | - modified_value = value |
1217 | | - end |
1218 | | - modified_nt = merge(modified_nt, (key => modified_value,)) |
1219 | | - end |
1220 | | - return modified_nt |
1221 | 1100 | end |
1222 | | - |
1223 | | -end # module TestExtUtils |
0 commit comments