Skip to content

Commit b60141e

Browse files
committed
fix the overflow issue in probability
1 parent 83ee4e7 commit b60141e

File tree

2 files changed

+21
-1
lines changed

2 files changed

+21
-1
lines changed

src/Core.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,16 @@ function log_probability(tn::TensorNetworkModel, config::Union{Dict, AbstractVec
203203
assign = config isa AbstractVector ? Dict(zip(get_vars(tn), config)) : config
204204
return sum(x -> log(x[2][(getindex.(Ref(assign), x[1]) .+ 1)...]), zip(getixsv(tn.code), tn.tensors))
205205
end
206+
"""
207+
$(TYPEDSIGNATURES)
208+
209+
Evaluate the log probability (or partition function).
210+
It is the logged version of [`probability`](@ref), which is less likely to overflow.
211+
"""
212+
function log_probability(tn::TensorNetworkModel; usecuda = false)::AbstractArray
213+
res = probability(tn; usecuda, rescale=true)
214+
return asarray(res.log_factor .+ log.(res.normalized_value), res.normalized_value)
215+
end
206216

207217
"""
208218
$(TYPEDSIGNATURES)

test/pr.jl

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,19 @@ using TensorInference
2222
for (id, problem) in problems[problem_set_name]
2323
@info "Testing: $(problem_set_name)_$id"
2424
tn = TensorNetworkModel(read_model(problem); optimizer, evidence=read_evidence(problem))
25-
solution = probability(tn) |> first |> log10
25+
solution = log_probability(tn) / log(10) |> first
2626
@test isapprox(solution, read_solution(problem); atol = 1e-3)
2727
end
2828
end
2929
end
3030
end
31+
32+
@testset "issue 77" begin
33+
problems = dataset_from_artifact("uai2014")["PR"]
34+
problem_set_name = "Alchemy"
35+
optimizer = TreeSA(ntrials = 1, niters = 5, βs = 0.1:0.1:100)
36+
id, problem = problems[problem_set_name] |> first
37+
tn = TensorNetworkModel(read_model(problem); optimizer, evidence=read_evidence(problem))
38+
solution = log_probability(tn) / log(10) |> first
39+
@test isapprox(solution, read_solution(problem); atol=1e-3)
40+
end

0 commit comments

Comments
 (0)