Skip to content

Commit 24987de

Browse files
committed
refactor interfaces
1 parent 7c4c28e commit 24987de

File tree

9 files changed

+217
-190
lines changed

9 files changed

+217
-190
lines changed

src/arithematics.jl

Lines changed: 58 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
export is_commutative_semiring
2-
export Max2Poly, Polynomial, Tropical, CountingTropical, ConfigTropical, StaticBitVector, Mod, ConfigEnumerator
2+
export Max2Poly, Polynomial, Tropical, CountingTropical, StaticBitVector, Mod, ConfigEnumerator, onehotv
33

44
using Polynomials: Polynomial
5-
using TropicalNumbers: Tropical, CountingTropical, ConfigTropical, StaticBitVector
5+
using TropicalNumbers: Tropical, CountingTropical, StaticBitVector
66
using Mods, Primes
77

8+
# patch for Tropical numbers
9+
Base.isnan(x::Tropical) = isnan(x.n)
10+
811
# pirate
912
Base.abs(x::Mod) = x
1013
Base.isless(x::Mod{N}, y::Mod{N}) where N = mod(x.val, N) < mod(y.val, N)
@@ -120,10 +123,61 @@ Base.one(::ConfigEnumerator{N,C}) where {N,C} = one(ConfigEnumerator{N,C})
120123
Base.show(io::IO, x::ConfigEnumerator) = print(io, "{", join(x.data, ", "), "}")
121124
Base.show(io::IO, ::MIME"text/plain", x::ConfigEnumerator) = Base.show(io, x)
122125

123-
# patch
126+
# the algebra sampling one of the configurations
127+
struct ConfigSampler{N,C}
128+
data::StaticBitVector{N,C}
129+
end
130+
131+
Base.:(==)(x::ConfigSampler{N,C}, y::ConfigSampler{N,C}) where {N,C} = x.data == y.data
132+
133+
function Base.:+(x::ConfigSampler{N,C}, y::ConfigSampler{N,C}) where {N,C} # biased sampling: return `x`, maybe using random sampler is better.
134+
return x
135+
end
136+
137+
function Base.:*(x::ConfigSampler{L,C}, y::ConfigSampler{L,C}) where {L,C}
138+
ConfigSampler(x.data | y.data)
139+
end
140+
141+
Base.zero(::Type{ConfigSampler{N,C}}) where {N,C} = ConfigSampler{N,C}(TropicalNumbers.statictrues(StaticBitVector{N,C}))
142+
Base.one(::Type{ConfigSampler{N,C}}) where {N,C} = ConfigSampler{N,C}(TropicalNumbers.staticfalses(StaticBitVector{N,C}))
143+
Base.zero(::ConfigSampler{N,C}) where {N,C} = zero(ConfigSampler{N,C})
144+
Base.one(::ConfigSampler{N,C}) where {N,C} = one(ConfigSampler{N,C})
124145

146+
# A patch to make `Polynomial{ConfigEnumerator}` work
125147
function Base.:*(a::Int, y::ConfigEnumerator)
126148
a == 0 && return zero(y)
127149
a == 1 && return y
128150
error("multiplication between int and config enumerator is not defined.")
129-
end
151+
end
152+
153+
# convert from counting type to bitstring type
154+
for (F,TP) in [(:bitstringset_type, :ConfigEnumerator), (:bitstringsampler_type, :ConfigSampler)]
155+
@eval begin
156+
function $F(::Type{T}, n::Int) where {T<:Max2Poly}
157+
Max2Poly{$F(n)}
158+
end
159+
function $F(::Type{T}, n::Int) where {TX, T<:Polynomial{C,TX} where C}
160+
Polynomial{$F(n),:x}
161+
end
162+
function $F(::Type{T}, n::Int) where {TV, T<:CountingTropical{TV}}
163+
CountingTropical{TV, $F(n)}
164+
end
165+
function $F(n::Integer)
166+
C = TropicalNumbers._nints(n)
167+
return $TP{n, C}
168+
end
169+
end
170+
end
171+
172+
# utilities for creating onehot vectors
173+
function onehotv(::Type{Polynomial{BS,X}}, x) where {BS,X}
174+
Polynomial{BS,X}([zero(BS), onehotv(BS, x)])
175+
end
176+
function onehotv(::Type{Max2Poly{BS}}, x) where {BS}
177+
Max2Poly{BS}(zero(BS), onehotv(BS, x),1)
178+
end
179+
function onehotv(::Type{CountingTropical{TV,BS}}, x) where {TV,BS}
180+
CountingTropical{TV,BS}(one(TV), onehotv(BS, x))
181+
end
182+
onehotv(::Type{ConfigEnumerator{N,C}}, i::Integer) where {N,C} = ConfigEnumerator([TropicalNumbers.onehot(StaticBitVector{N,C}, i)])
183+
onehotv(::Type{ConfigSampler{N,C}}, i::Integer) where {N,C} = ConfigSampler(TropicalNumbers.onehot(StaticBitVector{N,C}, i))

src/bounding.jl

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,16 @@ using TupleTools
22

33
export bounding_contract
44

5-
Base.isnan(x::Tropical) = isnan(x.n)
5+
"""
6+
backward_tropical(mode, ixs, xs, iy, y, ymask, size_dict)
7+
8+
The backward rule for tropical einsum.
9+
* `mode` can be one of `:all` and `:single`,
10+
* `ixs` and `xs` are labels and tensor data for input tensors,
11+
* `iy` and `y` are labels and tensor data for the output tensor,
12+
* `ymask` is the boolean mask for gradients,
13+
* `size_dict` is a key-value map from tensor label to dimension size.
14+
"""
615
function backward_tropical(mode, @nospecialize(ixs), @nospecialize(xs), @nospecialize(iy), @nospecialize(y), @nospecialize(ymask), size_dict)
716
y .= inv.(y) .* ymask
817
masks = []
@@ -25,6 +34,7 @@ function backward_tropical(mode, @nospecialize(ixs), @nospecialize(xs), @nospeci
2534
return masks
2635
end
2736

37+
# one of the entry in `A` that equal to the corresponding entry in `X` is masked to true.
2838
function onehotmask(A::AbstractArray{T}, X::AbstractArray{T}) where T
2939
@assert length(A) == length(X)
3040
mask = falses(size(A)...)
@@ -40,6 +50,7 @@ function onehotmask(A::AbstractArray{T}, X::AbstractArray{T}) where T
4050
return mask
4151
end
4252

53+
# the data structure storing intermediate `NestedEinsum` contraction results.
4354
struct CacheTree{T}
4455
content::AbstractArray{T}
4556
siblings::Vector{CacheTree{T}}
@@ -50,10 +61,11 @@ function cached_einsum(code::Int, @nospecialize(xs), size_dict)
5061
end
5162
function cached_einsum(code::NestedEinsum, @nospecialize(xs), size_dict)
5263
caches = [cached_einsum(arg, xs, size_dict) for arg in code.args]
53-
y = einsum(code.eins, (getfield.(caches, :content)...,), size_dict)
64+
y = dynamic_einsum(code.eins, (getfield.(caches, :content)...,); size_info=size_dict)
5465
CacheTree(y, caches)
5566
end
5667

68+
# computed mask tree by back propagation
5769
function generate_masktree(code::Int, cache, mask, size_dict, mode=:all)
5870
CacheTree(mask, CacheTree{Bool}[])
5971
end
@@ -62,6 +74,7 @@ function generate_masktree(code::NestedEinsum, cache, mask, size_dict, mode=:all
6274
return CacheTree(mask, generate_masktree.(code.args, cache.siblings, submasks, Ref(size_dict), mode))
6375
end
6476

77+
# The masked einsum contraction
6578
function masked_einsum(code::Int, @nospecialize(xs), masks, size_dict)
6679
y = copy(xs[code])
6780
y[OMEinsum.asarray(.!masks.content)] .= Ref(zero(eltype(y))); y

src/configurations.jl

Lines changed: 25 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -1,97 +1,39 @@
1-
export mis_config, ConfigEnumerator, ConfigTropical
1+
export getconfigs_bounded, getconfigs_direct
22

3-
function symbols(::EinCode{ixs}) where ixs
4-
res = []
5-
for ix in ixs
6-
for l in ix
7-
if l res
8-
push!(res, l)
9-
end
10-
end
11-
end
12-
return res
13-
end
14-
15-
function mis_config(code; all=false, bounding=true, usecuda=false)
3+
function getconfigs_bounded(gp::GraphProblem; all=false, usecuda=false)
164
if all && usecuda
175
throw(ArgumentError("ConfigEnumerator can not be computed on GPU!"))
186
end
19-
flatten_code = flatten(code)
20-
syms = unique(Iterators.flatten(filter(x->length(x)==1,OMEinsum.getixs(flatten_code))))
7+
T = (all ? bitstringset_type : bitstringsampler_type)(CountingTropical{Int64}, length(labels(gp.code)))
8+
syms = labels(gp.code)
219
vertex_index = Dict([s=>i for (i, s) in enumerate(syms)])
22-
N = length(vertex_index)
23-
C = TropicalNumbers._nints(N)
24-
xs = map(getixs(flatten_code)) do ix
25-
T = all ? CountingTropical{Float64, ConfigEnumerator{N,C}} : ConfigTropical{Float64, N, C}
26-
if length(ix) == 2
27-
return misb(T)
28-
else
29-
s = TropicalNumbers.onehot(StaticBitVector{N,C}, vertex_index[ix[1]])
30-
if all
31-
misv(T, T(1.0, ConfigEnumerator([s])))
32-
else
33-
misv(T, T(1.0, s))
34-
end
35-
end
10+
xst = generate_tensors(l->TropicalF64(1.0), gp)
11+
ymask = trues(fill(2, length(OMEinsum.getiy(flatten(gp.code))))...)
12+
if usecuda
13+
xst = CuArray.(xst)
14+
ymask = CuArray(ymask)
3615
end
37-
if bounding
38-
ymask = trues(fill(2, length(getiy(flatten_code)))...)
39-
xst = map(getixs(flatten_code)) do ix
40-
length(ix) == 1 ? misv(TropicalF64,Tropical(1.0)) : misb(TropicalF64)
41-
end
42-
if usecuda
43-
ymask = CuArray(ymask)
44-
xst = CuArray.(xst)
45-
end
46-
if all
47-
return bounding_contract(code, xst, ymask, xs)
48-
else
49-
@assert ndims(ymask) == 0
50-
t, res = mis_config_ad(code, xst, ymask)
51-
return fill(ConfigTropical(asscalar(t).n, StaticBitVector(map(l->res[l], 1:N))))
52-
end
16+
if all
17+
xs = generate_tensors(l->onehotv(T, vertex_index[l]), gp)
18+
return bounding_contract(gp.code, xst, ymask, xs)
5319
else
54-
if usecuda
55-
xs = CuArray.(xs)
56-
end
57-
return dynamic_einsum(code, xs)
20+
@assert ndims(ymask) == 0
21+
t, res = mis_config_ad(gp.code, xst, ymask)
22+
N = length(vertex_index)
23+
return fill(CountingTropical(asscalar(t).n, ConfigSampler(StaticBitVector(map(l->res[l], 1:N)))))
5824
end
5925
end
6026

61-
export mis_max2_config
62-
function mis_max2_config(code)
63-
flatten_code = flatten(code)
64-
syms = unique(Iterators.flatten(filter(x->length(x)==1,OMEinsum.getixs(flatten_code))))
65-
vertex_index = Dict([s=>i for (i, s) in enumerate(syms)])
66-
N = length(vertex_index)
67-
C = TropicalNumbers._nints(N)
68-
xs = map(getixs(flatten_code)) do ix
69-
T = Max2Poly{ConfigEnumerator{N,C}}
70-
if length(ix) == 2
71-
return misb(T)
72-
else
73-
s = TropicalNumbers.onehot(StaticBitVector{N,C}, vertex_index[ix[1]])
74-
misv(T, Max2Poly(zero(ConfigEnumerator{N,C}), ConfigEnumerator([s]), 1.0))
75-
end
27+
function getconfigs_direct(gp::GraphProblem; all=false, usecuda=false)
28+
if all && usecuda
29+
throw(ArgumentError("ConfigEnumerator can not be computed on GPU!"))
7630
end
77-
return dynamic_einsum(code, xs)
78-
end
79-
80-
export all_config
81-
function all_config(code)
82-
flatten_code = flatten(code)
83-
syms = unique(Iterators.flatten(filter(x->length(x)==1,OMEinsum.getixs(flatten_code))))
31+
T = (all ? bitstringset_type : bitstringsampler_type)(CountingTropical{Int64}, length(labels(gp.code)))
32+
syms = labels(gp.code)
8433
vertex_index = Dict([s=>i for (i, s) in enumerate(syms)])
85-
N = length(vertex_index)
86-
C = TropicalNumbers._nints(N)
87-
xs = map(getixs(flatten_code)) do ix
88-
T = Polynomial{ConfigEnumerator{N,C}, :x}
89-
if length(ix) == 2
90-
return misb(T)
91-
else
92-
s = TropicalNumbers.onehot(StaticBitVector{N,C}, vertex_index[ix[1]])
93-
misv(T, Polynomial([zero(ConfigEnumerator{N,C}), ConfigEnumerator([s])]))
94-
end
34+
xs = generate_tensors(l->onehotv(T, vertex_index[l]), gp)
35+
if usecuda
36+
xs = CuArray.(xs)
9537
end
96-
return dynamic_einsum(code, xs)
38+
dynamic_einsum(gp.code, xs)
9739
end

0 commit comments

Comments
 (0)