@@ -32,6 +32,10 @@ function eliminate_dimensions(x::AbstractArray{T, N}, ix::AbstractVector{L}, el:
3232 @assert length (ix) == N
3333 return x[eliminated_selector (size (x), ix, el. first, el. second)... ]
3434end
35+ function eliminate_dimensions (x:: RescaledArray{T, N} , ix:: AbstractVector{L} , el:: Pair{<:AbstractVector{L}, <:AbstractVector} ) where {T, N, L}
36+ return RescaledArray (x. log_factor, eliminate_dimensions (x. normalized_value, ix, el))
37+ end
38+
3539function eliminated_size (size0, ix, labels)
3640 @assert length (size0) == length (ix)
3741 return ntuple (length (ix)) do i
@@ -53,7 +57,7 @@ function eliminate_dimensions_addbatch!(x::AbstractArray{T, N}, ix::AbstractVect
5357 @assert length (ix) == N
5458 res = similar (x, (eliminated_size (size (x), ix, el. first)... , nbatch))
5559 for ibatch in 1 : nbatch
56- selectdim (res, N+ 1 , ibatch) . = eliminate_dimensions (x, ix, el. first=> view (el. second, :, ibatch))
60+ copyto! ( selectdim (res, N+ 1 , ibatch), eliminate_dimensions (x, ix, el. first=> view (el. second, :, ibatch) ))
5761 end
5862 push! (ix, batch_label)
5963 return res
@@ -63,7 +67,7 @@ function eliminate_dimensions_withbatch(x::AbstractArray{T, N}, ix::AbstractVect
6367 @assert length (ix) == N && size (x, N) == nbatch
6468 res = similar (x, (eliminated_size (size (x), ix, el. first)))
6569 for ibatch in 1 : nbatch
66- selectdim (res, N, ibatch) . = eliminate_dimensions (selectdim (x, N, ibatch), ix[1 : end - 1 ], el. first=> view (el. second, :, ibatch))
70+ copyto! ( selectdim (res, N, ibatch), eliminate_dimensions (selectdim (x, N, ibatch), ix[1 : end - 1 ], el. first=> view (el. second, :, ibatch) ))
6771 end
6872 return res
6973end
@@ -79,28 +83,28 @@ Returns a vector of vector, each element being a configurations defined on `get_
7983* `n` is the number of samples to be returned.
8084
8185### Keyword Arguments
86+ * `rescale` is a boolean flag to indicate whether to rescale the tensors during contraction.
8287* `usecuda` is a boolean flag to indicate whether to use CUDA for tensor computation.
8388* `queryvars` is the variables to be sampled, default is `get_vars(tn)`.
8489"""
85- function sample (tn:: TensorNetworkModel , n:: Int ; usecuda = false , queryvars = get_vars (tn)):: Samples
90+ function sample (tn:: TensorNetworkModel , n:: Int ; usecuda = false , queryvars = get_vars (tn), rescale :: Bool = false ):: Samples
8691 # generate tropical tensors with its elements being log(p).
87- xs = adapt_tensors (tn; usecuda, rescale = false )
92+ xs = adapt_tensors (tn; usecuda, rescale)
8893 # infer size from the contraction code and the input tensors `xs`, returns a label-size dictionary.
8994 size_dict = OMEinsum. get_size_dict! (getixsv (tn. code), xs, Dict {Int, Int} ())
9095 # forward compute and cache intermediate results.
9196 cache = cached_einsum (tn. code, xs, size_dict)
9297 # initialize `y̅` as the initial batch of samples.
9398 iy = getiyv (tn. code)
9499 idx = map (l-> findfirst (== (l), queryvars), iy ∩ queryvars)
95- indices = StatsBase. sample (CartesianIndices (size (cache. content)), _Weights ( vec ( cache. content) ), n)
100+ indices = StatsBase. sample (CartesianIndices (size (cache. content)), _weight ( cache. content), n)
96101 configs = zeros (Int, length (queryvars), n)
97102 for i= 1 : n
98103 configs[idx, i] .= indices[i]. I .- 1
99104 end
100105 samples = Samples (configs, queryvars)
101106 # back-propagate
102- env = similar (cache. content, (size (cache. content)... , n)) # batched env
103- fill! (env, one (eltype (env)))
107+ env = ones_like (cache. content, n)
104108 batch_label = _newindex (OMEinsum. uniquelabels (tn. code))
105109 code = deepcopy (tn. code)
106110 iy_env = [OMEinsum. getiyv (code)... , batch_label]
@@ -115,10 +119,22 @@ function sample(tn::TensorNetworkModel, n::Int; usecuda = false, queryvars = get
115119end
116120_newindex (labels:: AbstractVector{<:Union{Int, Char}} ) = maximum (labels) + 1
117121_newindex (:: AbstractVector{Symbol} ) = gensym (:batch )
118- _Weights (x:: AbstractVector{<:Real} ) = Weights (x)
119- function _Weights (x:: AbstractArray{<:Complex} )
122+ _weight (x:: AbstractArray{<:Real} ) = Weights (_normvec (x))
123+ function _weight (_x:: AbstractArray{<:Complex} )
124+ x = _normvec (_x)
120125 @assert all (e-> abs (imag (e)) < max (100 * eps (abs (e)), 1e-8 ), x) " Complex probability encountered: $x "
121- return Weights (real .(x))
126+ return _weight (real .(x))
127+ end
128+ _normvec (x:: AbstractArray ) = vec (x)
129+ _normvec (x:: RescaledArray ) = vec (x. normalized_value)
130+
131+ function ones_like (x:: AbstractArray{T} , n:: Int ) where {T}
132+ res = similar (x, (size (x)... , n))
133+ fill! (res, one (eltype (res)))
134+ return res
135+ end
136+ function ones_like (x:: RescaledArray , n:: Int )
137+ return RescaledArray (zero (x. log_factor), ones_like (x. normalized_value, n))
122138end
123139
124140function generate_samples! (se:: SlicedEinsum , cache:: CacheTree{T} , iy_env:: Vector{Int} , env:: AbstractArray{T} , samples:: Samples{L} , pool, batch_label:: L , size_dict:: Dict{L} ) where {T, L}
@@ -177,7 +193,7 @@ function update_samples!(labels, sample, vars::AbstractVector{L}, probabilities:
177193 @assert length (vars) == N
178194 totalset = CartesianIndices (probabilities)
179195 eliminated_locs = idx4labels (labels, vars)
180- config = StatsBase. sample (totalset, _Weights ( vec ( probabilities) ))
196+ config = StatsBase. sample (totalset, _weight ( probabilities))
181197 sample[eliminated_locs] .= config. I .- 1
182198end
183199
0 commit comments