|
| 1 | +""" |
| 2 | + save_configs(filename, data::ConfigEnumerator; format=:binary) |
| 3 | +
|
| 4 | +Save configurations `data` to file `filename`. The format is `:binary` or `:text`. |
| 5 | +""" |
| 6 | +function save_configs(filename, data::ConfigEnumerator{N,S,C}; format::Symbol=:binary) where {N,S,C} |
| 7 | + if format == :binary |
| 8 | + write(filename, raw_matrix(data)) |
| 9 | + elseif format == :text |
| 10 | + writedlm(filename, plain_matrix(data)) |
| 11 | + else |
| 12 | + error("format must be `:binary` or `:text`, got `:$format`") |
| 13 | + end |
| 14 | +end |
| 15 | + |
| 16 | +""" |
| 17 | + load_configs(filename; format=:binary, bitlength=nothing, nflavors=2) |
| 18 | +
|
| 19 | +Load configurations from file `filename`. The format is `:binary` or `:text`. |
| 20 | +If the format is `:binary`, the bitstring length `bitlength` must be specified, |
| 21 | +`nflavors` specifies the degree of freedom. |
| 22 | +""" |
| 23 | +function load_configs(filename; bitlength=nothing, format::Symbol=:binary, nflavors=2) |
| 24 | + if format == :binary |
| 25 | + bitlength === nothing && error("you need to specify `bitlength` for reading configurations from binary files.") |
| 26 | + S = ceil(Int, log2(nflavors)) |
| 27 | + C = _nints(bitlength, S) |
| 28 | + return _from_raw_matrix(StaticElementVector{bitlength,S,C}, reshape(reinterpret(UInt64, read(filename)),C,:)) |
| 29 | + elseif format == :text |
| 30 | + return from_plain_matrix(readdlm(filename); nflavors=nflavors) |
| 31 | + else |
| 32 | + error("format must be `:binary` or `:text`, got `:$format`") |
| 33 | + end |
| 34 | +end |
| 35 | + |
| 36 | +function raw_matrix(x::ConfigEnumerator{N,S,C}) where {N,S,C} |
| 37 | + m = zeros(UInt64, C, length(x)) |
| 38 | + @inbounds for i=1:length(x), j=1:C |
| 39 | + m[j,i] = x.data[i].data[j] |
| 40 | + end |
| 41 | + return m |
| 42 | +end |
| 43 | +function plain_matrix(x::ConfigEnumerator{N,S,C}) where {N,S,C} |
| 44 | + m = zeros(UInt8, N, length(x)) |
| 45 | + @inbounds for i=1:length(x), j=1:N |
| 46 | + m[j,i] = x.data[i][j] |
| 47 | + end |
| 48 | + return m |
| 49 | +end |
| 50 | + |
| 51 | +function from_raw_matrix(m; bitlength, nflavors=2) |
| 52 | + S = ceil(Int,log2(nflavors)) |
| 53 | + C = size(m, 1) |
| 54 | + T = StaticElementVector{bitlength,S,C} |
| 55 | + @assert bitlength*S <= C*64 |
| 56 | + _from_raw_matrix(T, m) |
| 57 | +end |
| 58 | +function _from_raw_matrix(::Type{StaticElementVector{N,S,C}}, m::AbstractMatrix) where {N,S,C} |
| 59 | + data = zeros(StaticElementVector{N,S,C}, size(m, 2)) |
| 60 | + @inbounds for i=1:size(m, 2) |
| 61 | + data[i] = StaticElementVector{N,S,C}(NTuple{C,UInt64}(view(m,:,i))) |
| 62 | + end |
| 63 | + return ConfigEnumerator(data) |
| 64 | +end |
| 65 | +function from_plain_matrix(m::Matrix; nflavors=2) |
| 66 | + S = ceil(Int,log2(nflavors)) |
| 67 | + N = size(m, 1) |
| 68 | + C = _nints(N, S) |
| 69 | + T = StaticElementVector{N,S,C} |
| 70 | + _from_plain_matrix(T, m) |
| 71 | +end |
| 72 | +function _from_plain_matrix(::Type{StaticElementVector{N,S,C}}, m::AbstractMatrix) where {N,S,C} |
| 73 | + data = zeros(StaticElementVector{N,S,C}, size(m, 2)) |
| 74 | + @inbounds for i=1:size(m, 2) |
| 75 | + data[i] = convert(StaticElementVector{N,S,C}, view(m, :, i)) |
| 76 | + end |
| 77 | + return ConfigEnumerator(data) |
| 78 | +end |
| 79 | + |
| 80 | +# convert to Matrix |
| 81 | +Base.Matrix(ce::ConfigEnumerator) = plain_matrix(ce) |
| 82 | +Base.Vector(ce::StaticElementVector) = collect(ce) |
| 83 | + |
| 84 | +########## saving tree #################### |
| 85 | +""" |
| 86 | + save_sumproduct(filename, t::SumProductTree) |
| 87 | +
|
| 88 | +Serialize a sum-product tree into a file. |
| 89 | +""" |
| 90 | +save_sumproduct(filename::String, t::SumProductTree) = serialize(filename, dict_serialize_tree!(t, Dict{UInt,Any}())) |
| 91 | + |
| 92 | +""" |
| 93 | + load_sumproduct(filename) |
| 94 | +
|
| 95 | +Deserialize a sum-product tree from a file. |
| 96 | +""" |
| 97 | +load_sumproduct(filename::String) = dict_deserialize_tree(deserialize(filename)...) |
| 98 | + |
| 99 | +function dict_serialize_tree!(t::SumProductTree, d::Dict) |
| 100 | + id = objectid(t) |
| 101 | + if !haskey(d, id) |
| 102 | + if t.tag === GraphTensorNetworks.LEAF || t.tag === GraphTensorNetworks.ZERO || t.tag == GraphTensorNetworks.ONE |
| 103 | + d[id] = t |
| 104 | + else |
| 105 | + d[id] = (t.tag, objectid(t.left), objectid(t.right)) |
| 106 | + dict_serialize_tree!(t.left, d) |
| 107 | + dict_serialize_tree!(t.right, d) |
| 108 | + end |
| 109 | + end |
| 110 | + return id, d |
| 111 | +end |
| 112 | + |
| 113 | +function dict_deserialize_tree(id::UInt, d::Dict) |
| 114 | + @assert haskey(d, id) |
| 115 | + content = d[id] |
| 116 | + if content isa SumProductTree |
| 117 | + return content |
| 118 | + else |
| 119 | + (tag, left, right) = content |
| 120 | + t = SumProductTree(tag, dict_deserialize_tree(left, d), dict_deserialize_tree(right, d)) |
| 121 | + d[id] = t |
| 122 | + return t |
| 123 | + end |
| 124 | +end |
| 125 | + |
0 commit comments