@@ -489,16 +489,33 @@ size_all_negative(::SpinGlass) = false
489489size_all_positive (:: SpinGlass ) = false
490490
491491# NOTE: `findmin` and `findmax` are required by `ProblemReductions.jl`
492+ """
493+ GTNSolver(; optimizer=TreeSA(), single=false, usecuda=false, T=Float64)
494+
495+ A generic tensor network based backend for the `findbest`, `findmin` and `findmax` interfaces in `ProblemReductions.jl`.
496+
497+ Keyword arguments
498+ -------------------------------------
499+ * `optimizer` is the optimizer for the tensor network contraction.
500+ * `single` is a switch to return single solution instead of all solutions.
501+ * `usecuda` is a switch to use CUDA (when applicable), user need to call statement `using CUDA` before turning on this switch.
502+ * `T` is the "base" element type, sometimes can be used to reduce the memory cost.
503+ """
492504Base. @kwdef struct GTNSolver
493505 optimizer:: OMEinsum.CodeOptimizer = TreeSA ()
506+ single:: Bool = false
494507 usecuda:: Bool = false
495508 T:: Type = Float64
496509end
497- function Base. findmin (problem:: AbstractProblem , solver:: GTNSolver )
498- res = collect (solve (GenericTensorNetwork (problem; optimizer= solver. optimizer), ConfigsMin (; tree_storage= true ); usecuda= solver. usecuda, T= solver. T)[]. c)
499- return map (x -> ProblemReductions. id_to_config (problem, Int .(x) .+ 1 ), res)
500- end
501- function Base. findmax (problem:: AbstractProblem , solver:: GTNSolver )
502- res = collect (solve (GenericTensorNetwork (problem; optimizer= solver. optimizer), ConfigsMax (; tree_storage= true ); usecuda= solver. usecuda, T= solver. T)[]. c)
503- return map (x -> ProblemReductions. id_to_config (problem, Int .(x) .+ 1 ), res)
504- end
510+ for (PROP, SPROP, SOLVER) in [
511+ (:ConfigsMin , :SingleConfigMin , :findmin ), (:ConfigsMax , :SingleConfigMax , :findmax )
512+ ]
513+ @eval function Base. $ (SOLVER)(problem:: AbstractProblem , solver:: GTNSolver )
514+ if solver. single
515+ res = [solve (GenericTensorNetwork (problem; optimizer= solver. optimizer), $ (SPROP)(); usecuda= solver. usecuda, T= solver. T)[]. c. data]
516+ else
517+ res = collect (solve (GenericTensorNetwork (problem; optimizer= solver. optimizer), $ (PROP)(; tree_storage= true ); usecuda= solver. usecuda, T= solver. T)[]. c)
518+ end
519+ return map (x -> ProblemReductions. id_to_config (problem, Int .(x) .+ 1 ), res)
520+ end
521+ end
0 commit comments