@@ -28,17 +28,44 @@ function subset(samples::Samples{L}, labels::AbstractVector{L}) where L
2828 return view (samples. samples, idx, :)
2929end
3030
31- function eliminate_dimensions (x:: AbstractArray{T, N} , ix:: AbstractVector{L} , el:: Pair{<:AbstractVector{L}} ) where {T, N, L}
31+ function eliminate_dimensions (x:: AbstractArray{T, N} , ix:: AbstractVector{L} , el:: Pair{<:AbstractVector{L}, <:AbstractVector } ) where {T, N, L}
3232 @assert length (ix) == N
33- idx = ntuple (N) do i
34- if ix[i] ∈ el. first
35- k = el. second[findfirst (== (ix[i]), el. first)] + 1
33+ return x[eliminated_selector (size (x), ix, el. first, el. second)... ]
34+ end
35+ function eliminated_size (size0, ix, labels)
36+ @assert length (size0) == length (ix)
37+ return ntuple (length (ix)) do i
38+ ix[i] ∈ labels ? 1 : size0[i]
39+ end
40+ end
41+ function eliminated_selector (size0, ix, labels, config)
42+ return ntuple (length (ix)) do i
43+ if ix[i] ∈ labels
44+ k = config[findfirst (== (ix[i]), labels)] + 1
3645 k: k
3746 else
38- 1 : size (x, i)
47+ 1 : size0[i]
3948 end
4049 end
41- return asarray (x[idx... ], x)
50+ end
51+ function eliminate_dimensions_addbatch! (x:: AbstractArray{T, N} , ix:: AbstractVector{L} , el:: Pair{<:AbstractVector{L}, <:AbstractMatrix} , batch_label:: L ) where {T, N, L}
52+ nbatch = size (el. second, 2 )
53+ @assert length (ix) == N
54+ res = similar (x, (eliminated_size (size (x), ix, el. first)... , nbatch))
55+ for ibatch in 1 : nbatch
56+ selectdim (res, N+ 1 , ibatch) .= eliminate_dimensions (x, ix, el. first=> view (el. second, :, ibatch))
57+ end
58+ push! (ix, batch_label)
59+ return res
60+ end
61+ function eliminate_dimensions_withbatch (x:: AbstractArray{T, N} , ix:: AbstractVector{L} , el:: Pair{<:AbstractVector{L}, <:AbstractMatrix} ) where {T, N, L}
62+ nbatch = size (el. second, 2 )
63+ @assert length (ix) == N && size (x, N) == nbatch
64+ res = similar (x, (eliminated_size (size (x), ix, el. first)))
65+ for ibatch in 1 : nbatch
66+ selectdim (res, N, ibatch) .= eliminate_dimensions (selectdim (x, N, ibatch), ix[1 : end - 1 ], el. first=> view (el. second, :, ibatch))
67+ end
68+ return res
4269end
4370
4471"""
@@ -72,78 +99,96 @@ function sample(tn::TensorNetworkModel, n::Int; usecuda = false, queryvars = get
7299 end
73100 samples = Samples (configs, queryvars)
74101 # back-propagate
75- env = copy (cache. content)
102+ env = similar (cache. content, ( size (cache . content) ... , n)) # batched env
76103 fill! (env, one (eltype (env)))
77- generate_samples! (tn. code, cache, env, samples, copy (samples. labels), size_dict) # note: `copy` is necessary
104+ batch_label = _newindex (OMEinsum. uniquelabels (tn. code))
105+ code = deepcopy (tn. code)
106+ iy_env = [OMEinsum. getiyv (code)... , batch_label]
107+ size_dict[batch_label] = n
108+ generate_samples! (code, cache, iy_env, env, samples, copy (samples. labels), batch_label, size_dict) # note: `copy` is necessary
78109 # set evidence variables
79- for (k, v) in tn. evidence
80- idx = findfirst (== (k), samples . labels )
110+ for (k, v) in setdiff ( tn. evidence, queryvars)
111+ idx = findfirst (== (k), queryvars )
81112 samples. samples[idx, :] .= v
82113 end
83114 return samples
84115end
116+ _newindex (labels:: AbstractVector{<:Union{Int, Char}} ) = maximum (labels) + 1
117+ _newindex (:: AbstractVector{Symbol} ) = gensym (:batch )
85118_Weights (x:: AbstractVector{<:Real} ) = Weights (x)
86119function _Weights (x:: AbstractArray{<:Complex} )
87- @assert all (e-> abs (imag (e)) < max (100 * eps (abs (e)), 1e-10 ), x) " Complex probability encountered: $x "
120+ @assert all (e-> abs (imag (e)) < max (100 * eps (abs (e)), 1e-8 ), x) " Complex probability encountered: $x "
88121 return Weights (real .(x))
89122end
90123
91- function generate_samples! (se:: SlicedEinsum , cache:: CacheTree{T} , env:: AbstractArray{T} , samples, pool, size_dict:: Dict ) where {T}
124+ function generate_samples! (se:: SlicedEinsum , cache:: CacheTree{T} , iy_env :: Vector{Int} , env:: AbstractArray{T} , samples:: Samples{L} , pool, batch_label :: L , size_dict:: Dict{L} ) where {T, L }
92125 # slicing is not supported yet.
93126 if length (se. slicing) != 0
94127 @warn " Slicing is not supported for caching, got nslices = $(length (se. slicing)) ! Fallback to `NestedEinsum`."
95128 end
96- return generate_samples! (se. eins, cache, env, samples, pool, size_dict)
129+ return generate_samples! (se. eins, cache, iy_env, env, samples, pool, batch_label , size_dict)
97130end
98131
99132# pool is a vector of labels that are not eliminated yet.
100- function generate_samples! (code:: NestedEinsum , cache:: CacheTree{T} , env:: AbstractArray{T} , samples:: Samples{L} , pool:: Vector{L} , size_dict:: Dict{L} ) where {T, L}
133+ function generate_samples! (code:: DynamicNestedEinsum , cache:: CacheTree{T} , iy_env:: Vector{L} , env:: AbstractArray{T} , samples:: Samples{L} , pool:: Vector{L} , batch_label:: L , size_dict:: Dict{L} ) where {T, L}
134+ @assert length (iy_env) == ndims (env)
101135 if ! (OMEinsum. isleaf (code))
102136 ixs, iy = getixsv (code. eins), getiyv (code. eins)
103137 for (subcode, child, ix) in zip (code. args, cache. children, ixs)
104138 # subenv for the current child, use it to sample and update its cache
105139 siblings = filter (x-> x != = child, cache. children)
106140 siblings_ixs = filter (x-> x != = ix, ixs)
107- envcode = optimize_code (EinCode ([siblings_ixs... , iy], ix), size_dict, GreedyMethod (; nrepeat= 1 ))
141+ iy_subenv = batch_label ∈ ix ? ix : [ix... , batch_label]
142+ envcode = optimize_code (EinCode ([siblings_ixs... , iy_env], iy_subenv), size_dict, GreedyMethod (; nrepeat= 1 ))
108143 subenv = einsum (envcode, (getfield .(siblings, :content )... , env), size_dict)
109144
110145 # generate samples
111146 sample_vars = ix ∩ pool
112147 if ! isempty (sample_vars)
113- probabilities = einsum (DynamicEinCode ([ix, ix], sample_vars), (child. content, subenv), size_dict)
114- update_samples! (samples, sample_vars, probabilities)
148+ probabilities = einsum (DynamicEinCode ([ix, iy_subenv], [sample_vars... , batch_label]), (child. content, subenv), size_dict)
149+ for ibatch in axes (probabilities, ndims (probabilities))
150+ update_samples! (samples. labels, samples[ibatch], sample_vars, selectdim (probabilities, ndims (probabilities), ibatch))
151+ end
115152 setdiff! (pool, sample_vars)
116153
117154 # eliminate the sampled variables
118155 setindex! .(Ref (size_dict), 1 , sample_vars)
119- subsamples = subset (samples, sample_vars)[:, 1 ]
120- udpate_cache_tree! (code, cache, sample_vars=> subsamples, size_dict)
121- subenv = eliminate_dimensions (subenv, ix, sample_vars=> subsamples)
156+ subsamples = subset (samples, sample_vars)
157+ udpate_cache_tree! (code, cache, sample_vars=> subsamples, batch_label, size_dict)
158+ subenv = _eliminate! (subenv, ix, sample_vars=> subsamples, batch_label )
122159 end
123160
124161 # recurse
125- generate_samples! (subcode, child, subenv, samples, pool, size_dict)
162+ generate_samples! (subcode, child, iy_subenv, subenv, samples, pool, batch_label , size_dict)
126163 end
127164 end
128165end
129166
167+ function _eliminate! (x, ix, el, batch_label)
168+ if batch_label ∈ ix
169+ eliminate_dimensions_withbatch (x, ix, el)
170+ else
171+ eliminate_dimensions_addbatch! (x, ix, el, batch_label)
172+ end
173+ end
174+
130175# probabilities is a tensor of probabilities for each variable in `vars`.
131- function update_samples! (samples :: Samples , vars:: AbstractVector{L} , probabilities:: AbstractArray{T, N} ) where {L, T, N}
176+ function update_samples! (labels, sample , vars:: AbstractVector{L} , probabilities:: AbstractArray{T, N} ) where {L, T, N}
132177 @assert length (vars) == N
133178 totalset = CartesianIndices (probabilities)
134- eliminated_locs = idx4labels (samples . labels, vars)
179+ eliminated_locs = idx4labels (labels, vars)
135180 config = StatsBase. sample (totalset, _Weights (vec (probabilities)))
136- samples . samples [eliminated_locs, 1 ] .= config. I .- 1
181+ sample [eliminated_locs] .= config. I .- 1
137182end
138183
139- function udpate_cache_tree! (ne:: NestedEinsum , cache:: CacheTree{T} , el:: Pair{<:AbstractVector{L}} , size_dict:: Dict{L} ) where {T, L}
184+ function udpate_cache_tree! (ne:: NestedEinsum , cache:: CacheTree{T} , el:: Pair{<:AbstractVector{L}} , batch_label :: L , size_dict:: Dict{L} ) where {T, L}
140185 OMEinsum. isleaf (ne) && return
141186 updated = false
142187 for (subcode, child, ix) in zip (ne. args, cache. children, getixsv (ne. eins))
143188 if any (x-> x ∈ el. first, ix)
144189 updated = true
145- child. content = eliminate_dimensions (child. content, ix, el)
146- udpate_cache_tree! (subcode, child, el, size_dict)
190+ child. content = _eliminate! (child. content, ix, el, batch_label )
191+ udpate_cache_tree! (subcode, child, el, batch_label, size_dict)
147192 end
148193 end
149194 updated && (cache. content = einsum (ne. eins, (getfield .(cache. children, :content )... ,), size_dict))
0 commit comments