Skip to content

Commit 1568a63

Browse files
committed
update
1 parent c105012 commit 1568a63

File tree

4 files changed

+15
-10
lines changed

4 files changed

+15
-10
lines changed

src/TensorInference.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ export sample
3636
export MMAPModel
3737

3838
# utils
39-
export matrix_product_state
39+
export random_matrix_product_state
4040

4141
include("Core.jl")
4242
include("RescaledArray.jl")

src/sampling.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,10 @@ Returns a vector of vector, each element being a configurations defined on `get_
5050
### Arguments
5151
* `tn` is the tensor network model.
5252
* `n` is the number of samples to be returned.
53+
54+
### Keyword Arguments
55+
* `usecuda` is a boolean flag to indicate whether to use CUDA for tensor computation.
56+
* `queryvars` is the variables to be sampled, default is `get_vars(tn)`.
5357
"""
5458
function sample(tn::TensorNetworkModel, n::Int; usecuda = false, queryvars = get_vars(tn))::Samples
5559
# generate tropical tensors with its elements being log(p).
@@ -80,7 +84,7 @@ function sample(tn::TensorNetworkModel, n::Int; usecuda = false, queryvars = get
8084
end
8185
_Weights(x::AbstractVector{<:Real}) = Weights(x)
8286
function _Weights(x::AbstractArray{<:Complex})
83-
@assert all(e->abs(imag(e)) < 100*eps(abs(e)), x) "Complex probability encountered: $x"
87+
@assert all(e->abs(imag(e)) < max(100*eps(abs(e)), 1e-10), x) "Complex probability encountered: $x"
8488
return Weights(real.(x))
8589
end
8690

src/utils.jl

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -331,19 +331,22 @@ connected in a chain.
331331
- `chi` is the bond dimension of the virtual indices.
332332
- `d` is the dimension of the physical indices.
333333
"""
334-
function matrix_product_state(n::Int, chi::Int, d::Int=2)
335-
tensors = Any[randn(ComplexF64, d, chi)]
334+
random_matrix_product_state(n::Int, chi::Int, d::Int=2) = random_matrix_product_state(ComplexF64, n, chi, d)
335+
function random_matrix_product_state(::Type{T}, n::Int, chi::Int, d::Int=2) where T
336+
# chi ^ (n-1) * (variance^n)^2 == 1/d^n
337+
variance = d^(-1/2) * chi^(-1/2+1/2n)
338+
tensors = Any[randn(T, d, chi) .* variance]
336339
physical_indices = collect(1:n)
337340
virtual_indices_ket = collect(n+1:2n-1)
338341
virtual_indices_bra = collect(2n:3n-2)
339342
ixs_ket = [[physical_indices[1], virtual_indices_ket[1]]]
340343
ixs_bra = [[physical_indices[1], virtual_indices_bra[1]]]
341344
for i = 2:n-1
342-
push!(tensors, randn(ComplexF64, chi, d, chi))
345+
push!(tensors, randn(T, chi, d, chi) .* variance)
343346
push!(ixs_ket, [virtual_indices_ket[i-1], physical_indices[i], virtual_indices_ket[i]])
344347
push!(ixs_bra, [virtual_indices_bra[i-1], physical_indices[i], virtual_indices_bra[i]])
345348
end
346-
push!(tensors, randn(ComplexF64, chi, d))
349+
push!(tensors, randn(T, chi, d) .* variance)
347350
push!(ixs_ket, [virtual_indices_ket[n-1], physical_indices[n]])
348351
push!(ixs_bra, [virtual_indices_bra[n-1], physical_indices[n]])
349352
tensors, ixs = [tensors..., conj.(tensors)...], [ixs_ket..., ixs_bra...]

test/sampling.jl

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -66,11 +66,9 @@ end
6666
@testset "sample MPS" begin
6767
n = 4
6868
chi = 3
69-
mps = matrix_product_state(n, chi)
69+
mps = random_matrix_product_state(n, chi)
7070
num_samples = 1000
71-
samples = map(1:num_samples) do i
72-
sample(mps, 1; queryvars=vcat(mps.mars...)).samples[:, 1]
73-
end
71+
sample(mps, num_samples; queryvars=vcat(mps.mars...))
7472
indices = map(samples) do sample
7573
sum(i->sample[i] * 2^(i-1), 1:4) + 1
7674
end

0 commit comments

Comments
 (0)