@@ -22,62 +22,23 @@ $(TYPEDEF)
2222* `nclique` is the number of cliques,
2323* `cards` is a vector of cardinalities for variables,
2424* `factors` is a vector of factors,
25-
26- * `obsvars` is a vector of observed variables,
27- * `obsvals` is a vector of observed values,
28- * `queryvars` is a vector of query variables,
29- * `reference_solution` is a vector with the reference solution.
3025"""
3126struct UAIInstance{ET, FT <: Factor{ET} }
3227 nvars:: Int
3328 nclique:: Int
3429 cards:: Vector{Int}
3530 factors:: Vector{FT}
36-
37- obsvars:: Vector{Int}
38- obsvals:: Vector{Int}
39- queryvars:: Vector{Int}
40- reference_solution
4131end
4232
4333Base. show (io:: IO , :: MIME"text/plain" , uai:: UAIInstance ) = Base. show (io, uai)
4434function Base. show (io:: IO , uai:: UAIInstance )
4535 println (io, " UAIInstance(nvars = $(uai. nvars) , nclique = $(uai. nclique) )" )
4636 println (io, " variables :" )
47- for (var, card) in zip (1 : uai. nvars, uai. cards)
48- println (io, string_var (" $var of size $card " , var ∈ uai. queryvars, Dict (zip (uai. obsvars, uai. obsvals))))
49- end
5037 println (io, " factors : " )
51- for f in uai. factors
52- println (io, " $(summary (f)) " )
38+ for (k, f) in enumerate (uai. factors)
39+ print (io, " $(summary (f)) " )
40+ k == length (uai. factors) || println (io)
5341 end
54- print (io, " reference_solution : $(uai. reference_solution) " )
55- end
56-
57- """
58- $TYPEDSIGNATURES
59-
60- Set the evidence of an UAI instance.
61- """
62- function set_evidence! (uai:: UAIInstance , pairs:: Pair{Int} ...)
63- empty! (uai. obsvars)
64- empty! (uai. obsvals)
65- for (var, val) in pairs
66- push! (uai. obsvars, var)
67- push! (uai. obsvals, val)
68- end
69- return uai
70- end
71-
72- """
73- $TYPEDSIGNATURES
74-
75- Set the query variables of an UAI instance.
76- """
77- function set_query! (uai:: UAIInstance , vars:: AbstractVector{Int} )
78- empty! (uai. queryvars)
79- append! (uai. queryvars, vars)
80- return uai
8142end
8243
8344"""
@@ -89,32 +50,32 @@ Probabilistic modeling with a tensor network.
8950* `vars` is the degree of freedoms in the tensor network.
9051* `code` is the tensor network contraction pattern.
9152* `tensors` is the tensors fed into the tensor network.
92- * `fixedvertices ` is a dictionary to specifiy degree of freedoms fixed to certain values.
53+ * `evidence ` is a dictionary to specifiy degree of freedoms fixed to certain values.
9354"""
9455struct TensorNetworkModel{LT, ET, MT <: AbstractArray }
9556 vars:: Vector{LT}
9657 code:: ET
9758 tensors:: Vector{MT}
98- fixedvertices :: Dict{LT, Int}
59+ evidence :: Dict{LT, Int}
9960end
10061
10162function Base. show (io:: IO , tn:: TensorNetworkModel )
10263 open = getiyv (tn. code)
103- variables = join ([string_var (var, open, tn. fixedvertices ) for var in tn. vars], " , " )
64+ variables = join ([string_var (var, open, tn. evidence ) for var in tn. vars], " , " )
10465 tc, sc, rw = contraction_complexity (tn)
10566 println (io, " $(typeof (tn)) " )
10667 println (io, " variables: $variables " )
10768 print_tcscrw (io, tc, sc, rw)
10869end
10970Base. show (io:: IO , :: MIME"text/plain" , tn:: TensorNetworkModel ) = Base. show (io, tn)
11071
111- function string_var (var, open, fixedvertices )
112- if var ∈ open && haskey (fixedvertices , var)
113- " $var (open, fixed to $(fixedvertices [var]) )"
72+ function string_var (var, open, evidence )
73+ if var ∈ open && haskey (evidence , var)
74+ " $var (open, fixed to $(evidence [var]) )"
11475 elseif var ∈ open
11576 " $var (open)"
116- elseif haskey (fixedvertices , var)
117- " $var (evidence → $(fixedvertices [var]) )"
77+ elseif haskey (evidence , var)
78+ " $var (evidence → $(evidence [var]) )"
11879 else
11980 " $var "
12081 end
@@ -129,19 +90,17 @@ $(TYPEDSIGNATURES)
12990"""
13091function TensorNetworkModel (
13192 instance:: UAIInstance ;
132- openvertices = (),
93+ openvars = (),
94+ evidence = Dict {Int,Int} (),
13395 optimizer = GreedyMethod (),
13496 simplifier = nothing
13597):: TensorNetworkModel
136- if ! isempty (instance. queryvars)
137- @warn " The `queryvars` field of the input `UAIInstance` instance is designed for the `MMAPModel`, which is not respected by `TensorNetworkModel`. Got non-empty value: $(uai. queryvars) "
138- end
13998 return TensorNetworkModel (
14099 1 : (instance. nvars),
141100 instance. cards,
142101 instance. factors;
143- openvertices ,
144- fixedvertices = Dict ( zip (instance . obsvars, instance . obsvals)) ,
102+ openvars ,
103+ evidence ,
145104 optimizer,
146105 simplifier
147106 )
@@ -154,18 +113,18 @@ function TensorNetworkModel(
154113 vars:: AbstractVector{LT} ,
155114 cards:: AbstractVector{Int} ,
156115 factors:: Vector{<:Factor{T}} ;
157- openvertices = (),
158- fixedvertices = Dict {LT, Int} (),
116+ openvars = (),
117+ evidence = Dict {LT, Int} (),
159118 optimizer = GreedyMethod (),
160119 simplifier = nothing
161120):: TensorNetworkModel where {T, LT}
162121 # The 1st argument of `EinCode` is a vector of vector of labels for specifying the input tensors,
163122 # The 2nd argument of `EinCode` is a vector of labels for specifying the output tensor,
164123 # e.g.
165124 # `EinCode([[1, 2], [2, 3]], [1, 3])` is the EinCode for matrix multiplication.
166- rawcode = EinCode ([[[var] for var in vars]. .. , [[factor. vars... ] for factor in factors]. .. ], collect (LT, openvertices )) # labels for vertex tensors (unity tensors) and edge tensors
125+ rawcode = EinCode ([[[var] for var in vars]. .. , [[factor. vars... ] for factor in factors]. .. ], collect (LT, openvars )) # labels for vertex tensors (unity tensors) and edge tensors
167126 tensors = Array{T}[[ones (T, cards[i]) for i in 1 : length (vars)]. .. , [t. vals for t in factors]. .. ]
168- return TensorNetworkModel (collect (LT, vars), rawcode, tensors; fixedvertices , optimizer, simplifier)
127+ return TensorNetworkModel (collect (LT, vars), rawcode, tensors; evidence , optimizer, simplifier)
169128end
170129
171130"""
@@ -175,7 +134,7 @@ function TensorNetworkModel(
175134 vars:: AbstractVector{LT} ,
176135 rawcode:: EinCode ,
177136 tensors:: Vector{<:AbstractArray} ;
178- fixedvertices = Dict {LT, Int} (),
137+ evidence = Dict {LT, Int} (),
179138 optimizer = GreedyMethod (),
180139 simplifier = nothing
181140):: TensorNetworkModel where {LT}
@@ -185,7 +144,7 @@ function TensorNetworkModel(
185144 # The 3rd and 4th arguments are the optimizer and simplifier that configures which algorithm to use and simplify.
186145 size_dict = OMEinsum. get_size_dict (getixsv (rawcode), tensors)
187146 code = optimize_code (rawcode, size_dict, optimizer, simplifier)
188- TensorNetworkModel (collect (LT, vars), code, tensors, fixedvertices )
147+ TensorNetworkModel (collect (LT, vars), code, tensors, evidence )
189148end
190149
191150"""
@@ -202,10 +161,10 @@ Get the cardinalities of variables in this tensor network.
202161"""
203162function get_cards (tn:: TensorNetworkModel ; fixedisone = false ):: Vector
204163 vars = get_vars (tn)
205- [fixedisone && haskey (tn. fixedvertices , vars[k]) ? 1 : length (tn. tensors[k]) for k in 1 : length (vars)]
164+ [fixedisone && haskey (tn. evidence , vars[k]) ? 1 : length (tn. tensors[k]) for k in 1 : length (vars)]
206165end
207166
208- chfixedvertices (tn:: TensorNetworkModel , fixedvertices ) = TensorNetworkModel (tn. vars, tn. code, tn. tensors, fixedvertices )
167+ chevidence (tn:: TensorNetworkModel , evidence ) = TensorNetworkModel (tn. vars, tn. code, tn. tensors, evidence )
209168
210169"""
211170$(TYPEDSIGNATURES)
221180$(TYPEDSIGNATURES)
222181
223182Contract the tensor network and return a probability array with its rank specified in the contraction code `tn.code`.
224- The returned array may not be l1-normalized even if the total probability is l1-normalized, because the evidence `tn.fixedvertices ` may not be empty.
183+ The returned array may not be l1-normalized even if the total probability is l1-normalized, because the evidence `tn.evidence ` may not be empty.
225184"""
226185function probability (tn:: TensorNetworkModel ; usecuda = false , rescale = true ):: AbstractArray
227186 return tn. code (adapt_tensors (tn; usecuda, rescale)... )
0 commit comments