@@ -7,20 +7,10 @@ $TYPEDFIELDS
77
88The sampled configurations are stored in `samples`, which is a vector of vector.
99`labels` is a vector of variable names for labeling configurations.
10- The `setmask` is an boolean indicator to denote whether the sampling process of a variable is complete.
1110"""
1211struct Samples{L} <: AbstractVector {SubArray{Float64, 1 , Matrix{Float64}, Tuple{Base. Slice{Base. OneTo{Int64}}, Int64}, true }}
1312 samples:: Matrix{Int} # size is nvars × nsample
1413 labels:: Vector{L}
15- setmask:: BitVector
16- end
17- function set_eliminated! (samples:: Samples , eliminated_variables)
18- for var in eliminated_variables
19- loc = findfirst (== (var), samples. labels)
20- samples. setmask[loc] && error (" varaible `$var ` is already eliminated." )
21- samples. setmask[loc] = true
22- end
23- return samples
2414end
2515Base. getindex (s:: Samples , i:: Int ) = view (s. samples, :, i)
2616Base. length (s:: Samples ) = size (s. samples, 2 )
@@ -30,16 +20,12 @@ function Base.show(io::IO, s::Samples) # display with PrettyTables
3020 PrettyTables. pretty_table (io, s. samples' , header= s. labels)
3121end
3222num_samples (samples:: Samples ) = size (samples. samples, 2 )
33- eliminated_variables (samples:: Samples ) = samples. labels[samples. setmask]
34- is_eliminated (samples:: Samples{L} , var:: L ) where L = samples. setmask[findfirst (== (var), samples. labels)]
3523function idx4labels (totalset:: AbstractVector{L} , labels:: AbstractVector{L} ):: Vector{Int} where L
3624 map (v-> findfirst (== (v), totalset), labels)
3725end
38- idx4labels (samples:: Samples{L} , lb:: L ) where L = findfirst (== (lb), samples. labels)
3926function subset (samples:: Samples{L} , labels:: AbstractVector{L} ) where L
4027 idx = idx4labels (samples. labels, labels)
41- @assert all (i-> samples. setmask[i], idx)
42- return samples. samples[idx, :]
28+ return view (samples. samples, idx, :)
4329end
4430
4531function eliminate_dimensions (x:: AbstractArray{T, N} , ix:: AbstractVector{L} , el:: Pair{<:AbstractVector{L}} ) where {T, N, L}
@@ -55,40 +41,6 @@ function eliminate_dimensions(x::AbstractArray{T, N}, ix::AbstractVector{L}, el:
5541 return asarray (x[idx... ], x)
5642end
5743
58- function addbatch (samples:: Samples , eliminated_variables)
59- uniquelabels = unique! (vcat (ixs... , iy))
60- labelmap = Dict (zip (uniquelabels, 1 : length (uniquelabels)))
61- batchdim = length (labelmap) + 1
62- newnewixs = [Int[getindex .(Ref (labelmap), ix)... , batchdim] for ix in newixs]
63- newnewiy = Int[getindex .(Ref (labelmap), eliminated_variables)... , batchdim]
64- newnewxs = [get_slice (x, dimx, samples. samples[ixloc, :]) for (x, dimx, ixloc) in zip (xs, slice_xs_dim, ix_in_sample)]
65- end
66-
67- # type unstable
68- function get_slice (x:: AbstractArray{T} , slicedim, configs:: AbstractMatrix ) where T
69- outdim = setdiff (1 : ndims (x), slicedim)
70- res = similar (x, [size (x, d) for d in outdim]. .. , size (configs, 2 ))
71- return get_slice! (res, x, outdim, slicedim, configs)
72- end
73-
74- function get_slice! (res, x:: AbstractArray{T} , outdim, slicedim, configs:: AbstractMatrix ) where T
75- xstrides = strides (x)
76- @inbounds for ci in CartesianIndices (res)
77- idx = 1
78- # the output dimension part
79- for (dim, k) in zip (outdim, ci. I)
80- idx += (k- 1 ) * xstrides[dim]
81- end
82- # the sliced part
83- batchidx = ci. I[end ]
84- for (dim, k) in zip (slicedim, view (configs, :, batchidx))
85- idx += k * xstrides[dim]
86- end
87- res[ci] = x[idx]
88- end
89- return res
90- end
91-
9244"""
9345$(TYPEDSIGNATURES)
9446
@@ -108,15 +60,13 @@ function sample(tn::TensorNetworkModel, n::Int; usecuda = false, queryvars = get
10860 cache = cached_einsum (tn. code, xs, size_dict)
10961 # initialize `y̅` as the initial batch of samples.
11062 iy = getiyv (tn. code)
111- setmask = falses (length (queryvars))
11263 idx = map (l-> findfirst (== (l), queryvars), iy ∩ queryvars)
113- setmask[idx] .= true
11464 indices = StatsBase. sample (CartesianIndices (size (cache. content)), _Weights (vec (cache. content)), n)
11565 configs = zeros (Int, length (queryvars), n)
11666 for i= 1 : n
11767 configs[idx, i] .= indices[i]. I .- 1
11868 end
119- samples = Samples (configs, queryvars, setmask )
69+ samples = Samples (configs, queryvars)
12070 # back-propagate
12171 env = copy (cache. content)
12272 fill! (env, one (eltype (env)))
@@ -157,12 +107,12 @@ function generate_samples!(code::NestedEinsum, cache::CacheTree{T}, env::Abstrac
157107 sample_vars = ix ∩ pool
158108 probabilities = einsum (DynamicEinCode ([ix, ix], sample_vars), (child. content, subenv), size_dict)
159109 update_samples! (samples, sample_vars, probabilities)
160- pool = setdiff (pool, sample_vars)
110+ setdiff! (pool, sample_vars)
161111
162112 # eliminate the sampled variables
163113 setindex! .(Ref (size_dict), 1 , sample_vars)
164114 subsamples = subset (samples, sample_vars)[:, 1 ]
165- udpate_cache_tree! (code, cache, sample_vars=> subsamples)
115+ udpate_cache_tree! (code, cache, sample_vars=> subsamples, size_dict )
166116 subenv = eliminate_dimensions (subenv, ix, sample_vars=> subsamples)
167117
168118 # recurse
@@ -178,13 +128,15 @@ function update_samples!(samples::Samples, vars::AbstractVector{L}, probabilitie
178128 eliminated_locs = idx4labels (samples. labels, vars)
179129 config = StatsBase. sample (totalset, _Weights (vec (probabilities)))
180130 samples. samples[eliminated_locs, 1 ] .= config. I .- 1
181- set_eliminated! (samples, vars)
182131end
183132
184- function udpate_cache_tree! (ne:: NestedEinsum , cache:: CacheTree{T} , el:: Pair{<:AbstractVector{L}} ) where {T, L}
133+ function udpate_cache_tree! (ne:: NestedEinsum , cache:: CacheTree{T} , el:: Pair{<:AbstractVector{L}} , size_dict :: Dict{L} ) where {T, L}
185134 OMEinsum. isleaf (ne) && return
135+ updated = false
186136 for (subcode, child, ix) in zip (ne. args, cache. children, getixsv (ne. eins))
137+ updated = updated || any (x-> x ∈ el. first, ix)
187138 child. content = eliminate_dimensions (child. content, ix, el)
188- udpate_cache_tree! (subcode, child, el)
139+ udpate_cache_tree! (subcode, child, el, size_dict )
189140 end
141+ updated && (cache. content = einsum (ne. eins, (getfield .(cache. children, :content )... ,), size_dict))
190142end
0 commit comments