-
Notifications
You must be signed in to change notification settings - Fork 66
Adding OMEinsumContractionOrders.jl as a backend of TensorOperations.jl for finding the optimal contraction order #185
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from 12 commits
e4b102d
d6bbe6f
6d61103
329952d
c14fa7c
76698c1
81d64af
3846242
5f442e2
70301a7
61d1cce
f55c1b2
92fe983
4adc9cf
a120620
d39c6ea
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,4 +1,6 @@ | ||
| *.jl.cov | ||
| *.jl.*.cov | ||
| *.jl.mem | ||
| Manifest.toml | ||
| Manifest.toml | ||
| .vscode | ||
| .DS_Store |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,113 @@ | ||
| module TensorOperationsOMEinsumContractionOrdersExt | ||
|
|
||
| using TensorOperations | ||
| using TensorOperations: TensorOperations as TO | ||
| using TensorOperations: TreeOptimizer | ||
| using OMEinsumContractionOrders | ||
| using OMEinsumContractionOrders: EinCode, NestedEinsum, SlicedEinsum, isleaf, | ||
| optimize_kahypar_auto | ||
|
|
||
| function TO.optimaltree(network, optdata::Dict{TDK,TDV}, ::TreeOptimizer{:GreedyMethod}, | ||
| verbose::Bool) where {TDK,TDV} | ||
| @debug "Using optimizer GreedyMethod from OMEinsumContractionOrders" | ||
| ome_optimizer = GreedyMethod() | ||
| return optimize(network, optdata, ome_optimizer, verbose) | ||
| end | ||
|
|
||
| function TO.optimaltree(network, optdata::Dict{TDK,TDV}, ::TreeOptimizer{:KaHyParBipartite}, | ||
| verbose::Bool) where {TDK,TDV} | ||
| @debug "Using optimizer KaHyParBipartite from OMEinsumContractionOrders" | ||
| return optimize_kahypar(network, optdata, verbose) | ||
| end | ||
|
|
||
| function TO.optimaltree(network, optdata::Dict{TDK,TDV}, ::TreeOptimizer{:TreeSA}, | ||
| verbose::Bool) where {TDK,TDV} | ||
| @debug "Using optimizer TreeSA from OMEinsumContractionOrders" | ||
| ome_optimizer = TreeSA() | ||
| return optimize(network, optdata, ome_optimizer, verbose) | ||
| end | ||
|
|
||
| function TO.optimaltree(network, optdata::Dict{TDK,TDV}, ::TreeOptimizer{:SABipartite}, | ||
| verbose::Bool) where {TDK,TDV} | ||
| @debug "Using optimizer SABipartite from OMEinsumContractionOrders" | ||
| ome_optimizer = SABipartite() | ||
| return optimize(network, optdata, ome_optimizer, verbose) | ||
| end | ||
|
|
||
| function TO.optimaltree(network, optdata::Dict{TDK,TDV}, ::TreeOptimizer{:ExactTreewidth}, | ||
| verbose::Bool) where {TDK,TDV} | ||
| @debug "Using optimizer ExactTreewidth from OMEinsumContractionOrders" | ||
| ome_optimizer = ExactTreewidth() | ||
| return optimize(network, optdata, ome_optimizer, verbose) | ||
| end | ||
|
|
||
| function optimize(network, optdata::Dict{TDK,TDV}, ome_optimizer::CodeOptimizer, | ||
| verbose::Bool) where {TDK,TDV} | ||
| @assert TDV <: Number "The values of `optdata` dictionary must be of `<:Number`" | ||
|
|
||
| # transform the network as EinCode | ||
| code, size_dict = network2eincode(network, optdata) | ||
| # optimize the contraction order using OMEinsumContractionOrders, which gives a NestedEinsum | ||
| optcode = optimize_code(code, size_dict, ome_optimizer) | ||
|
|
||
| # transform the optimized contraction order back to the network | ||
| optimaltree = eincode2contractiontree(optcode) | ||
|
|
||
| # calculate the complexity of the contraction | ||
| cc = OMEinsumContractionOrders.contraction_complexity(optcode, size_dict) | ||
| if verbose | ||
| println("Optimal contraction tree: ", optimaltree) | ||
| println(cc) | ||
| end | ||
| return optimaltree, 2.0^(cc.tc) | ||
| end | ||
|
|
||
| function optimize_kahypar(network, optdata::Dict{TDK,TDV}, verbose::Bool) where {TDK,TDV} | ||
| @assert TDV <: Number "The values of `optdata` dictionary must be of `<:Number`" | ||
|
|
||
| # transform the network as EinCode | ||
| code, size_dict = network2eincode(network, optdata) | ||
| # optimize the contraction order using OMEinsumContractionOrders, which gives a NestedEinsum | ||
| optcode = optimize_kahypar_auto(code, size_dict) | ||
|
|
||
| # transform the optimized contraction order back to the network | ||
| optimaltree = eincode2contractiontree(optcode) | ||
|
|
||
| # calculate the complexity of the contraction | ||
| cc = OMEinsumContractionOrders.contraction_complexity(optcode, size_dict) | ||
| if verbose | ||
| println("Optimal contraction tree: ", optimaltree) | ||
| println(cc) | ||
| end | ||
| return optimaltree, 2.0^(cc.tc) | ||
| end | ||
|
|
||
| function network2eincode(network, optdata) | ||
| indices = unique(vcat(network...)) | ||
| new_indices = Dict([i => j for (j, i) in enumerate(indices)]) | ||
| new_network = [Int[new_indices[i] for i in t] for t in network] | ||
| open_edges = Int[] | ||
| # if a indices appear only once, it is an open index | ||
| for i in indices | ||
| if sum([i in t for t in network]) == 1 | ||
| push!(open_edges, new_indices[i]) | ||
| end | ||
| end | ||
| size_dict = Dict([new_indices[i] => optdata[i] for i in keys(optdata)]) | ||
| return EinCode(new_network, open_edges), size_dict | ||
| end | ||
|
|
||
| function eincode2contractiontree(eincode::NestedEinsum) | ||
| if isleaf(eincode) | ||
| return eincode.tensorindex | ||
| else | ||
| return [eincode2contractiontree(arg) for arg in eincode.args] | ||
| end | ||
| end | ||
|
|
||
| # TreeSA returns a SlicedEinsum, with nslice = 0, so directly using the eins | ||
| function eincode2contractiontree(eincode::SlicedEinsum) | ||
| return eincode2contractiontree(eincode.eins) | ||
| end | ||
|
|
||
| end |
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -1,5 +1,6 @@ | ||||||
| """ | ||||||
| ncon(tensorlist, indexlist, [conjlist, sym]; order = ..., output = ..., backend = ..., allocator = ...) | ||||||
| ncon(tensorlist, indexlist, optimizer, conjlist; output=..., kwargs...) | ||||||
|
||||||
| ncon(tensorlist, indexlist, optimizer, conjlist; output=..., kwargs...) | |
| ncon(tensorlist, indexlist, conjlist, optimizer; output=..., kwargs...) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This may lead to wrong dispatch since conjlist has default value.
To avoid conflicts, I combine these two cases, given by
ncon(tensorlist, indexlist, [conjlist, sym]; order = ..., output = ..., optimizer = ..., backend = ..., allocator = ...)where the optimizer is now a kwargs, and we do not allow order and optimizer to be specified together.
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here, you probably want to use tensorstructure(tensors[i], j, conjlist[i]), as this is part of our interface. In particular, size might not always be defined for custom tensor types.
(https://github.com/Jutho/TensorOperations.jl/blob/2da3b048e656efa9ad451dc843b34a9de5465571/src/interface.jl#L188)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks you very much for pointing that out, I did not noticed that the conjlist may change the size of the tensor.
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here, it seems a bit strange to me to keep the optimizer as a Symbol. Would it not make more sense to immediately pass the optimizer itself?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Previously I want to use that to avoid exporting the structure TreeOptimizer.
Now I export it together with the optimizers
export TreeOptimizer, ExhaustiveSearchOptimizer, GreedyMethodOptimizer, KaHyParBipartiteOptimizer, TreeSAOptimizer, SABipartiteOptimizer, ExactTreewidthOptimizer
so that now we can directly pass the optimizer into the function.
Do you think it's a good idea?
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -71,6 +71,8 @@ function tensorparser(tensorexpr, kwargs...) | |
| end | ||
| end | ||
| # now handle the remaining keyword arguments | ||
| optimizer = TreeOptimizer{:ExhaustiveSearch}() # the default optimizer implemented in TensorOperations.jl | ||
| optval = nothing | ||
| for (name, val) in kwargs | ||
| if name == :order | ||
| isexpr(val, :tuple) || | ||
|
|
@@ -86,18 +88,29 @@ function tensorparser(tensorexpr, kwargs...) | |
| throw(ArgumentError("Invalid use of `costcheck`, should be `costcheck=warn` or `costcheck=cache`")) | ||
| parser.contractioncostcheck = val | ||
| elseif name == :opt | ||
| if val isa Bool && val | ||
| optdict = optdata(tensorexpr) | ||
| elseif val isa Expr | ||
| optdict = optdata(val, tensorexpr) | ||
| optval = val | ||
| elseif name == :opt_algorithm | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think here you will have to be a little careful, in principle there is no order to the keyword arguments. My best guess is that you probably want to attempt to extract an optimizer and optdict, and only after all kwargs have been parsed, you can construct the
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thank you very much for pointing that out, I did not notice that perviously. |
||
| if val isa Symbol | ||
| optimizer = TreeOptimizer{val}() | ||
| else | ||
| throw(ArgumentError("Invalid use of `opt`, should be `opt=true` or `opt=OptExpr`")) | ||
| throw(ArgumentError("Invalid use of `opt_algorithm`, should be `opt_algorithm=ExhaustiveSearch` or `opt_algorithm=NameOfAlgorithm`")) | ||
| end | ||
| parser.contractiontreebuilder = network -> optimaltree(network, optdict)[1] | ||
| elseif !(name == :backend || name == :allocator) # these two have been handled | ||
| throw(ArgumentError("Unknown keyword argument `name`.")) | ||
| end | ||
| end | ||
| # construct the contraction tree builder after all keyword arguments have been processed | ||
| if !isnothing(optval) | ||
| if optval isa Bool && optval | ||
| optdict = optdata(tensorexpr) | ||
| elseif optval isa Expr | ||
| optdict = optdata(optval, tensorexpr) | ||
| else | ||
| throw(ArgumentError("Invalid use of `opt`, should be `opt=true` or `opt=OptExpr`")) | ||
| end | ||
| parser.contractiontreebuilder = network -> optimaltree(network, optdict; | ||
| optimizer=optimizer)[1] | ||
| end | ||
| return parser | ||
| end | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.