From bb387c9bcc59ac0fdf9ef9ab1d3a4c05889b1a8a Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Fri, 3 Oct 2025 10:18:29 -0400 Subject: [PATCH 01/13] CompatHelper: bump compat for NamedGraphs to 0.7, (keep existing compat) (#4) Co-authored-by: CompatHelper Julia Co-authored-by: Matt Fishman --- Project.toml | 4 ++-- test/Project.toml | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/Project.toml b/Project.toml index ce4f361..21bc2ee 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "ITensorNetworksNext" uuid = "302f2e75-49f0-4526-aef7-d8ba550cb06c" authors = ["ITensor developers and contributors"] -version = "0.1.2" +version = "0.1.3" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" @@ -25,7 +25,7 @@ Graphs = "1.13.1" LinearAlgebra = "1.10" MacroTools = "0.5.16" NamedDimsArrays = "0.7.13" -NamedGraphs = "0.6.9" +NamedGraphs = "0.6.9, 0.7" SimpleTraits = "0.9.5" SplitApplyCombine = "1.2.3" julia = "1.10" diff --git a/test/Project.toml b/test/Project.toml index 94f32e3..80debae 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -17,7 +17,7 @@ Graphs = "1.13.1" ITensorBase = "0.2.12" ITensorNetworksNext = "0.1.1" NamedDimsArrays = "0.7.14" -NamedGraphs = "0.6.8" +NamedGraphs = "0.6.8, 0.7" SafeTestsets = "0.1" Suppressor = "0.2.8" Test = "1.10" From 3549a9cd3d06498c9dd0579cf98e720b7cabf0f8 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Fri, 3 Oct 2025 16:09:15 -0400 Subject: [PATCH 02/13] Automatic Runic.jl run (#6) Co-authored-by: mtfishman <7855256+mtfishman@users.noreply.github.com> Co-authored-by: mtfishman Co-authored-by: Matt Fishman --- .JuliaFormatter.toml | 3 - .github/workflows/FormatCheck.yml | 13 +- .pre-commit-config.yaml | 8 +- Project.toml | 2 +- docs/make.jl | 22 +-- docs/make_index.jl | 16 +- docs/make_readme.jl | 16 +- examples/README.jl | 2 +- src/abstracttensornetwork.jl | 258 +++++++++++++++--------------- src/tensornetwork.jl | 70 ++++---- test/runtests.jl | 80 ++++----- test/test_aqua.jl | 2 +- test/test_basics.jl | 102 ++++++------ 13 files changed, 298 insertions(+), 296 deletions(-) delete mode 100644 .JuliaFormatter.toml diff --git a/.JuliaFormatter.toml b/.JuliaFormatter.toml deleted file mode 100644 index 4c49a86..0000000 --- a/.JuliaFormatter.toml +++ /dev/null @@ -1,3 +0,0 @@ -# See https://domluna.github.io/JuliaFormatter.jl/stable/ for a list of options -style = "blue" -indent = 2 diff --git a/.github/workflows/FormatCheck.yml b/.github/workflows/FormatCheck.yml index 3f78afc..1525861 100644 --- a/.github/workflows/FormatCheck.yml +++ b/.github/workflows/FormatCheck.yml @@ -1,11 +1,14 @@ name: "Format Check" on: - push: - branches: - - 'main' - tags: '*' - pull_request: + pull_request_target: + paths: ['**/*.jl'] + types: [opened, synchronize, reopened, ready_for_review] + +permissions: + contents: read + actions: write + pull-requests: write jobs: format-check: diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 88bc8b4..3fc4743 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,5 +1,5 @@ ci: - skip: [julia-formatter] + skip: [runic] repos: - repo: https://github.com/pre-commit/pre-commit-hooks @@ -11,7 +11,7 @@ repos: - id: end-of-file-fixer exclude_types: [markdown] # incompatible with Literate.jl -- repo: "https://github.com/domluna/JuliaFormatter.jl" - rev: v2.1.6 +- repo: https://github.com/fredrikekre/runic-pre-commit + rev: v2.0.1 hooks: - - id: "julia-formatter" + - id: runic diff --git a/Project.toml b/Project.toml index 21bc2ee..ff60a76 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "ITensorNetworksNext" uuid = "302f2e75-49f0-4526-aef7-d8ba550cb06c" authors = ["ITensor developers and contributors"] -version = "0.1.3" +version = "0.1.4" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" diff --git a/docs/make.jl b/docs/make.jl index 5a50658..1b29518 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -2,23 +2,23 @@ using ITensorNetworksNext: ITensorNetworksNext using Documenter: Documenter, DocMeta, deploydocs, makedocs DocMeta.setdocmeta!( - ITensorNetworksNext, :DocTestSetup, :(using ITensorNetworksNext); recursive=true + ITensorNetworksNext, :DocTestSetup, :(using ITensorNetworksNext); recursive = true ) include("make_index.jl") makedocs(; - modules=[ITensorNetworksNext], - authors="ITensor developers and contributors", - sitename="ITensorNetworksNext.jl", - format=Documenter.HTML(; - canonical="https://itensor.github.io/ITensorNetworksNext.jl", - edit_link="main", - assets=["assets/favicon.ico", "assets/extras.css"], - ), - pages=["Home" => "index.md", "Reference" => "reference.md"], + modules = [ITensorNetworksNext], + authors = "ITensor developers and contributors", + sitename = "ITensorNetworksNext.jl", + format = Documenter.HTML(; + canonical = "https://itensor.github.io/ITensorNetworksNext.jl", + edit_link = "main", + assets = ["assets/favicon.ico", "assets/extras.css"], + ), + pages = ["Home" => "index.md", "Reference" => "reference.md"], ) deploydocs(; - repo="github.com/ITensor/ITensorNetworksNext.jl", devbranch="main", push_preview=true + repo = "github.com/ITensor/ITensorNetworksNext.jl", devbranch = "main", push_preview = true ) diff --git a/docs/make_index.jl b/docs/make_index.jl index 44fa493..038bc87 100644 --- a/docs/make_index.jl +++ b/docs/make_index.jl @@ -2,20 +2,20 @@ using Literate: Literate using ITensorNetworksNext: ITensorNetworksNext function ccq_logo(content) - include_ccq_logo = """ + include_ccq_logo = """ ```@raw html Flatiron Center for Computational Quantum Physics logo. Flatiron Center for Computational Quantum Physics logo. ``` """ - content = replace(content, "{CCQ_LOGO}" => include_ccq_logo) - return content + content = replace(content, "{CCQ_LOGO}" => include_ccq_logo) + return content end Literate.markdown( - joinpath(pkgdir(ITensorNetworksNext), "examples", "README.jl"), - joinpath(pkgdir(ITensorNetworksNext), "docs", "src"); - flavor=Literate.DocumenterFlavor(), - name="index", - postprocess=ccq_logo, + joinpath(pkgdir(ITensorNetworksNext), "examples", "README.jl"), + joinpath(pkgdir(ITensorNetworksNext), "docs", "src"); + flavor = Literate.DocumenterFlavor(), + name = "index", + postprocess = ccq_logo, ) diff --git a/docs/make_readme.jl b/docs/make_readme.jl index 960d376..088dc58 100644 --- a/docs/make_readme.jl +++ b/docs/make_readme.jl @@ -2,20 +2,20 @@ using Literate: Literate using ITensorNetworksNext: ITensorNetworksNext function ccq_logo(content) - include_ccq_logo = """ + include_ccq_logo = """ Flatiron Center for Computational Quantum Physics logo. """ - content = replace(content, "{CCQ_LOGO}" => include_ccq_logo) - return content + content = replace(content, "{CCQ_LOGO}" => include_ccq_logo) + return content end Literate.markdown( - joinpath(pkgdir(ITensorNetworksNext), "examples", "README.jl"), - joinpath(pkgdir(ITensorNetworksNext)); - flavor=Literate.CommonMarkFlavor(), - name="README", - postprocess=ccq_logo, + joinpath(pkgdir(ITensorNetworksNext), "examples", "README.jl"), + joinpath(pkgdir(ITensorNetworksNext)); + flavor = Literate.CommonMarkFlavor(), + name = "README", + postprocess = ccq_logo, ) diff --git a/examples/README.jl b/examples/README.jl index 4aaa79b..e3ee854 100644 --- a/examples/README.jl +++ b/examples/README.jl @@ -1,5 +1,5 @@ # # ITensorNetworksNext.jl -# +# # [![Stable](https://img.shields.io/badge/docs-stable-blue.svg)](https://itensor.github.io/ITensorNetworksNext.jl/stable/) # [![Dev](https://img.shields.io/badge/docs-dev-blue.svg)](https://itensor.github.io/ITensorNetworksNext.jl/dev/) # [![Build Status](https://github.com/ITensor/ITensorNetworksNext.jl/actions/workflows/Tests.yml/badge.svg?branch=main)](https://github.com/ITensor/ITensorNetworksNext.jl/actions/workflows/Tests.yml?query=branch%3Amain) diff --git a/src/abstracttensornetwork.jl b/src/abstracttensornetwork.jl index e666e93..73bf9d6 100644 --- a/src/abstracttensornetwork.jl +++ b/src/abstracttensornetwork.jl @@ -1,44 +1,44 @@ using Adapt: Adapt, adapt, adapt_structure using BackendSelection: @Algorithm_str, Algorithm using DataGraphs: - DataGraphs, - AbstractDataGraph, - edge_data, - underlying_graph, - underlying_graph_type, - vertex_data + DataGraphs, + AbstractDataGraph, + edge_data, + underlying_graph, + underlying_graph_type, + vertex_data using Dictionaries: Dictionary using Graphs: - Graphs, - AbstractEdge, - AbstractGraph, - Graph, - add_edge!, - add_vertex!, - bfs_tree, - center, - dst, - edges, - edgetype, - ne, - neighbors, - nv, - rem_edge!, - src, - vertices + Graphs, + AbstractEdge, + AbstractGraph, + Graph, + add_edge!, + add_vertex!, + bfs_tree, + center, + dst, + edges, + edgetype, + ne, + neighbors, + nv, + rem_edge!, + src, + vertices using LinearAlgebra: LinearAlgebra, factorize using MacroTools: @capture using NamedDimsArrays: dimnames using NamedGraphs: NamedGraphs, NamedGraph, not_implemented, steiner_tree using NamedGraphs.GraphsExtensions: - ⊔, directed_graph, incident_edges, rem_edges!, rename_vertices, vertextype + ⊔, directed_graph, incident_edges, rem_edges!, rename_vertices, vertextype using SplitApplyCombine: flatten -abstract type AbstractTensorNetwork{V,VD} <: AbstractDataGraph{V,VD,Nothing} end +abstract type AbstractTensorNetwork{V, VD} <: AbstractDataGraph{V, VD, Nothing} end function Graphs.rem_edge!(tn::AbstractTensorNetwork, e) - rem_edge!(underlying_graph(tn), e) - return tn + rem_edge!(underlying_graph(tn), e) + return tn end # TODO: Define a generic fallback for `AbstractDataGraph`? @@ -46,14 +46,14 @@ DataGraphs.edge_data_eltype(::Type{<:AbstractTensorNetwork}) = error("No edge da # Graphs.jl overloads function Graphs.weights(graph::AbstractTensorNetwork) - V = vertextype(graph) - es = Tuple.(edges(graph)) - ws = Dictionary{Tuple{V,V},Float64}(es, undef) - for e in edges(graph) - w = log2(dim(commoninds(graph, e))) - ws[(src(e), dst(e))] = w - end - return ws + V = vertextype(graph) + es = Tuple.(edges(graph)) + ws = Dictionary{Tuple{V, V}, Float64}(es, undef) + for e in edges(graph) + w = log2(dim(commoninds(graph, e))) + ws[(src(e), dst(e))] = w + end + return ws end # Copy @@ -71,85 +71,85 @@ Graphs.is_directed(::Type{<:AbstractTensorNetwork}) = false # Derived interface, may need to be overloaded function DataGraphs.underlying_graph_type(G::Type{<:AbstractTensorNetwork}) - return underlying_graph_type(data_graph_type(G)) + return underlying_graph_type(data_graph_type(G)) end # AbstractDataGraphs overloads function DataGraphs.vertex_data(graph::AbstractTensorNetwork, args...) - return error("Not implemented") + return error("Not implemented") end function DataGraphs.edge_data(graph::AbstractTensorNetwork, args...) - return error("Not implemented") + return error("Not implemented") end DataGraphs.underlying_graph(tn::AbstractTensorNetwork) = error("Not implemented") function NamedGraphs.vertex_positions(tn::AbstractTensorNetwork) - return NamedGraphs.vertex_positions(underlying_graph(tn)) + return NamedGraphs.vertex_positions(underlying_graph(tn)) end function NamedGraphs.ordered_vertices(tn::AbstractTensorNetwork) - return NamedGraphs.ordered_vertices(underlying_graph(tn)) + return NamedGraphs.ordered_vertices(underlying_graph(tn)) end function Adapt.adapt_structure(to, tn::AbstractTensorNetwork) - # TODO: Define and use: - # - # @preserve_graph map_vertex_data(adapt(to), tn) - # - # or just: - # - # @preserve_graph map(adapt(to), tn) - return map_vertex_data_preserve_graph(adapt(to), tn) + # TODO: Define and use: + # + # @preserve_graph map_vertex_data(adapt(to), tn) + # + # or just: + # + # @preserve_graph map(adapt(to), tn) + return map_vertex_data_preserve_graph(adapt(to), tn) end function linkinds(tn::AbstractTensorNetwork, edge::Pair) - return linkinds(tn, edgetype(tn)(edge)) + return linkinds(tn, edgetype(tn)(edge)) end function linkinds(tn::AbstractTensorNetwork, edge::AbstractEdge) - return nameddimsindices(tn[src(edge)]) ∩ nameddimsindices(tn[dst(edge)]) + return nameddimsindices(tn[src(edge)]) ∩ nameddimsindices(tn[dst(edge)]) end function linkaxes(tn::AbstractTensorNetwork, edge::Pair) - return linkaxes(tn, edgetype(tn)(edge)) + return linkaxes(tn, edgetype(tn)(edge)) end function linkaxes(tn::AbstractTensorNetwork, edge::AbstractEdge) - return axes(tn[src(edge)]) ∩ axes(tn[dst(edge)]) + return axes(tn[src(edge)]) ∩ axes(tn[dst(edge)]) end function linknames(tn::AbstractTensorNetwork, edge::Pair) - return linknames(tn, edgetype(tn)(edge)) + return linknames(tn, edgetype(tn)(edge)) end function linknames(tn::AbstractTensorNetwork, edge::AbstractEdge) - return dimnames(tn[src(edge)]) ∩ dimnames(tn[dst(edge)]) + return dimnames(tn[src(edge)]) ∩ dimnames(tn[dst(edge)]) end function siteinds(tn::AbstractTensorNetwork, v) - s = nameddimsindices(tn[v]) - for v′ in neighbors(tn, v) - s = setdiff(s, nameddimsindices(tn[v′])) - end - return s + s = nameddimsindices(tn[v]) + for v′ in neighbors(tn, v) + s = setdiff(s, nameddimsindices(tn[v′])) + end + return s end function siteaxes(tn::AbstractTensorNetwork, edge::AbstractEdge) - s = axes(tn[src(edge)]) ∩ axes(tn[dst(edge)]) - for v′ in neighbors(tn, v) - s = setdiff(s, axes(tn[v′])) - end - return s + s = axes(tn[src(edge)]) ∩ axes(tn[dst(edge)]) + for v′ in neighbors(tn, v) + s = setdiff(s, axes(tn[v′])) + end + return s end function sitenames(tn::AbstractTensorNetwork, edge::AbstractEdge) - s = dimnames(tn[src(edge)]) ∩ dimnames(tn[dst(edge)]) - for v′ in neighbors(tn, v) - s = setdiff(s, dimnames(tn[v′])) - end - return s + s = dimnames(tn[src(edge)]) ∩ dimnames(tn[dst(edge)]) + for v′ in neighbors(tn, v) + s = setdiff(s, dimnames(tn[v′])) + end + return s end function setindex_preserve_graph!(tn::AbstractTensorNetwork, value, vertex) - vertex_data(tn)[vertex] = value - return tn + vertex_data(tn)[vertex] = value + return tn end # TODO: Move to `BaseExtensions` module. function is_setindex!_expr(expr::Expr) - return is_assignment_expr(expr) && is_getindex_expr(first(expr.args)) + return is_assignment_expr(expr) && is_getindex_expr(first(expr.args)) end is_setindex!_expr(x) = false is_getindex_expr(expr::Expr) = (expr.head === :ref) @@ -162,118 +162,118 @@ is_assignment_expr(expr) = false # preserve_graph_function(::typeof(map_vertex_data)) = map_vertex_data_preserve_graph # Also allow annotating codeblocks like `@views`. macro preserve_graph(expr) - if !is_setindex!_expr(expr) - error( - "preserve_graph must be used with setindex! syntax (as @preserve_graph a[i,j,...] = value)", - ) - end - @capture(expr, array_[indices__] = value_) - return :(setindex_preserve_graph!($(esc(array)), $(esc(value)), $(esc.(indices)...))) + if !is_setindex!_expr(expr) + error( + "preserve_graph must be used with setindex! syntax (as @preserve_graph a[i,j,...] = value)", + ) + end + @capture(expr, array_[indices__] = value_) + return :(setindex_preserve_graph!($(esc(array)), $(esc(value)), $(esc.(indices)...))) end # Update the graph of the TensorNetwork `tn` to include # edges that should exist based on the tensor connectivity. function add_missing_edges!(tn::AbstractTensorNetwork) - foreach(v -> add_missing_edges!(tn, v), vertices(tn)) - return tn + foreach(v -> add_missing_edges!(tn, v), vertices(tn)) + return tn end # Update the graph of the TensorNetwork `tn` to include # edges that should be incident to the vertex `v` # based on the tensor connectivity. function add_missing_edges!(tn::AbstractTensorNetwork, v) - for v′ in vertices(tn) - if v ≠ v′ - e = v => v′ - if !isempty(linkinds(tn, e)) - add_edge!(tn, e) - end + for v′ in vertices(tn) + if v ≠ v′ + e = v => v′ + if !isempty(linkinds(tn, e)) + add_edge!(tn, e) + end + end end - end - return tn + return tn end # Fix the edges of the TensorNetwork `tn` to match # the tensor connectivity. function fix_edges!(tn::AbstractTensorNetwork) - foreach(v -> fix_edges!(tn, v), vertices(tn)) - return tn + foreach(v -> fix_edges!(tn, v), vertices(tn)) + return tn end # Fix the edges of the TensorNetwork `tn` to match # the tensor connectivity at vertex `v`. function fix_edges!(tn::AbstractTensorNetwork, v) - rem_incident_edges!(tn, v) - rem_edges!(tn, incident_edges(tn, v)) - add_missing_edges!(tn, v) - return tn + rem_incident_edges!(tn, v) + rem_edges!(tn, incident_edges(tn, v)) + add_missing_edges!(tn, v) + return tn end # Customization point. using NamedDimsArrays: AbstractNamedUnitRange, namedunitrange, nametype, randname function trivial_unitrange(type::Type{<:AbstractUnitRange}) - return Base.oneto(one(eltype(type))) + return Base.oneto(one(eltype(type))) end function rand_trivial_namedunitrange( - ::Type{<:AbstractNamedUnitRange{<:Any,R,N}} -) where {R,N} - return namedunitrange(trivial_unitrange(R), randname(N)) + ::Type{<:AbstractNamedUnitRange{<:Any, R, N}} + ) where {R, N} + return namedunitrange(trivial_unitrange(R), randname(N)) end dag(x) = x using NamedDimsArrays: nameddimsindices function insert_trivial_link!(tn, e) - add_edge!(tn, e) - l = rand_trivial_namedunitrange(eltype(nameddimsindices(tn[src(e)]))) - x = similar(tn[src(e)], (l,)) - x[1] = 1 - @preserve_graph tn[src(e)] = tn[src(e)] * x - @preserve_graph tn[dst(e)] = tn[dst(e)] * dag(x) - return tn + add_edge!(tn, e) + l = rand_trivial_namedunitrange(eltype(nameddimsindices(tn[src(e)]))) + x = similar(tn[src(e)], (l,)) + x[1] = 1 + @preserve_graph tn[src(e)] = tn[src(e)] * x + @preserve_graph tn[dst(e)] = tn[dst(e)] * dag(x) + return tn end function Base.setindex!(tn::AbstractTensorNetwork, value, v) - @preserve_graph tn[v] = value - fix_edges!(tn, v) - return tn + @preserve_graph tn[v] = value + fix_edges!(tn, v) + return tn end using NamedGraphs.OrdinalIndexing: OrdinalSuffixedInteger # Fix ambiguity error. function Base.setindex!(graph::AbstractTensorNetwork, value, vertex::OrdinalSuffixedInteger) - graph[vertices(graph)[vertex]] = value - return graph + graph[vertices(graph)[vertex]] = value + return graph end # Fix ambiguity error. function Base.setindex!(tn::AbstractTensorNetwork, value, edge::AbstractEdge) - return error("No edge data.") + return error("No edge data.") end # Fix ambiguity error. function Base.setindex!(tn::AbstractTensorNetwork, value, edge::Pair) - return error("No edge data.") + return error("No edge data.") end using NamedGraphs.OrdinalIndexing: OrdinalSuffixedInteger # Fix ambiguity error. function Base.setindex!( - tn::AbstractTensorNetwork, - value, - edge::Pair{<:OrdinalSuffixedInteger,<:OrdinalSuffixedInteger}, -) - return error("No edge data.") + tn::AbstractTensorNetwork, + value, + edge::Pair{<:OrdinalSuffixedInteger, <:OrdinalSuffixedInteger}, + ) + return error("No edge data.") end function Base.show(io::IO, mime::MIME"text/plain", graph::AbstractTensorNetwork) - println(io, "$(typeof(graph)) with $(nv(graph)) vertices:") - show(io, mime, vertices(graph)) - println(io, "\n") - println(io, "and $(ne(graph)) edge(s):") - for e in edges(graph) - show(io, mime, e) + println(io, "$(typeof(graph)) with $(nv(graph)) vertices:") + show(io, mime, vertices(graph)) + println(io, "\n") + println(io, "and $(ne(graph)) edge(s):") + for e in edges(graph) + show(io, mime, e) + println(io) + end println(io) - end - println(io) - println(io, "with vertex data:") - show(io, mime, axes.(vertex_data(graph))) - return nothing + println(io, "with vertex data:") + show(io, mime, axes.(vertex_data(graph))) + return nothing end Base.show(io::IO, graph::AbstractTensorNetwork) = show(io, MIME"text/plain"(), graph) diff --git a/src/tensornetwork.jl b/src/tensornetwork.jl index 3fd794b..7423669 100644 --- a/src/tensornetwork.jl +++ b/src/tensornetwork.jl @@ -7,68 +7,68 @@ using NamedGraphs.GraphsExtensions: arranged_edges, vertextype function _TensorNetwork end -struct TensorNetwork{V,VD,UG<:AbstractGraph{V},Tensors<:AbstractDictionary{V,VD}} <: - AbstractTensorNetwork{V,VD} - underlying_graph::UG - tensors::Tensors - global @inline function _TensorNetwork( - underlying_graph::UG, tensors::Tensors - ) where {V,VD,UG<:AbstractGraph{V},Tensors<:AbstractDictionary{V,VD}} - # This assumes the tensor connectivity matches the graph structure. - return new{V,VD,UG,Tensors}(underlying_graph, tensors) - end +struct TensorNetwork{V, VD, UG <: AbstractGraph{V}, Tensors <: AbstractDictionary{V, VD}} <: + AbstractTensorNetwork{V, VD} + underlying_graph::UG + tensors::Tensors + global @inline function _TensorNetwork( + underlying_graph::UG, tensors::Tensors + ) where {V, VD, UG <: AbstractGraph{V}, Tensors <: AbstractDictionary{V, VD}} + # This assumes the tensor connectivity matches the graph structure. + return new{V, VD, UG, Tensors}(underlying_graph, tensors) + end end DataGraphs.underlying_graph(tn::TensorNetwork) = getfield(tn, :underlying_graph) DataGraphs.vertex_data(tn::TensorNetwork) = getfield(tn, :tensors) function DataGraphs.underlying_graph_type(type::Type{<:TensorNetwork}) - return fieldtype(type, :underlying_graph) + return fieldtype(type, :underlying_graph) end # Determine the graph structure from the tensors. function TensorNetwork(t::AbstractDictionary) - g = NamedGraph(eachindex(t)) - for v1 in vertices(g) - for v2 in vertices(g) - if v1 ≠ v2 - if !isdisjoint(dimnames(t[v1]), dimnames(t[v2])) - add_edge!(g, v1 => v2) + g = NamedGraph(eachindex(t)) + for v1 in vertices(g) + for v2 in vertices(g) + if v1 ≠ v2 + if !isdisjoint(dimnames(t[v1]), dimnames(t[v2])) + add_edge!(g, v1 => v2) + end + end end - end end - end - return _TensorNetwork(g, t) + return _TensorNetwork(g, t) end function TensorNetwork(tensors::AbstractDict) - return TensorNetwork(Dictionary(tensors)) + return TensorNetwork(Dictionary(tensors)) end function TensorNetwork(graph::AbstractGraph, tensors::AbstractDictionary) - tn = TensorNetwork(tensors) - arranged_edges(tn) ⊆ arranged_edges(graph) || - error("The edges in the tensors do not match the graph structure.") - for e in setdiff(arranged_edges(graph), arranged_edges(tn)) - insert_trivial_link!(tn, e) - end - return tn + tn = TensorNetwork(tensors) + arranged_edges(tn) ⊆ arranged_edges(graph) || + error("The edges in the tensors do not match the graph structure.") + for e in setdiff(arranged_edges(graph), arranged_edges(tn)) + insert_trivial_link!(tn, e) + end + return tn end function TensorNetwork(graph::AbstractGraph, tensors::AbstractDict) - return TensorNetwork(graph, Dictionary(tensors)) + return TensorNetwork(graph, Dictionary(tensors)) end function TensorNetwork(f, graph::AbstractGraph) - return TensorNetwork(graph, Dict(v => f(v) for v in vertices(graph))) + return TensorNetwork(graph, Dict(v => f(v) for v in vertices(graph))) end function Base.copy(tn::TensorNetwork) - TensorNetwork(copy(underlying_graph(tn)), copy(vertex_data(tn))) + return TensorNetwork(copy(underlying_graph(tn)), copy(vertex_data(tn))) end TensorNetwork(tn::TensorNetwork) = copy(tn) TensorNetwork{V}(tn::TensorNetwork{V}) where {V} = copy(tn) function TensorNetwork{V}(tn::TensorNetwork) where {V} - g′ = convert_vertextype(V, underlying_graph(tn)) - d = vertex_data(tn) - d′ = dictionary(V(k) => d[k] for k in eachindex(d)) - return TensorNetwork(g′, d′) + g′ = convert_vertextype(V, underlying_graph(tn)) + d = vertex_data(tn) + d′ = dictionary(V(k) => d[k] for k in eachindex(d)) + return TensorNetwork(g′, d′) end NamedGraphs.convert_vertextype(::Type{V}, tn::TensorNetwork{V}) where {V} = tn diff --git a/test/runtests.jl b/test/runtests.jl index 98b2d2b..0008050 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -6,60 +6,62 @@ using Suppressor: Suppressor const pat = r"(?:--group=)(\w+)" arg_id = findfirst(contains(pat), ARGS) const GROUP = uppercase( - if isnothing(arg_id) - get(ENV, "GROUP", "ALL") - else - only(match(pat, ARGS[arg_id]).captures) - end, + if isnothing(arg_id) + get(ENV, "GROUP", "ALL") + else + only(match(pat, ARGS[arg_id]).captures) + end, ) "match files of the form `test_*.jl`, but exclude `*setup*.jl`" function istestfile(fn) - return endswith(fn, ".jl") && startswith(basename(fn), "test_") && !contains(fn, "setup") + return endswith(fn, ".jl") && startswith(basename(fn), "test_") && !contains(fn, "setup") end "match files of the form `*.jl`, but exclude `*_notest.jl` and `*setup*.jl`" function isexamplefile(fn) - return endswith(fn, ".jl") && !endswith(fn, "_notest.jl") && !contains(fn, "setup") + return endswith(fn, ".jl") && !endswith(fn, "_notest.jl") && !contains(fn, "setup") end @time begin - # tests in groups based on folder structure - for testgroup in filter(isdir, readdir(@__DIR__)) - if GROUP == "ALL" || GROUP == uppercase(testgroup) - groupdir = joinpath(@__DIR__, testgroup) - for file in filter(istestfile, readdir(groupdir)) - filename = joinpath(groupdir, file) - @eval @safetestset $file begin - include($filename) + # tests in groups based on folder structure + for testgroup in filter(isdir, readdir(@__DIR__)) + if GROUP == "ALL" || GROUP == uppercase(testgroup) + groupdir = joinpath(@__DIR__, testgroup) + for file in filter(istestfile, readdir(groupdir)) + filename = joinpath(groupdir, file) + @eval @safetestset $file begin + include($filename) + end + end end - end end - end - # single files in top folder - for file in filter(istestfile, readdir(@__DIR__)) - (file == basename(@__FILE__)) && continue # exclude this file to avoid infinite recursion - @eval @safetestset $file begin - include($file) + # single files in top folder + for file in filter(istestfile, readdir(@__DIR__)) + (file == basename(@__FILE__)) && continue # exclude this file to avoid infinite recursion + @eval @safetestset $file begin + include($file) + end end - end - # test examples - examplepath = joinpath(@__DIR__, "..", "examples") - for (root, _, files) in walkdir(examplepath) - contains(chopprefix(root, @__DIR__), "setup") && continue - for file in filter(isexamplefile, files) - filename = joinpath(root, file) - @eval begin - @safetestset $file begin - $(Expr( - :macrocall, - GlobalRef(Suppressor, Symbol("@suppress")), - LineNumberNode(@__LINE__, @__FILE__), - :(include($filename)), - )) + # test examples + examplepath = joinpath(@__DIR__, "..", "examples") + for (root, _, files) in walkdir(examplepath) + contains(chopprefix(root, @__DIR__), "setup") && continue + for file in filter(isexamplefile, files) + filename = joinpath(root, file) + @eval begin + @safetestset $file begin + $( + Expr( + :macrocall, + GlobalRef(Suppressor, Symbol("@suppress")), + LineNumberNode(@__LINE__, @__FILE__), + :(include($filename)), + ) + ) + end + end end - end end - end end diff --git a/test/test_aqua.jl b/test/test_aqua.jl index 34bfff1..0afead5 100644 --- a/test/test_aqua.jl +++ b/test/test_aqua.jl @@ -3,5 +3,5 @@ using Aqua: Aqua using Test: @testset @testset "Code quality (Aqua.jl)" begin - Aqua.test_all(ITensorNetworksNext) + Aqua.test_all(ITensorNetworksNext) end diff --git a/test/test_basics.jl b/test/test_basics.jl index 59e5e35..0c9d803 100644 --- a/test/test_basics.jl +++ b/test/test_basics.jl @@ -8,56 +8,56 @@ using NamedGraphs.NamedGraphGenerators: named_grid using Test: @test, @testset @testset "ITensorNetworksNext" begin - @testset "Construct TensorNetwork product state" begin - dims = (3, 3) - g = named_grid(dims) - s = Dict(v => Index(2) for v in vertices(g)) - tn = TensorNetwork(g) do v - return randn(s[v]) + @testset "Construct TensorNetwork product state" begin + dims = (3, 3) + g = named_grid(dims) + s = Dict(v => Index(2) for v in vertices(g)) + tn = TensorNetwork(g) do v + return randn(s[v]) + end + @test nv(tn) == 9 + @test ne(tn) == ne(g) + @test issetequal(vertices(tn), vertices(g)) + @test issetequal(arranged_edges(tn), arranged_edges(g)) + for v in vertices(tn) + @test siteinds(tn, v) == [s[v]] + end + for v1 in vertices(tn) + for v2 in vertices(tn) + v1 == v2 && continue + haslink = !isempty(linkinds(tn, v1 => v2)) + @test haslink == has_edge(tn, v1 => v2) + end + end + for e in edges(tn) + @test isone(length(only(linkinds(tn, e)))) + end + end + @testset "Construct TensorNetwork partition function" begin + dims = (3, 3) + g = named_grid(dims) + l = Dict(e => Index(2) for e in edges(g)) + l = merge(l, Dict(reverse(e) => l[e] for e in edges(g))) + tn = TensorNetwork(g) do v + is = map(e -> l[e], incident_edges(g, v)) + return randn(Tuple(is)) + end + @test nv(tn) == 9 + @test ne(tn) == ne(g) + @test issetequal(vertices(tn), vertices(g)) + @test issetequal(arranged_edges(tn), arranged_edges(g)) + for v in vertices(tn) + @test isempty(siteinds(tn, v)) + end + for v1 in vertices(tn) + for v2 in vertices(tn) + v1 == v2 && continue + haslink = !isempty(linkinds(tn, v1 => v2)) + @test haslink == has_edge(tn, v1 => v2) + end + end + for e in edges(tn) + @test only(linkinds(tn, e)) == l[e] + end end - @test nv(tn) == 9 - @test ne(tn) == ne(g) - @test issetequal(vertices(tn), vertices(g)) - @test issetequal(arranged_edges(tn), arranged_edges(g)) - for v in vertices(tn) - @test siteinds(tn, v) == [s[v]] - end - for v1 in vertices(tn) - for v2 in vertices(tn) - v1 == v2 && continue - haslink = !isempty(linkinds(tn, v1 => v2)) - @test haslink == has_edge(tn, v1 => v2) - end - end - for e in edges(tn) - @test isone(length(only(linkinds(tn, e)))) - end - end - @testset "Construct TensorNetwork partition function" begin - dims = (3, 3) - g = named_grid(dims) - l = Dict(e => Index(2) for e in edges(g)) - l = merge(l, Dict(reverse(e) => l[e] for e in edges(g))) - tn = TensorNetwork(g) do v - is = map(e -> l[e], incident_edges(g, v)) - return randn(Tuple(is)) - end - @test nv(tn) == 9 - @test ne(tn) == ne(g) - @test issetequal(vertices(tn), vertices(g)) - @test issetequal(arranged_edges(tn), arranged_edges(g)) - for v in vertices(tn) - @test isempty(siteinds(tn, v)) - end - for v1 in vertices(tn) - for v2 in vertices(tn) - v1 == v2 && continue - haslink = !isempty(linkinds(tn, v1 => v2)) - @test haslink == has_edge(tn, v1 => v2) - end - end - for e in edges(tn) - @test only(linkinds(tn, e)) == l[e] - end - end end From 88a6203c6b4d69f32081123e63281d16f2f8de1f Mon Sep 17 00:00:00 2001 From: Matt Fishman Date: Mon, 6 Oct 2025 21:09:36 -0400 Subject: [PATCH 03/13] Upgrade to NamedDimsArrays v0.8 (#8) --- Project.toml | 4 ++-- src/abstracttensornetwork.jl | 11 +++++------ src/tensornetwork.jl | 2 +- test/Project.toml | 4 ++-- 4 files changed, 10 insertions(+), 11 deletions(-) diff --git a/Project.toml b/Project.toml index ff60a76..fc1ba2d 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "ITensorNetworksNext" uuid = "302f2e75-49f0-4526-aef7-d8ba550cb06c" authors = ["ITensor developers and contributors"] -version = "0.1.4" +version = "0.1.5" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" @@ -24,7 +24,7 @@ Dictionaries = "0.4.5" Graphs = "1.13.1" LinearAlgebra = "1.10" MacroTools = "0.5.16" -NamedDimsArrays = "0.7.13" +NamedDimsArrays = "0.8" NamedGraphs = "0.6.9, 0.7" SimpleTraits = "0.9.5" SplitApplyCombine = "1.2.3" diff --git a/src/abstracttensornetwork.jl b/src/abstracttensornetwork.jl index 73bf9d6..cdcf409 100644 --- a/src/abstracttensornetwork.jl +++ b/src/abstracttensornetwork.jl @@ -28,7 +28,7 @@ using Graphs: vertices using LinearAlgebra: LinearAlgebra, factorize using MacroTools: @capture -using NamedDimsArrays: dimnames +using NamedDimsArrays: dimnames, inds using NamedGraphs: NamedGraphs, NamedGraph, not_implemented, steiner_tree using NamedGraphs.GraphsExtensions: ⊔, directed_graph, incident_edges, rem_edges!, rename_vertices, vertextype @@ -105,7 +105,7 @@ function linkinds(tn::AbstractTensorNetwork, edge::Pair) return linkinds(tn, edgetype(tn)(edge)) end function linkinds(tn::AbstractTensorNetwork, edge::AbstractEdge) - return nameddimsindices(tn[src(edge)]) ∩ nameddimsindices(tn[dst(edge)]) + return inds(tn[src(edge)]) ∩ inds(tn[dst(edge)]) end function linkaxes(tn::AbstractTensorNetwork, edge::Pair) return linkaxes(tn, edgetype(tn)(edge)) @@ -121,9 +121,9 @@ function linknames(tn::AbstractTensorNetwork, edge::AbstractEdge) end function siteinds(tn::AbstractTensorNetwork, v) - s = nameddimsindices(tn[v]) + s = inds(tn[v]) for v′ in neighbors(tn, v) - s = setdiff(s, nameddimsindices(tn[v′])) + s = setdiff(s, inds(tn[v′])) end return s end @@ -221,10 +221,9 @@ end dag(x) = x -using NamedDimsArrays: nameddimsindices function insert_trivial_link!(tn, e) add_edge!(tn, e) - l = rand_trivial_namedunitrange(eltype(nameddimsindices(tn[src(e)]))) + l = rand_trivial_namedunitrange(eltype(inds(tn[src(e)]))) x = similar(tn[src(e)], (l,)) x[1] = 1 @preserve_graph tn[src(e)] = tn[src(e)] * x diff --git a/src/tensornetwork.jl b/src/tensornetwork.jl index 7423669..c7d1479 100644 --- a/src/tensornetwork.jl +++ b/src/tensornetwork.jl @@ -1,7 +1,7 @@ using DataGraphs: DataGraphs, AbstractDataGraph, DataGraph using Dictionaries: AbstractDictionary, Indices, dictionary using Graphs: AbstractSimpleGraph -using NamedDimsArrays: AbstractNamedDimsArray, dimnames, nameddimsarray +using NamedDimsArrays: AbstractNamedDimsArray, dimnames using NamedGraphs: NamedGraphs, NamedEdge, NamedGraph, vertextype using NamedGraphs.GraphsExtensions: arranged_edges, vertextype diff --git a/test/Project.toml b/test/Project.toml index 80debae..b22f9d1 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -14,9 +14,9 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Aqua = "0.8.14" Dictionaries = "0.4.5" Graphs = "1.13.1" -ITensorBase = "0.2.12" +ITensorBase = "0.3" ITensorNetworksNext = "0.1.1" -NamedDimsArrays = "0.7.14" +NamedDimsArrays = "0.8" NamedGraphs = "0.6.8, 0.7" SafeTestsets = "0.1" Suppressor = "0.2.8" From 1a6d2243f9b250959f59d4f1978bf8273379a20b Mon Sep 17 00:00:00 2001 From: Matt Fishman Date: Mon, 6 Oct 2025 23:02:29 -0400 Subject: [PATCH 04/13] LazyNamedDimsArrays (#9) --- Project.toml | 8 +- src/ITensorNetworksNext.jl | 1 + src/lazynameddimsarrays.jl | 182 +++++++++++++++++++++++++++++++ test/Project.toml | 4 + test/test_lazynameddimsarrays.jl | 56 ++++++++++ 5 files changed, 249 insertions(+), 2 deletions(-) create mode 100644 src/lazynameddimsarrays.jl create mode 100644 test/test_lazynameddimsarrays.jl diff --git a/Project.toml b/Project.toml index fc1ba2d..b527a53 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "ITensorNetworksNext" uuid = "302f2e75-49f0-4526-aef7-d8ba550cb06c" authors = ["ITensor developers and contributors"] -version = "0.1.5" +version = "0.1.6" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" @@ -15,9 +15,11 @@ NamedDimsArrays = "60cbd0c0-df58-4cb7-918c-6f5607b73fde" NamedGraphs = "678767b0-92e7-4007-89e4-4527a8725b19" SimpleTraits = "699a6c99-e7fa-54fc-8d76-47d257e15c1d" SplitApplyCombine = "03a91e81-4c3e-53e1-a0a4-9c0c8f19dd66" +TermInterface = "8ea1fca8-c5ef-4a55-8b96-4e9afe9c9a3c" +WrappedUnions = "325db55a-9c6c-5b90-b1a2-ec87e7a38c44" [compat] -Adapt = "4.3.0" +Adapt = "4.3" BackendSelection = "0.1.6" DataGraphs = "0.2.7" Dictionaries = "0.4.5" @@ -28,4 +30,6 @@ NamedDimsArrays = "0.8" NamedGraphs = "0.6.9, 0.7" SimpleTraits = "0.9.5" SplitApplyCombine = "1.2.3" +TermInterface = "2" +WrappedUnions = "0.3" julia = "1.10" diff --git a/src/ITensorNetworksNext.jl b/src/ITensorNetworksNext.jl index 89daa37..35c9e59 100644 --- a/src/ITensorNetworksNext.jl +++ b/src/ITensorNetworksNext.jl @@ -1,5 +1,6 @@ module ITensorNetworksNext +include("lazynameddimsarrays.jl") include("abstracttensornetwork.jl") include("tensornetwork.jl") diff --git a/src/lazynameddimsarrays.jl b/src/lazynameddimsarrays.jl new file mode 100644 index 0000000..04eca26 --- /dev/null +++ b/src/lazynameddimsarrays.jl @@ -0,0 +1,182 @@ +module LazyNamedDimsArrays + +using WrappedUnions: @wrapped, unwrap +using NamedDimsArrays: + NamedDimsArrays, + AbstractNamedDimsArray, + AbstractNamedDimsArrayStyle, + dename, + inds + +struct Prod{A} + factors::Vector{A} +end + +@wrapped struct LazyNamedDimsArray{ + T, A <: AbstractNamedDimsArray{T}, + } <: AbstractNamedDimsArray{T, Any} + union::Union{A, Prod{LazyNamedDimsArray{T, A}}} +end + +function NamedDimsArrays.inds(a::LazyNamedDimsArray) + if unwrap(a) isa AbstractNamedDimsArray + return inds(unwrap(a)) + elseif unwrap(a) isa Prod + return mapreduce(inds, symdiff, unwrap(a).factors) + else + return error("Variant not supported.") + end +end +function NamedDimsArrays.dename(a::LazyNamedDimsArray) + if unwrap(a) isa AbstractNamedDimsArray + return dename(unwrap(a)) + elseif unwrap(a) isa Prod + return dename(materialize(a), inds(a)) + else + return error("Variant not supported.") + end +end + +using Base.Broadcast: materialize +function Base.Broadcast.materialize(a::LazyNamedDimsArray) + if unwrap(a) isa AbstractNamedDimsArray + return unwrap(a) + elseif unwrap(a) isa Prod + return prod(materialize, unwrap(a).factors) + else + return error("Variant not supported.") + end +end +Base.copy(a::LazyNamedDimsArray) = materialize(a) + +function Base.:*(a::LazyNamedDimsArray) + if unwrap(a) isa AbstractNamedDimsArray + return LazyNamedDimsArray(Prod([lazy(unwrap(a))])) + elseif unwrap(a) isa Prod + return a + else + return error("Variant not supported.") + end +end + +function Base.:*(a1::LazyNamedDimsArray, a2::LazyNamedDimsArray) + # Nested by default. + return LazyNamedDimsArray(Prod([a1, a2])) +end +function Base.:+(a1::LazyNamedDimsArray, a2::LazyNamedDimsArray) + return error("Not implemented.") +end +function Base.:-(a1::LazyNamedDimsArray, a2::LazyNamedDimsArray) + return error("Not implemented.") +end +function Base.:*(c::Number, a::LazyNamedDimsArray) + return error("Not implemented.") +end +function Base.:*(a::LazyNamedDimsArray, c::Number) + return error("Not implemented.") +end +function Base.:/(a::LazyNamedDimsArray, c::Number) + return error("Not implemented.") +end +function Base.:-(a::LazyNamedDimsArray) + return error("Not implemented.") +end + +function LazyNamedDimsArray(a::AbstractNamedDimsArray) + return LazyNamedDimsArray{eltype(a), typeof(a)}(a) +end +function LazyNamedDimsArray(a::Prod{LazyNamedDimsArray{T, A}}) where {T, A} + return LazyNamedDimsArray{T, A}(a) +end +function lazy(a::AbstractNamedDimsArray) + return LazyNamedDimsArray(a) +end + +# Broadcasting +struct LazyNamedDimsArrayStyle <: AbstractNamedDimsArrayStyle{Any} end +function Base.BroadcastStyle(::Type{<:LazyNamedDimsArray}) + return LazyNamedDimsArrayStyle() +end +function Broadcast.broadcasted(::LazyNamedDimsArrayStyle, f, as...) + return error("Arbitrary broadcasting not supported for LazyNamedDimsArray.") +end +# Linear operations. +function Broadcast.broadcasted(::LazyNamedDimsArrayStyle, ::typeof(+), a1, a2) + return a1 + a2 +end +function Broadcast.broadcasted(::LazyNamedDimsArrayStyle, ::typeof(-), a1, a2) + return a1 - a2 +end +function Broadcast.broadcasted(::LazyNamedDimsArrayStyle, ::typeof(*), c::Number, a) + return c * a +end +function Broadcast.broadcasted(::LazyNamedDimsArrayStyle, ::typeof(*), a, c::Number) + return a * c +end +# Fix ambiguity error. +function Broadcast.broadcasted(::LazyNamedDimsArrayStyle, ::typeof(*), a::Number, b::Number) + return a * b +end +function Broadcast.broadcasted(::LazyNamedDimsArrayStyle, ::typeof(/), a, c::Number) + return a / c +end +function Broadcast.broadcasted(::LazyNamedDimsArrayStyle, ::typeof(-), a) + return -a +end + +using TermInterface: TermInterface +# arguments, arity, children, head, iscall, operation +function TermInterface.arguments(a::LazyNamedDimsArray) + if unwrap(a) isa AbstractNamedDimsArray + return error("No arguments.") + elseif unwrap(a) isa Prod + unwrap(a).factors + else + return error("Variant not supported.") + end +end +function TermInterface.children(a::LazyNamedDimsArray) + return TermInterface.arguments(a) +end +function TermInterface.head(a::LazyNamedDimsArray) + return TermInterface.operation(a) +end +function TermInterface.iscall(a::LazyNamedDimsArray) + if unwrap(a) isa AbstractNamedDimsArray + return false + elseif unwrap(a) isa Prod + return true + else + return false + end +end +function TermInterface.isexpr(a::LazyNamedDimsArray) + return TermInterface.iscall(a) +end +function TermInterface.maketerm(::Type{LazyNamedDimsArray}, head, args, metadata) + if head ≡ prod + return LazyNamedDimsArray(Prod(args)) + else + return error("Only product terms supported right now.") + end +end +function TermInterface.operation(a::LazyNamedDimsArray) + if unwrap(a) isa AbstractNamedDimsArray + return error("No operation.") + elseif unwrap(a) isa Prod + prod + else + return error("Variant not supported.") + end +end +function TermInterface.sorted_arguments(a::LazyNamedDimsArray) + if unwrap(a) isa AbstractNamedDimsArray + return error("No arguments.") + elseif unwrap(a) isa Prod + return TermInterface.arguments(a) + else + return error("Variant not supported.") + end +end + +end diff --git a/test/Project.toml b/test/Project.toml index b22f9d1..9646508 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -8,7 +8,9 @@ NamedDimsArrays = "60cbd0c0-df58-4cb7-918c-6f5607b73fde" NamedGraphs = "678767b0-92e7-4007-89e4-4527a8725b19" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb" +TermInterface = "8ea1fca8-c5ef-4a55-8b96-4e9afe9c9a3c" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +WrappedUnions = "325db55a-9c6c-5b90-b1a2-ec87e7a38c44" [compat] Aqua = "0.8.14" @@ -20,4 +22,6 @@ NamedDimsArrays = "0.8" NamedGraphs = "0.6.8, 0.7" SafeTestsets = "0.1" Suppressor = "0.2.8" +TermInterface = "2" Test = "1.10" +WrappedUnions = "0.3" diff --git a/test/test_lazynameddimsarrays.jl b/test/test_lazynameddimsarrays.jl new file mode 100644 index 0000000..958c191 --- /dev/null +++ b/test/test_lazynameddimsarrays.jl @@ -0,0 +1,56 @@ +using Base.Broadcast: materialize +using ITensorNetworksNext.LazyNamedDimsArrays: LazyNamedDimsArray, Prod, lazy +using NamedDimsArrays: NamedDimsArray, inds, nameddims +using TermInterface: + arguments, arity, children, head, iscall, isexpr, maketerm, operation, sorted_arguments +using Test: @test, @test_throws, @testset +using WrappedUnions: unwrap + +@testset "LazyNamedDimsArrays" begin + @testset "Basics" begin + a1 = nameddims(randn(2, 2), (:i, :j)) + a2 = nameddims(randn(2, 2), (:j, :k)) + a3 = nameddims(randn(2, 2), (:k, :l)) + l1, l2, l3 = lazy.((a1, a2, a3)) + for li in (l1, l2, l3) + @test li isa LazyNamedDimsArray + @test unwrap(li) isa NamedDimsArray + @test inds(li) == inds(unwrap(li)) + @test copy(li) == unwrap(li) + @test materialize(li) == unwrap(li) + end + l = l1 * l2 * l3 + @test copy(l) ≈ a1 * a2 * a3 + @test materialize(l) ≈ a1 * a2 * a3 + @test issetequal(inds(l), symdiff(inds.((a1, a2, a3))...)) + @test unwrap(l) isa Prod + @test unwrap(l).factors == [l1 * l2, l3] + end + + @testset "TermInterface" begin + a1 = nameddims(randn(2, 2), (:i, :j)) + a2 = nameddims(randn(2, 2), (:j, :k)) + a3 = nameddims(randn(2, 2), (:k, :l)) + l1, l2, l3 = lazy.((a1, a2, a3)) + + @test_throws ErrorException arguments(l1) + @test_throws ErrorException arity(l1) + @test_throws ErrorException children(l1) + @test_throws ErrorException head(l1) + @test !iscall(l1) + @test !isexpr(l1) + @test_throws ErrorException operation(l1) + @test_throws ErrorException sorted_arguments(l1) + + l = l1 * l2 * l3 + @test arguments(l) == [l1 * l2, l3] + @test arity(l) == 2 + @test children(l) == [l1 * l2, l3] + @test head(l) ≡ prod + @test iscall(l) + @test isexpr(l) + @test l == maketerm(LazyNamedDimsArray, prod, [l1 * l2, l3], nothing) + @test operation(l) ≡ prod + @test sorted_arguments(l) == [l1 * l2, l3] + end +end From 8a6158e2c990e4b0bc4548f4a2050ad7a5dc2908 Mon Sep 17 00:00:00 2001 From: Matt Fishman Date: Tue, 7 Oct 2025 00:26:07 -0400 Subject: [PATCH 05/13] Change Prod to Mul (#10) --- Project.toml | 2 +- src/lazynameddimsarrays.jl | 162 +++++++++++++++++-------------- test/test_lazynameddimsarrays.jl | 28 ++++-- 3 files changed, 109 insertions(+), 83 deletions(-) diff --git a/Project.toml b/Project.toml index b527a53..b77134e 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "ITensorNetworksNext" uuid = "302f2e75-49f0-4526-aef7-d8ba550cb06c" authors = ["ITensor developers and contributors"] -version = "0.1.6" +version = "0.1.7" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" diff --git a/src/lazynameddimsarrays.jl b/src/lazynameddimsarrays.jl index 04eca26..3561cb6 100644 --- a/src/lazynameddimsarrays.jl +++ b/src/lazynameddimsarrays.jl @@ -7,42 +7,108 @@ using NamedDimsArrays: AbstractNamedDimsArrayStyle, dename, inds +using TermInterface: TermInterface, arguments, iscall, maketerm, operation, sorted_arguments -struct Prod{A} - factors::Vector{A} -end +struct Mul{A} + arguments::Vector{A} +end +TermInterface.arguments(m::Mul) = getfield(m, :arguments) +TermInterface.children(m::Mul) = arguments(m) +TermInterface.head(m::Mul) = operation(m) +TermInterface.iscall(m::Mul) = true +TermInterface.isexpr(m::Mul) = iscall(m) +TermInterface.maketerm(::Type{Mul}, head::typeof(*), args, metadata) = Mul(args) +TermInterface.operation(m::Mul) = * +TermInterface.sorted_arguments(m::Mul) = arguments(m) +TermInterface.sorted_children(m::Mul) = sorted_arguments(a) @wrapped struct LazyNamedDimsArray{ T, A <: AbstractNamedDimsArray{T}, } <: AbstractNamedDimsArray{T, Any} - union::Union{A, Prod{LazyNamedDimsArray{T, A}}} + union::Union{A, Mul{LazyNamedDimsArray{T, A}}} end function NamedDimsArrays.inds(a::LazyNamedDimsArray) - if unwrap(a) isa AbstractNamedDimsArray - return inds(unwrap(a)) - elseif unwrap(a) isa Prod - return mapreduce(inds, symdiff, unwrap(a).factors) + u = unwrap(a) + if u isa AbstractNamedDimsArray + return inds(u) + elseif u isa Mul + return mapreduce(inds, symdiff, arguments(u)) else return error("Variant not supported.") end end function NamedDimsArrays.dename(a::LazyNamedDimsArray) - if unwrap(a) isa AbstractNamedDimsArray - return dename(unwrap(a)) - elseif unwrap(a) isa Prod + u = unwrap(a) + if u isa AbstractNamedDimsArray + return dename(u) + elseif u isa Mul return dename(materialize(a), inds(a)) else return error("Variant not supported.") end end +function TermInterface.arguments(a::LazyNamedDimsArray) + u = unwrap(a) + if u isa AbstractNamedDimsArray + return error("No arguments.") + elseif u isa Mul + return arguments(u) + else + return error("Variant not supported.") + end +end +function TermInterface.children(a::LazyNamedDimsArray) + return arguments(a) +end +function TermInterface.head(a::LazyNamedDimsArray) + return operation(a) +end +function TermInterface.iscall(a::LazyNamedDimsArray) + return iscall(unwrap(a)) +end +function TermInterface.isexpr(a::LazyNamedDimsArray) + return iscall(a) +end +function TermInterface.maketerm(::Type{LazyNamedDimsArray}, head, args, metadata) + if head ≡ * + return LazyNamedDimsArray(maketerm(Mul, head, args, metadata)) + else + return error("Only product terms supported right now.") + end +end +function TermInterface.operation(a::LazyNamedDimsArray) + u = unwrap(a) + if u isa AbstractNamedDimsArray + return error("No operation.") + elseif u isa Mul + return operation(u) + else + return error("Variant not supported.") + end +end +function TermInterface.sorted_arguments(a::LazyNamedDimsArray) + u = unwrap(a) + if u isa AbstractNamedDimsArray + return error("No arguments.") + elseif u isa Mul + return sorted_arguments(u) + else + return error("Variant not supported.") + end +end +function TermInterface.sorted_children(a::LazyNamedDimsArray) + return sorted_arguments(a) +end + using Base.Broadcast: materialize function Base.Broadcast.materialize(a::LazyNamedDimsArray) - if unwrap(a) isa AbstractNamedDimsArray - return unwrap(a) - elseif unwrap(a) isa Prod - return prod(materialize, unwrap(a).factors) + u = unwrap(a) + if u isa AbstractNamedDimsArray + return u + elseif u isa Mul + return mapfoldl(materialize, operation(u), arguments(u)) else return error("Variant not supported.") end @@ -50,9 +116,10 @@ end Base.copy(a::LazyNamedDimsArray) = materialize(a) function Base.:*(a::LazyNamedDimsArray) - if unwrap(a) isa AbstractNamedDimsArray - return LazyNamedDimsArray(Prod([lazy(unwrap(a))])) - elseif unwrap(a) isa Prod + u = unwrap(a) + if u isa AbstractNamedDimsArray + return LazyNamedDimsArray(Mul([lazy(u)])) + elseif u isa Mul return a else return error("Variant not supported.") @@ -61,7 +128,7 @@ end function Base.:*(a1::LazyNamedDimsArray, a2::LazyNamedDimsArray) # Nested by default. - return LazyNamedDimsArray(Prod([a1, a2])) + return LazyNamedDimsArray(Mul([a1, a2])) end function Base.:+(a1::LazyNamedDimsArray, a2::LazyNamedDimsArray) return error("Not implemented.") @@ -85,7 +152,7 @@ end function LazyNamedDimsArray(a::AbstractNamedDimsArray) return LazyNamedDimsArray{eltype(a), typeof(a)}(a) end -function LazyNamedDimsArray(a::Prod{LazyNamedDimsArray{T, A}}) where {T, A} +function LazyNamedDimsArray(a::Mul{LazyNamedDimsArray{T, A}}) where {T, A} return LazyNamedDimsArray{T, A}(a) end function lazy(a::AbstractNamedDimsArray) @@ -124,59 +191,4 @@ function Broadcast.broadcasted(::LazyNamedDimsArrayStyle, ::typeof(-), a) return -a end -using TermInterface: TermInterface -# arguments, arity, children, head, iscall, operation -function TermInterface.arguments(a::LazyNamedDimsArray) - if unwrap(a) isa AbstractNamedDimsArray - return error("No arguments.") - elseif unwrap(a) isa Prod - unwrap(a).factors - else - return error("Variant not supported.") - end -end -function TermInterface.children(a::LazyNamedDimsArray) - return TermInterface.arguments(a) -end -function TermInterface.head(a::LazyNamedDimsArray) - return TermInterface.operation(a) -end -function TermInterface.iscall(a::LazyNamedDimsArray) - if unwrap(a) isa AbstractNamedDimsArray - return false - elseif unwrap(a) isa Prod - return true - else - return false - end -end -function TermInterface.isexpr(a::LazyNamedDimsArray) - return TermInterface.iscall(a) -end -function TermInterface.maketerm(::Type{LazyNamedDimsArray}, head, args, metadata) - if head ≡ prod - return LazyNamedDimsArray(Prod(args)) - else - return error("Only product terms supported right now.") - end -end -function TermInterface.operation(a::LazyNamedDimsArray) - if unwrap(a) isa AbstractNamedDimsArray - return error("No operation.") - elseif unwrap(a) isa Prod - prod - else - return error("Variant not supported.") - end -end -function TermInterface.sorted_arguments(a::LazyNamedDimsArray) - if unwrap(a) isa AbstractNamedDimsArray - return error("No arguments.") - elseif unwrap(a) isa Prod - return TermInterface.arguments(a) - else - return error("Variant not supported.") - end -end - end diff --git a/test/test_lazynameddimsarrays.jl b/test/test_lazynameddimsarrays.jl index 958c191..4c38c5e 100644 --- a/test/test_lazynameddimsarrays.jl +++ b/test/test_lazynameddimsarrays.jl @@ -1,8 +1,17 @@ using Base.Broadcast: materialize -using ITensorNetworksNext.LazyNamedDimsArrays: LazyNamedDimsArray, Prod, lazy +using ITensorNetworksNext.LazyNamedDimsArrays: LazyNamedDimsArray, Mul, lazy using NamedDimsArrays: NamedDimsArray, inds, nameddims using TermInterface: - arguments, arity, children, head, iscall, isexpr, maketerm, operation, sorted_arguments + arguments, + arity, + children, + head, + iscall, + isexpr, + maketerm, + operation, + sorted_arguments, + sorted_children using Test: @test, @test_throws, @testset using WrappedUnions: unwrap @@ -23,8 +32,11 @@ using WrappedUnions: unwrap @test copy(l) ≈ a1 * a2 * a3 @test materialize(l) ≈ a1 * a2 * a3 @test issetequal(inds(l), symdiff(inds.((a1, a2, a3))...)) - @test unwrap(l) isa Prod - @test unwrap(l).factors == [l1 * l2, l3] + @test unwrap(l) isa Mul + @test unwrap(l).arguments == [l1 * l2, l3] + # TermInterface.jl + @test operation(unwrap(l)) ≡ * + @test arguments(unwrap(l)) == [l1 * l2, l3] end @testset "TermInterface" begin @@ -41,16 +53,18 @@ using WrappedUnions: unwrap @test !isexpr(l1) @test_throws ErrorException operation(l1) @test_throws ErrorException sorted_arguments(l1) + @test_throws ErrorException sorted_children(l1) l = l1 * l2 * l3 @test arguments(l) == [l1 * l2, l3] @test arity(l) == 2 @test children(l) == [l1 * l2, l3] - @test head(l) ≡ prod + @test head(l) ≡ * @test iscall(l) @test isexpr(l) - @test l == maketerm(LazyNamedDimsArray, prod, [l1 * l2, l3], nothing) - @test operation(l) ≡ prod + @test l == maketerm(LazyNamedDimsArray, *, [l1 * l2, l3], nothing) + @test operation(l) ≡ * @test sorted_arguments(l) == [l1 * l2, l3] + @test sorted_children(l) == [l1 * l2, l3] end end From 0c19c8bcb5f19fe0f723c3d7e7ef3eea77f9c827 Mon Sep 17 00:00:00 2001 From: Matt Fishman Date: Tue, 7 Oct 2025 18:28:33 -0400 Subject: [PATCH 06/13] Better printing, equality, symbolic arrays (#11) --- Project.toml | 4 +- src/lazynameddimsarrays.jl | 161 ++++++++++++++++++++++++++++--- test/Project.toml | 2 + test/test_lazynameddimsarrays.jl | 43 ++++++++- 4 files changed, 191 insertions(+), 19 deletions(-) diff --git a/Project.toml b/Project.toml index b77134e..5cfd81d 100644 --- a/Project.toml +++ b/Project.toml @@ -1,9 +1,10 @@ name = "ITensorNetworksNext" uuid = "302f2e75-49f0-4526-aef7-d8ba550cb06c" authors = ["ITensor developers and contributors"] -version = "0.1.7" +version = "0.1.8" [deps] +AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" BackendSelection = "680c2d7c-f67a-4cc9-ae9c-da132b1447a5" DataGraphs = "b5a273c3-7e6c-41f6-98bd-8d7f1525a36a" @@ -19,6 +20,7 @@ TermInterface = "8ea1fca8-c5ef-4a55-8b96-4e9afe9c9a3c" WrappedUnions = "325db55a-9c6c-5b90-b1a2-ec87e7a38c44" [compat] +AbstractTrees = "0.4.5" Adapt = "4.3" BackendSelection = "0.1.6" DataGraphs = "0.2.7" diff --git a/src/lazynameddimsarrays.jl b/src/lazynameddimsarrays.jl index 3561cb6..4a038e8 100644 --- a/src/lazynameddimsarrays.jl +++ b/src/lazynameddimsarrays.jl @@ -1,14 +1,25 @@ module LazyNamedDimsArrays +using AbstractTrees: AbstractTrees using WrappedUnions: @wrapped, unwrap using NamedDimsArrays: NamedDimsArrays, AbstractNamedDimsArray, AbstractNamedDimsArrayStyle, + NamedDimsArray, dename, + dimnames, inds using TermInterface: TermInterface, arguments, iscall, maketerm, operation, sorted_arguments +# Custom version of `AbstractTrees.printnode` to +# avoid type piracy when overloading on `AbstractNamedDimsArray`. +printnode(io::IO, x) = AbstractTrees.printnode(io, x) +function printnode(io::IO, a::AbstractNamedDimsArray) + show(io, collect(dimnames(a))) + return nothing +end + struct Mul{A} arguments::Vector{A} end @@ -21,6 +32,13 @@ TermInterface.maketerm(::Type{Mul}, head::typeof(*), args, metadata) = Mul(args) TermInterface.operation(m::Mul) = * TermInterface.sorted_arguments(m::Mul) = arguments(m) TermInterface.sorted_children(m::Mul) = sorted_arguments(a) +ismul(x) = false +ismul(m::Mul) = true +function Base.show(io::IO, m::Mul) + args = map(arg -> sprint(printnode, arg), arguments(m)) + print(io, "(", join(args, " $(operation(m)) "), ")") + return nothing +end @wrapped struct LazyNamedDimsArray{ T, A <: AbstractNamedDimsArray{T}, @@ -30,9 +48,9 @@ end function NamedDimsArrays.inds(a::LazyNamedDimsArray) u = unwrap(a) - if u isa AbstractNamedDimsArray + if !iscall(u) return inds(u) - elseif u isa Mul + elseif ismul(u) return mapreduce(inds, symdiff, arguments(u)) else return error("Variant not supported.") @@ -40,10 +58,8 @@ function NamedDimsArrays.inds(a::LazyNamedDimsArray) end function NamedDimsArrays.dename(a::LazyNamedDimsArray) u = unwrap(a) - if u isa AbstractNamedDimsArray + if !iscall(u) return dename(u) - elseif u isa Mul - return dename(materialize(a), inds(a)) else return error("Variant not supported.") end @@ -51,9 +67,9 @@ end function TermInterface.arguments(a::LazyNamedDimsArray) u = unwrap(a) - if u isa AbstractNamedDimsArray + if !iscall(u) return error("No arguments.") - elseif u isa Mul + elseif ismul(u) return arguments(u) else return error("Variant not supported.") @@ -75,14 +91,14 @@ function TermInterface.maketerm(::Type{LazyNamedDimsArray}, head, args, metadata if head ≡ * return LazyNamedDimsArray(maketerm(Mul, head, args, metadata)) else - return error("Only product terms supported right now.") + return error("Only mul supported right now.") end end function TermInterface.operation(a::LazyNamedDimsArray) u = unwrap(a) - if u isa AbstractNamedDimsArray + if !iscall(u) return error("No operation.") - elseif u isa Mul + elseif ismul(u) return operation(u) else return error("Variant not supported.") @@ -90,9 +106,9 @@ function TermInterface.operation(a::LazyNamedDimsArray) end function TermInterface.sorted_arguments(a::LazyNamedDimsArray) u = unwrap(a) - if u isa AbstractNamedDimsArray + if !iscall(u) return error("No arguments.") - elseif u isa Mul + elseif ismul(u) return sorted_arguments(u) else return error("Variant not supported.") @@ -101,13 +117,29 @@ end function TermInterface.sorted_children(a::LazyNamedDimsArray) return sorted_arguments(a) end +ismul(a::LazyNamedDimsArray) = ismul(unwrap(a)) + +function AbstractTrees.children(a::LazyNamedDimsArray) + if !iscall(a) + return () + else + return arguments(a) + end +end +function AbstractTrees.nodevalue(a::LazyNamedDimsArray) + if !iscall(a) + return unwrap(a) + else + return operation(a) + end +end using Base.Broadcast: materialize function Base.Broadcast.materialize(a::LazyNamedDimsArray) u = unwrap(a) - if u isa AbstractNamedDimsArray + if !iscall(u) return u - elseif u isa Mul + elseif ismul(u) return mapfoldl(materialize, operation(u), arguments(u)) else return error("Variant not supported.") @@ -115,11 +147,45 @@ function Base.Broadcast.materialize(a::LazyNamedDimsArray) end Base.copy(a::LazyNamedDimsArray) = materialize(a) +function Base.:(==)(a1::LazyNamedDimsArray, a2::LazyNamedDimsArray) + u1, u2 = unwrap.((a1, a2)) + if !iscall(u1) && !iscall(u2) + return u1 == u2 + elseif ismul(u1) && ismul(u2) + return arguments(u1) == arguments(u2) + else + return false + end +end + +function printnode(io::IO, a::LazyNamedDimsArray) + return printnode(io, unwrap(a)) +end +function AbstractTrees.printnode(io::IO, a::LazyNamedDimsArray) + return printnode(io, a) +end +function Base.show(io::IO, a::LazyNamedDimsArray) + if !iscall(a) + return show(io, unwrap(a)) + else + return printnode(io, a) + end +end +function Base.show(io::IO, mime::MIME"text/plain", a::LazyNamedDimsArray) + if !iscall(a) + @invoke show(io, mime, a::AbstractNamedDimsArray) + return nothing + else + show(io, a) + return nothing + end +end + function Base.:*(a::LazyNamedDimsArray) u = unwrap(a) - if u isa AbstractNamedDimsArray + if !iscall(u) return LazyNamedDimsArray(Mul([lazy(u)])) - elseif u isa Mul + elseif ismul(u) return a else return error("Variant not supported.") @@ -191,4 +257,67 @@ function Broadcast.broadcasted(::LazyNamedDimsArrayStyle, ::typeof(-), a) return -a end +struct SymbolicArray{T, N, Name, Axes <: NTuple{N, AbstractUnitRange{<:Integer}}} <: AbstractArray{T, N} + name::Name + axes::Axes + function SymbolicArray{T}(name, ax::Tuple{Vararg{AbstractUnitRange{<:Integer}}}) where {T} + N = length(ax) + return new{T, N, typeof(name), typeof(ax)}(name, ax) + end +end +function SymbolicArray(name, ax::Tuple{Vararg{AbstractUnitRange{<:Integer}}}) + return SymbolicArray{Any}(name, ax) +end +function SymbolicArray{T}(name, ax::AbstractUnitRange...) where {T} + return SymbolicArray{T}(name, ax) +end +function SymbolicArray(name, ax::AbstractUnitRange...) + return SymbolicArray{Any}(name, ax) +end +symname(a::SymbolicArray) = getfield(a, :name) +Base.axes(a::SymbolicArray) = getfield(a, :axes) +Base.size(a::SymbolicArray) = length.(axes(a)) +function Base.:(==)(a::SymbolicArray, b::SymbolicArray) + return symname(a) == symname(b) && axes(a) == axes(b) +end +function Base.show(io::IO, mime::MIME"text/plain", a::SymbolicArray) + Base.summary(io, a) + println(io, ":") + print(io, repr(symname(a))) + return nothing +end +function Base.show(io::IO, a::SymbolicArray) + print(io, "SymbolicArray(", symname(a), ", ", size(a), ")") + return nothing +end +using AbstractTrees: AbstractTrees +function AbstractTrees.printnode(io::IO, a::SymbolicArray) + print(io, repr(symname(a))) + return nothing +end +const SymbolicNamedDimsArray{T, N, Parent <: SymbolicArray{T, N}, DimNames} = + NamedDimsArray{T, N, Parent, DimNames} +function symnameddims(name) + return lazy(NamedDimsArray(SymbolicArray(name), ())) +end +function printnode(io::IO, a::SymbolicNamedDimsArray) + print(io, symname(dename(a))) + if ndims(a) > 0 + print(io, "[", join(dimnames(a), ","), "]") + end + return nothing +end +function Base.:(==)(a::SymbolicNamedDimsArray, b::SymbolicNamedDimsArray) + return issetequal(inds(a), inds(b)) && dename(a) == dename(b) +end +function Base.:*(a::SymbolicNamedDimsArray, b::SymbolicNamedDimsArray) + return lazy(a) * lazy(b) +end +function Base.:*(a::SymbolicNamedDimsArray, b::LazyNamedDimsArray) + return lazy(a) * b +end +function Base.:*(a::LazyNamedDimsArray, b::SymbolicNamedDimsArray) + return a * lazy(b) +end + end diff --git a/test/Project.toml b/test/Project.toml index 9646508..5a5ce6a 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,4 +1,5 @@ [deps] +AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" Dictionaries = "85a47980-9c8c-11e8-2b9f-f7ca1fa99fb4" Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6" @@ -13,6 +14,7 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" WrappedUnions = "325db55a-9c6c-5b90-b1a2-ec87e7a38c44" [compat] +AbstractTrees = "0.4.5" Aqua = "0.8.14" Dictionaries = "0.4.5" Graphs = "1.13.1" diff --git a/test/test_lazynameddimsarrays.jl b/test/test_lazynameddimsarrays.jl index 4c38c5e..40735ec 100644 --- a/test/test_lazynameddimsarrays.jl +++ b/test/test_lazynameddimsarrays.jl @@ -1,6 +1,8 @@ +using AbstractTrees: AbstractTrees, print_tree, printnode using Base.Broadcast: materialize -using ITensorNetworksNext.LazyNamedDimsArrays: LazyNamedDimsArray, Mul, lazy -using NamedDimsArrays: NamedDimsArray, inds, nameddims +using ITensorNetworksNext.LazyNamedDimsArrays: + LazyNamedDimsArray, Mul, SymbolicArray, ismul, lazy, symnameddims +using NamedDimsArrays: NamedDimsArray, dename, dimnames, inds, nameddims using TermInterface: arguments, arity, @@ -33,6 +35,7 @@ using WrappedUnions: unwrap @test materialize(l) ≈ a1 * a2 * a3 @test issetequal(inds(l), symdiff(inds.((a1, a2, a3))...)) @test unwrap(l) isa Mul + @test ismul(unwrap(l)) @test unwrap(l).arguments == [l1 * l2, l3] # TermInterface.jl @test operation(unwrap(l)) ≡ * @@ -54,6 +57,15 @@ using WrappedUnions: unwrap @test_throws ErrorException operation(l1) @test_throws ErrorException sorted_arguments(l1) @test_throws ErrorException sorted_children(l1) + @test AbstractTrees.children(l1) ≡ () + @test AbstractTrees.nodevalue(l1) ≡ a1 + @test sprint(show, l1) == sprint(show, a1) + # TODO: Fix this test, it is basically correct but the type parameters + # print in a different way. + # @test sprint(show, MIME"text/plain"(), l1) == + # replace(sprint(show, MIME"text/plain"(), a1), "NamedDimsArray" => "LazyNamedDimsArray") + @test sprint(printnode, l1) == "[:i, :j]" + @test sprint(print_tree, l1) == "[:i, :j]\n" l = l1 * l2 * l3 @test arguments(l) == [l1 * l2, l3] @@ -66,5 +78,32 @@ using WrappedUnions: unwrap @test operation(l) ≡ * @test sorted_arguments(l) == [l1 * l2, l3] @test sorted_children(l) == [l1 * l2, l3] + @test AbstractTrees.children(l) == [l1 * l2, l3] + @test AbstractTrees.nodevalue(l) ≡ * + @test sprint(show, l) == "(([:i, :j] * [:j, :k]) * [:k, :l])" + @test sprint(show, MIME"text/plain"(), l) == "(([:i, :j] * [:j, :k]) * [:k, :l])" + @test sprint(printnode, l) == "(([:i, :j] * [:j, :k]) * [:k, :l])" + @test sprint(print_tree, l) == + "(([:i, :j] * [:j, :k]) * [:k, :l])\n├─ ([:i, :j] * [:j, :k])\n│ ├─ [:i, :j]\n│ └─ [:j, :k]\n└─ [:k, :l]\n" + end + + @testset "symnameddims" begin + a = symnameddims(:a) + b = symnameddims(:b) + c = symnameddims(:c) + @test a isa LazyNamedDimsArray + @test unwrap(a) isa NamedDimsArray + @test dename(a) isa SymbolicArray + @test dename(unwrap(a)) isa SymbolicArray + @test dename(unwrap(a)) == SymbolicArray(:a) + @test inds(a) == () + @test dimnames(a) == () + + ex = a * b * c + @test copy(ex) == ex + @test arguments(ex) == [a * b, c] + @test operation(ex) ≡ * + @test sprint(show, ex) == "((a * b) * c)" + @test sprint(show, MIME"text/plain"(), ex) == "((a * b) * c)" end end From adb429b67bce3db69e10c52effe5b6d02a086e26 Mon Sep 17 00:00:00 2001 From: Matt Fishman Date: Wed, 8 Oct 2025 19:44:13 -0400 Subject: [PATCH 07/13] Define `substitute`, split off generic code (#12) --- Project.toml | 4 +- src/lazynameddimsarrays.jl | 357 +++++++++++++++++++------------ test/test_lazynameddimsarrays.jl | 83 ++++--- 3 files changed, 277 insertions(+), 167 deletions(-) diff --git a/Project.toml b/Project.toml index 5cfd81d..b944769 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "ITensorNetworksNext" uuid = "302f2e75-49f0-4526-aef7-d8ba550cb06c" authors = ["ITensor developers and contributors"] -version = "0.1.8" +version = "0.1.9" [deps] AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" @@ -17,6 +17,7 @@ NamedGraphs = "678767b0-92e7-4007-89e4-4527a8725b19" SimpleTraits = "699a6c99-e7fa-54fc-8d76-47d257e15c1d" SplitApplyCombine = "03a91e81-4c3e-53e1-a0a4-9c0c8f19dd66" TermInterface = "8ea1fca8-c5ef-4a55-8b96-4e9afe9c9a3c" +TypeParameterAccessors = "7e5a90cf-f82e-492e-a09b-e3e26432c138" WrappedUnions = "325db55a-9c6c-5b90-b1a2-ec87e7a38c44" [compat] @@ -33,5 +34,6 @@ NamedGraphs = "0.6.9, 0.7" SimpleTraits = "0.9.5" SplitApplyCombine = "1.2.3" TermInterface = "2" +TypeParameterAccessors = "0.4.4" WrappedUnions = "0.3" julia = "1.10" diff --git a/src/lazynameddimsarrays.jl b/src/lazynameddimsarrays.jl index 4a038e8..e1b4b27 100644 --- a/src/lazynameddimsarrays.jl +++ b/src/lazynameddimsarrays.jl @@ -11,61 +11,55 @@ using NamedDimsArrays: dimnames, inds using TermInterface: TermInterface, arguments, iscall, maketerm, operation, sorted_arguments +using TypeParameterAccessors: unspecify_type_parameters -# Custom version of `AbstractTrees.printnode` to -# avoid type piracy when overloading on `AbstractNamedDimsArray`. -printnode(io::IO, x) = AbstractTrees.printnode(io, x) -function printnode(io::IO, a::AbstractNamedDimsArray) - show(io, collect(dimnames(a))) - return nothing -end +lazy(x) = error("Not defined.") -struct Mul{A} - arguments::Vector{A} +generic_map(f, v) = map(f, v) +generic_map(f, v::AbstractDict) = Dict(eachindex(v) .=> map(f, values(v))) +generic_map(f, v::AbstractSet) = Set([f(x) for x in v]) + +# Defined to avoid type piracy. +# TODO: Define a proper hash function +# in NamedDimsArrays.jl, maybe one that is +# independent of the order of dimensions. +function _hash(a::NamedDimsArray, h::UInt64) + h = hash(:NamedDimsArray, h) + h = hash(dename(a), h) + for i in inds(a) + h = hash(i, h) + end + return h end -TermInterface.arguments(m::Mul) = getfield(m, :arguments) -TermInterface.children(m::Mul) = arguments(m) -TermInterface.head(m::Mul) = operation(m) -TermInterface.iscall(m::Mul) = true -TermInterface.isexpr(m::Mul) = iscall(m) -TermInterface.maketerm(::Type{Mul}, head::typeof(*), args, metadata) = Mul(args) -TermInterface.operation(m::Mul) = * -TermInterface.sorted_arguments(m::Mul) = arguments(m) -TermInterface.sorted_children(m::Mul) = sorted_arguments(a) -ismul(x) = false -ismul(m::Mul) = true -function Base.show(io::IO, m::Mul) - args = map(arg -> sprint(printnode, arg), arguments(m)) - print(io, "(", join(args, " $(operation(m)) "), ")") - return nothing +function _hash(x, h::UInt64) + return hash(x, h) end -@wrapped struct LazyNamedDimsArray{ - T, A <: AbstractNamedDimsArray{T}, - } <: AbstractNamedDimsArray{T, Any} - union::Union{A, Mul{LazyNamedDimsArray{T, A}}} +# Custom version of `AbstractTrees.printnode` to +# avoid type piracy when overloading on `AbstractNamedDimsArray`. +printnode_nameddims(io::IO, x) = AbstractTrees.printnode(io, x) +function printnode_nameddims(io::IO, a::AbstractNamedDimsArray) + show(io, collect(dimnames(a))) + return nothing end -function NamedDimsArrays.inds(a::LazyNamedDimsArray) - u = unwrap(a) - if !iscall(u) - return inds(u) - elseif ismul(u) - return mapreduce(inds, symdiff, arguments(u)) +# Generic lazy functionality. +function maketerm_lazy(type::Type, head, args, metadata) + if head ≡ * + return type(maketerm(Mul, head, args, metadata)) else - return error("Variant not supported.") + return error("Only mul supported right now.") end end -function NamedDimsArrays.dename(a::LazyNamedDimsArray) +function getindex_lazy(a::AbstractArray, I...) u = unwrap(a) if !iscall(u) - return dename(u) + return u[I...] else - return error("Variant not supported.") + return error("Indexing into expression not supported.") end end - -function TermInterface.arguments(a::LazyNamedDimsArray) +function arguments_lazy(a) u = unwrap(a) if !iscall(u) return error("No arguments.") @@ -75,26 +69,19 @@ function TermInterface.arguments(a::LazyNamedDimsArray) return error("Variant not supported.") end end -function TermInterface.children(a::LazyNamedDimsArray) +function children_lazy(a) return arguments(a) end -function TermInterface.head(a::LazyNamedDimsArray) +function head_lazy(a) return operation(a) end -function TermInterface.iscall(a::LazyNamedDimsArray) +function iscall_lazy(a) return iscall(unwrap(a)) end -function TermInterface.isexpr(a::LazyNamedDimsArray) +function isexpr_lazy(a) return iscall(a) end -function TermInterface.maketerm(::Type{LazyNamedDimsArray}, head, args, metadata) - if head ≡ * - return LazyNamedDimsArray(maketerm(Mul, head, args, metadata)) - else - return error("Only mul supported right now.") - end -end -function TermInterface.operation(a::LazyNamedDimsArray) +function operation_lazy(a) u = unwrap(a) if !iscall(u) return error("No operation.") @@ -104,7 +91,7 @@ function TermInterface.operation(a::LazyNamedDimsArray) return error("Variant not supported.") end end -function TermInterface.sorted_arguments(a::LazyNamedDimsArray) +function sorted_arguments_lazy(a) u = unwrap(a) if !iscall(u) return error("No arguments.") @@ -114,28 +101,26 @@ function TermInterface.sorted_arguments(a::LazyNamedDimsArray) return error("Variant not supported.") end end -function TermInterface.sorted_children(a::LazyNamedDimsArray) +function sorted_children_lazy(a) return sorted_arguments(a) end -ismul(a::LazyNamedDimsArray) = ismul(unwrap(a)) - -function AbstractTrees.children(a::LazyNamedDimsArray) +ismul_lazy(a) = ismul(unwrap(a)) +function abstracttrees_children_lazy(a) if !iscall(a) return () else return arguments(a) end end -function AbstractTrees.nodevalue(a::LazyNamedDimsArray) +function nodevalue_lazy(a) if !iscall(a) return unwrap(a) else return operation(a) end end - using Base.Broadcast: materialize -function Base.Broadcast.materialize(a::LazyNamedDimsArray) +function materialize_lazy(a) u = unwrap(a) if !iscall(u) return u @@ -145,9 +130,8 @@ function Base.Broadcast.materialize(a::LazyNamedDimsArray) return error("Variant not supported.") end end -Base.copy(a::LazyNamedDimsArray) = materialize(a) - -function Base.:(==)(a1::LazyNamedDimsArray, a2::LazyNamedDimsArray) +copy_lazy(a) = materialize(a) +function equals_lazy(a1, a2) u1, u2 = unwrap.((a1, a2)) if !iscall(u1) && !iscall(u2) return u1 == u2 @@ -157,105 +141,210 @@ function Base.:(==)(a1::LazyNamedDimsArray, a2::LazyNamedDimsArray) return false end end - -function printnode(io::IO, a::LazyNamedDimsArray) - return printnode(io, unwrap(a)) +function hash_lazy(a, h::UInt64) + h = hash(Symbol(unspecify_type_parameters(typeof(a))), h) + # Use `_hash`, which defines a custom hash for NamedDimsArray. + return _hash(unwrap(a), h) end -function AbstractTrees.printnode(io::IO, a::LazyNamedDimsArray) - return printnode(io, a) +function map_arguments_lazy(f, a) + u = unwrap(a) + if !iscall(u) + return error("No arguments to map.") + elseif ismul(u) + return lazy(map_arguments(f, u)) + else + return error("Variant not supported.") + end +end +function substitute_lazy(a, substitutions::AbstractDict) + haskey(substitutions, a) && return substitutions[a] + !iscall(a) && return a + return map_arguments(arg -> substitute(arg, substitutions), a) end -function Base.show(io::IO, a::LazyNamedDimsArray) +function substitute_lazy(a, substitutions) + return substitute(a, Dict(substitutions)) +end +function printnode_lazy(io, a) + # Use `printnode_nameddims` to avoid type piracy, + # since it overloads on `AbstractNamedDimsArray`. + return printnode_nameddims(io, unwrap(a)) +end +function show_lazy(io::IO, a) if !iscall(a) return show(io, unwrap(a)) else - return printnode(io, a) + return AbstractTrees.printnode(io, a) end end -function Base.show(io::IO, mime::MIME"text/plain", a::LazyNamedDimsArray) +function show_lazy(io::IO, mime::MIME"text/plain", a) + summary(io, a) + println(io, ":") if !iscall(a) - @invoke show(io, mime, a::AbstractNamedDimsArray) + show(io, mime, unwrap(a)) return nothing else show(io, a) return nothing end end - -function Base.:*(a::LazyNamedDimsArray) +add_lazy(a1, a2) = error("Not implemented.") +sub_lazy(a) = error("Not implemented.") +sub_lazy(a1, a2) = error("Not implemented.") +function mul_lazy(a) u = unwrap(a) if !iscall(u) - return LazyNamedDimsArray(Mul([lazy(u)])) + return lazy(Mul([a])) elseif ismul(u) return a else return error("Variant not supported.") end end +# Note that this is nested by default. +mul_lazy(a1, a2) = lazy(Mul([a1, a2])) +mul_lazy(a1::Number, a2) = error("Not implemented.") +mul_lazy(a1, a2::Number) = error("Not implemented.") +mul_lazy(a1::Number, a2::Number) = a1 * a2 +div_lazy(a1, a2::Number) = error("Not implemented.") -function Base.:*(a1::LazyNamedDimsArray, a2::LazyNamedDimsArray) - # Nested by default. - return LazyNamedDimsArray(Mul([a1, a2])) -end -function Base.:+(a1::LazyNamedDimsArray, a2::LazyNamedDimsArray) - return error("Not implemented.") +# NamedDimsArrays.jl interface. +function inds_lazy(a) + u = unwrap(a) + if !iscall(u) + return inds(u) + elseif ismul(u) + return mapreduce(inds, symdiff, arguments(u)) + else + return error("Variant not supported.") + end end -function Base.:-(a1::LazyNamedDimsArray, a2::LazyNamedDimsArray) - return error("Not implemented.") +function dename_lazy(a) + u = unwrap(a) + if !iscall(u) + return dename(u) + else + return error("Variant not supported.") + end end -function Base.:*(c::Number, a::LazyNamedDimsArray) - return error("Not implemented.") + +# Lazy broadcasting. +struct LazyNamedDimsArrayStyle <: AbstractNamedDimsArrayStyle{Any} end +function Broadcast.broadcasted(::LazyNamedDimsArrayStyle, f, as...) + return error("Arbitrary broadcasting not supported for LazyNamedDimsArray.") end -function Base.:*(a::LazyNamedDimsArray, c::Number) - return error("Not implemented.") +# Linear operations. +Broadcast.broadcasted(::LazyNamedDimsArrayStyle, ::typeof(+), a1, a2) = a1 + a2 +Broadcast.broadcasted(::LazyNamedDimsArrayStyle, ::typeof(-), a1, a2) = a1 - a2 +Broadcast.broadcasted(::LazyNamedDimsArrayStyle, ::typeof(*), c::Number, a) = c * a +Broadcast.broadcasted(::LazyNamedDimsArrayStyle, ::typeof(*), a, c::Number) = a * c +Broadcast.broadcasted(::LazyNamedDimsArrayStyle, ::typeof(*), a::Number, b::Number) = a * b +Broadcast.broadcasted(::LazyNamedDimsArrayStyle, ::typeof(/), a, c::Number) = a / c +Broadcast.broadcasted(::LazyNamedDimsArrayStyle, ::typeof(-), a) = -a + +# Generic functionality for Applied types, like `Mul`, `Add`, etc. +ismul(a) = operation(a) ≡ * +head_applied(a) = operation(a) +iscall_applied(a) = true +isexpr_applied(a) = iscall(a) +function show_applied(io::IO, a) + args = map(arg -> sprint(AbstractTrees.printnode, arg), arguments(a)) + print(io, "(", join(args, " $(operation(a)) "), ")") + return nothing end -function Base.:/(a::LazyNamedDimsArray, c::Number) - return error("Not implemented.") +sorted_arguments_applied(a) = arguments(a) +children_applied(a) = arguments(a) +sorted_children_applied(a) = sorted_arguments(a) +function maketerm_applied(type, head, args, metadata) + term = type(args) + @assert head ≡ operation(term) + return term +end +map_arguments_applied(f, a) = unspecify_type_parameters(typeof(a))(map(f, arguments(a))) +function hash_applied(a, h::UInt64) + h = hash(Symbol(unspecify_type_parameters(typeof(a))), h) + for arg in arguments(a) + h = hash(arg, h) + end + return h end -function Base.:-(a::LazyNamedDimsArray) - return error("Not implemented.") + +abstract type Applied end +TermInterface.head(a::Applied) = head_applied(a) +TermInterface.iscall(a::Applied) = iscall_applied(a) +TermInterface.isexpr(a::Applied) = isexpr_applied(a) +Base.show(io::IO, a::Applied) = show_applied(io, a) +TermInterface.sorted_arguments(a::Applied) = sorted_arguments_applied(a) +TermInterface.children(a::Applied) = children_applied(a) +TermInterface.sorted_children(a::Applied) = sorted_children_applied(a) +function TermInterface.maketerm(type::Type{<:Applied}, head, args, metadata) + return maketerm_applied(type, head, args, metadata) +end +map_arguments(f, a::Applied) = map_arguments_applied(f, a) +Base.hash(a::Applied, h::UInt64) = hash_applied(a, h) + +struct Mul{A} <: Applied + arguments::Vector{A} end +TermInterface.arguments(m::Mul) = getfield(m, :arguments) +TermInterface.operation(m::Mul) = * +@wrapped struct LazyNamedDimsArray{ + T, A <: AbstractNamedDimsArray{T}, + } <: AbstractNamedDimsArray{T, Any} + union::Union{A, Mul{LazyNamedDimsArray{T, A}}} +end function LazyNamedDimsArray(a::AbstractNamedDimsArray) return LazyNamedDimsArray{eltype(a), typeof(a)}(a) end function LazyNamedDimsArray(a::Mul{LazyNamedDimsArray{T, A}}) where {T, A} return LazyNamedDimsArray{T, A}(a) end -function lazy(a::AbstractNamedDimsArray) - return LazyNamedDimsArray(a) -end +lazy(a::LazyNamedDimsArray) = a +lazy(a::AbstractNamedDimsArray) = LazyNamedDimsArray(a) +lazy(a::Mul{<:LazyNamedDimsArray}) = LazyNamedDimsArray(a) + +NamedDimsArrays.inds(a::LazyNamedDimsArray) = inds_lazy(a) +NamedDimsArrays.dename(a::LazyNamedDimsArray) = dename_lazy(a) # Broadcasting -struct LazyNamedDimsArrayStyle <: AbstractNamedDimsArrayStyle{Any} end function Base.BroadcastStyle(::Type{<:LazyNamedDimsArray}) return LazyNamedDimsArrayStyle() end -function Broadcast.broadcasted(::LazyNamedDimsArrayStyle, f, as...) - return error("Arbitrary broadcasting not supported for LazyNamedDimsArray.") -end -# Linear operations. -function Broadcast.broadcasted(::LazyNamedDimsArrayStyle, ::typeof(+), a1, a2) - return a1 + a2 -end -function Broadcast.broadcasted(::LazyNamedDimsArrayStyle, ::typeof(-), a1, a2) - return a1 - a2 -end -function Broadcast.broadcasted(::LazyNamedDimsArrayStyle, ::typeof(*), c::Number, a) - return c * a -end -function Broadcast.broadcasted(::LazyNamedDimsArrayStyle, ::typeof(*), a, c::Number) - return a * c -end -# Fix ambiguity error. -function Broadcast.broadcasted(::LazyNamedDimsArrayStyle, ::typeof(*), a::Number, b::Number) - return a * b -end -function Broadcast.broadcasted(::LazyNamedDimsArrayStyle, ::typeof(/), a, c::Number) - return a / c -end -function Broadcast.broadcasted(::LazyNamedDimsArrayStyle, ::typeof(-), a) - return -a -end + +# Derived functionality. +function TermInterface.maketerm(type::Type{LazyNamedDimsArray}, head, args, metadata) + return maketerm_lazy(type, head, args, metadata) +end +Base.getindex(a::LazyNamedDimsArray, I::Int...) = getindex_lazy(a, I...) +TermInterface.arguments(a::LazyNamedDimsArray) = arguments_lazy(a) +TermInterface.children(a::LazyNamedDimsArray) = children_lazy(a) +TermInterface.head(a::LazyNamedDimsArray) = head_lazy(a) +TermInterface.iscall(a::LazyNamedDimsArray) = iscall_lazy(a) +TermInterface.isexpr(a::LazyNamedDimsArray) = isexpr_lazy(a) +TermInterface.operation(a::LazyNamedDimsArray) = operation_lazy(a) +TermInterface.sorted_arguments(a::LazyNamedDimsArray) = sorted_arguments_lazy(a) +AbstractTrees.children(a::LazyNamedDimsArray) = abstracttrees_children_lazy(a) +TermInterface.sorted_children(a::LazyNamedDimsArray) = sorted_children_lazy(a) +ismul(a::LazyNamedDimsArray) = ismul_lazy(a) +AbstractTrees.nodevalue(a::LazyNamedDimsArray) = nodevalue_lazy(a) +Base.Broadcast.materialize(a::LazyNamedDimsArray) = materialize_lazy(a) +Base.copy(a::LazyNamedDimsArray) = copy_lazy(a) +Base.:(==)(a1::LazyNamedDimsArray, a2::LazyNamedDimsArray) = equals_lazy(a1, a2) +Base.hash(a::LazyNamedDimsArray, h::UInt64) = hash_lazy(a, h) +map_arguments(f, a::LazyNamedDimsArray) = map_arguments_lazy(f, a) +substitute(a::LazyNamedDimsArray, substitutions) = substitute_lazy(a, substitutions) +AbstractTrees.printnode(io::IO, a::LazyNamedDimsArray) = printnode_lazy(io, a) +printnode_nameddims(io::IO, a::LazyNamedDimsArray) = printnode_lazy(io, a) +Base.show(io::IO, a::LazyNamedDimsArray) = show_lazy(io, a) +Base.show(io::IO, mime::MIME"text/plain", a::LazyNamedDimsArray) = show_lazy(io, mime, a) +Base.:*(a::LazyNamedDimsArray) = mul_lazy(a) +Base.:*(a1::LazyNamedDimsArray, a2::LazyNamedDimsArray) = mul_lazy(a1, a2) +Base.:+(a1::LazyNamedDimsArray, a2::LazyNamedDimsArray) = add_lazy(a1, a2) +Base.:-(a1::LazyNamedDimsArray, a2::LazyNamedDimsArray) = sub_lazy(a1, a2) +Base.:*(a1::Number, a2::LazyNamedDimsArray) = mul_lazy(a1, a2) +Base.:*(a1::LazyNamedDimsArray, a2::Number) = mul_lazy(a1, a2) +Base.:/(a1::LazyNamedDimsArray, a2::Number) = div_lazy(a1, a2) +Base.:-(a::LazyNamedDimsArray) = sub_lazy(a) struct SymbolicArray{T, N, Name, Axes <: NTuple{N, AbstractUnitRange{<:Integer}}} <: AbstractArray{T, N} name::Name @@ -280,6 +369,17 @@ Base.size(a::SymbolicArray) = length.(axes(a)) function Base.:(==)(a::SymbolicArray, b::SymbolicArray) return symname(a) == symname(b) && axes(a) == axes(b) end +function Base.hash(a::SymbolicArray, h::UInt64) + h = hash(:SymbolicArray, h) + h = hash(symname(a), h) + return hash(size(a), h) +end +function Base.getindex(a::SymbolicArray{<:Any, N}, I::Vararg{Int, N}) where {N} + return error("Indexing into SymbolicArray not supported.") +end +function Base.setindex!(a::SymbolicArray{<:Any, N}, value, I::Vararg{Int, N}) where {N} + return error("Indexing into SymbolicArray not supported.") +end function Base.show(io::IO, mime::MIME"text/plain", a::SymbolicArray) Base.summary(io, a) println(io, ":") @@ -300,24 +400,19 @@ const SymbolicNamedDimsArray{T, N, Parent <: SymbolicArray{T, N}, DimNames} = function symnameddims(name) return lazy(NamedDimsArray(SymbolicArray(name), ())) end -function printnode(io::IO, a::SymbolicNamedDimsArray) +function AbstractTrees.printnode(io::IO, a::SymbolicNamedDimsArray) print(io, symname(dename(a))) if ndims(a) > 0 print(io, "[", join(dimnames(a), ","), "]") end return nothing end +printnode_nameddims(io::IO, a::SymbolicNamedDimsArray) = AbstractTrees.printnode(io, a) function Base.:(==)(a::SymbolicNamedDimsArray, b::SymbolicNamedDimsArray) return issetequal(inds(a), inds(b)) && dename(a) == dename(b) end -function Base.:*(a::SymbolicNamedDimsArray, b::SymbolicNamedDimsArray) - return lazy(a) * lazy(b) -end -function Base.:*(a::SymbolicNamedDimsArray, b::LazyNamedDimsArray) - return lazy(a) * b -end -function Base.:*(a::LazyNamedDimsArray, b::SymbolicNamedDimsArray) - return a * lazy(b) -end +Base.:*(a::SymbolicNamedDimsArray, b::SymbolicNamedDimsArray) = lazy(a) * lazy(b) +Base.:*(a::SymbolicNamedDimsArray, b::LazyNamedDimsArray) = lazy(a) * b +Base.:*(a::LazyNamedDimsArray, b::SymbolicNamedDimsArray) = a * lazy(b) end diff --git a/test/test_lazynameddimsarrays.jl b/test/test_lazynameddimsarrays.jl index 40735ec..cc86fdc 100644 --- a/test/test_lazynameddimsarrays.jl +++ b/test/test_lazynameddimsarrays.jl @@ -1,27 +1,24 @@ using AbstractTrees: AbstractTrees, print_tree, printnode using Base.Broadcast: materialize -using ITensorNetworksNext.LazyNamedDimsArrays: - LazyNamedDimsArray, Mul, SymbolicArray, ismul, lazy, symnameddims -using NamedDimsArrays: NamedDimsArray, dename, dimnames, inds, nameddims -using TermInterface: - arguments, - arity, - children, - head, - iscall, - isexpr, - maketerm, - operation, - sorted_arguments, - sorted_children +using ITensorNetworksNext.LazyNamedDimsArrays: LazyNamedDimsArrays, LazyNamedDimsArray, + Mul, SymbolicArray, ismul, lazy, substitute, symnameddims +using NamedDimsArrays: NamedDimsArray, @names, dename, dimnames, inds, nameddims, namedoneto +using TermInterface: arguments, arity, children, head, iscall, isexpr, maketerm, operation, + sorted_arguments, sorted_children using Test: @test, @test_throws, @testset using WrappedUnions: unwrap @testset "LazyNamedDimsArrays" begin + function sprint_namespaced(x) + context = (:module => LazyNamedDimsArrays) + module_prefix = "ITensorNetworksNext.LazyNamedDimsArrays." + return replace(sprint(show, MIME"text/plain"(), x; context), module_prefix => "") + end @testset "Basics" begin - a1 = nameddims(randn(2, 2), (:i, :j)) - a2 = nameddims(randn(2, 2), (:j, :k)) - a3 = nameddims(randn(2, 2), (:k, :l)) + i, j, k, l = namedoneto.(2, (:i, :j, :k, :l)) + a1 = randn(i, j) + a2 = randn(j, k) + a3 = randn(k, l) l1, l2, l3 = lazy.((a1, a2, a3)) for li in (l1, l2, l3) @test li isa LazyNamedDimsArray @@ -62,8 +59,8 @@ using WrappedUnions: unwrap @test sprint(show, l1) == sprint(show, a1) # TODO: Fix this test, it is basically correct but the type parameters # print in a different way. - # @test sprint(show, MIME"text/plain"(), l1) == - # replace(sprint(show, MIME"text/plain"(), a1), "NamedDimsArray" => "LazyNamedDimsArray") + # @test sprint_namespaced(l1) == + # replace(sprint_namespaced(a1), "NamedDimsArray" => "LazyNamedDimsArray") @test sprint(printnode, l1) == "[:i, :j]" @test sprint(print_tree, l1) == "[:i, :j]\n" @@ -81,29 +78,45 @@ using WrappedUnions: unwrap @test AbstractTrees.children(l) == [l1 * l2, l3] @test AbstractTrees.nodevalue(l) ≡ * @test sprint(show, l) == "(([:i, :j] * [:j, :k]) * [:k, :l])" - @test sprint(show, MIME"text/plain"(), l) == "(([:i, :j] * [:j, :k]) * [:k, :l])" + @test sprint_namespaced(l) == + "named(Base.OneTo(2), :i)×named(Base.OneTo(2), :l) " * + "LazyNamedDimsArray{Float64, …}:\n(([:i, :j] * [:j, :k]) * [:k, :l])" @test sprint(printnode, l) == "(([:i, :j] * [:j, :k]) * [:k, :l])" @test sprint(print_tree, l) == - "(([:i, :j] * [:j, :k]) * [:k, :l])\n├─ ([:i, :j] * [:j, :k])\n│ ├─ [:i, :j]\n│ └─ [:j, :k]\n└─ [:k, :l]\n" + "(([:i, :j] * [:j, :k]) * [:k, :l])\n" * + "├─ ([:i, :j] * [:j, :k])\n" * + "│ ├─ [:i, :j]\n│ └─ [:j, :k]\n" * + "└─ [:k, :l]\n" end @testset "symnameddims" begin - a = symnameddims(:a) - b = symnameddims(:b) - c = symnameddims(:c) - @test a isa LazyNamedDimsArray - @test unwrap(a) isa NamedDimsArray - @test dename(a) isa SymbolicArray - @test dename(unwrap(a)) isa SymbolicArray - @test dename(unwrap(a)) == SymbolicArray(:a) - @test inds(a) == () - @test dimnames(a) == () + a1, a2, a3 = symnameddims.((:a1, :a2, :a3)) + @test a1 isa LazyNamedDimsArray + @test unwrap(a1) isa NamedDimsArray + @test dename(a1) isa SymbolicArray + @test dename(unwrap(a1)) isa SymbolicArray + @test dename(unwrap(a1)) == SymbolicArray(:a1) + @test inds(a1) == () + @test dimnames(a1) == () - ex = a * b * c + ex = a1 * a2 * a3 @test copy(ex) == ex - @test arguments(ex) == [a * b, c] + @test arguments(ex) == [a1 * a2, a3] @test operation(ex) ≡ * - @test sprint(show, ex) == "((a * b) * c)" - @test sprint(show, MIME"text/plain"(), ex) == "((a * b) * c)" + @test sprint(show, ex) == "((a1 * a2) * a3)" + @test sprint_namespaced(ex) == + "0-dimensional LazyNamedDimsArray{Any, …}:\n((a1 * a2) * a3)" + end + + @testset "substitute" begin + s = symnameddims.((:a1, :a2, :a3)) + i = @names i[1:4] + a = (randn(2, 2)[i[1], i[2]], randn(2, 2)[i[2], i[3]], randn(2, 2)[i[3], i[4]]) + l = lazy.(a) + + seq = s[1] * (s[2] * s[3]) + net = substitute(seq, s .=> l) + @test net == l[1] * (l[2] * l[3]) + @test arguments(net) == [l[1], l[2] * l[3]] end end From 281911608bf325f9305a803301ec5092507f2243 Mon Sep 17 00:00:00 2001 From: Matt Fishman Date: Thu, 23 Oct 2025 19:00:52 -0400 Subject: [PATCH 08/13] Fix lazy ITensor (#14) --- Project.toml | 2 +- src/lazynameddimsarrays.jl | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index b944769..72516e4 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "ITensorNetworksNext" uuid = "302f2e75-49f0-4526-aef7-d8ba550cb06c" authors = ["ITensor developers and contributors"] -version = "0.1.9" +version = "0.1.10" [deps] AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" diff --git a/src/lazynameddimsarrays.jl b/src/lazynameddimsarrays.jl index e1b4b27..23e0679 100644 --- a/src/lazynameddimsarrays.jl +++ b/src/lazynameddimsarrays.jl @@ -294,7 +294,9 @@ TermInterface.operation(m::Mul) = * union::Union{A, Mul{LazyNamedDimsArray{T, A}}} end function LazyNamedDimsArray(a::AbstractNamedDimsArray) - return LazyNamedDimsArray{eltype(a), typeof(a)}(a) + # Use `eltype(typeof(a))` for arrays that have different + # runtime and compile time eltypes, like `ITensor`. + return LazyNamedDimsArray{eltype(typeof(a)), typeof(a)}(a) end function LazyNamedDimsArray(a::Mul{LazyNamedDimsArray{T, A}}) where {T, A} return LazyNamedDimsArray{T, A}(a) From c4085f7289b66e1ef551e586dbd2fdb125fda3cd Mon Sep 17 00:00:00 2001 From: Joseph Tindall <51231103+JoeyT1994@users.noreply.github.com> Date: Sun, 26 Oct 2025 11:36:35 -0400 Subject: [PATCH 09/13] Working `contract_network` (#7) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Matt Fishman --- Project.toml | 9 +++- .../ITensorNetworksNextTensorOperationsExt.jl | 16 +++++++ src/ITensorNetworksNext.jl | 1 + src/contract_network.jl | 47 +++++++++++++++++++ test/Project.toml | 2 + test/test_contract_network.jl | 39 +++++++++++++++ 6 files changed, 113 insertions(+), 1 deletion(-) create mode 100644 ext/ITensorNetworksNextTensorOperationsExt/ITensorNetworksNextTensorOperationsExt.jl create mode 100644 src/contract_network.jl create mode 100644 test/test_contract_network.jl diff --git a/Project.toml b/Project.toml index 72516e4..42df730 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "ITensorNetworksNext" uuid = "302f2e75-49f0-4526-aef7-d8ba550cb06c" authors = ["ITensor developers and contributors"] -version = "0.1.10" +version = "0.1.11" [deps] AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" @@ -20,6 +20,12 @@ TermInterface = "8ea1fca8-c5ef-4a55-8b96-4e9afe9c9a3c" TypeParameterAccessors = "7e5a90cf-f82e-492e-a09b-e3e26432c138" WrappedUnions = "325db55a-9c6c-5b90-b1a2-ec87e7a38c44" +[weakdeps] +TensorOperations = "6aa20fa7-93e2-5fca-9bc0-fbd0db3c71a2" + +[extensions] +ITensorNetworksNextTensorOperationsExt = "TensorOperations" + [compat] AbstractTrees = "0.4.5" Adapt = "4.3" @@ -33,6 +39,7 @@ NamedDimsArrays = "0.8" NamedGraphs = "0.6.9, 0.7" SimpleTraits = "0.9.5" SplitApplyCombine = "1.2.3" +TensorOperations = "5.3.1" TermInterface = "2" TypeParameterAccessors = "0.4.4" WrappedUnions = "0.3" diff --git a/ext/ITensorNetworksNextTensorOperationsExt/ITensorNetworksNextTensorOperationsExt.jl b/ext/ITensorNetworksNextTensorOperationsExt/ITensorNetworksNextTensorOperationsExt.jl new file mode 100644 index 0000000..f3b90bf --- /dev/null +++ b/ext/ITensorNetworksNextTensorOperationsExt/ITensorNetworksNextTensorOperationsExt.jl @@ -0,0 +1,16 @@ +module ITensorNetworksNextTensorOperationsExt + +using BackendSelection: @Algorithm_str, Algorithm +using NamedDimsArrays: inds +using ITensorNetworksNext: ITensorNetworksNext, contraction_sequence_to_expr +using TensorOperations: TensorOperations, optimaltree + +function ITensorNetworksNext.contraction_sequence(::Algorithm"optimal", tn::Vector{<:AbstractArray}) + network = collect.(inds.(tn)) + #Converting dims to Float64 to minimize overflow issues + inds_to_dims = Dict(i => Float64(length(i)) for i in unique(reduce(vcat, network))) + seq, _ = optimaltree(network, inds_to_dims) + return contraction_sequence_to_expr(seq) +end + +end diff --git a/src/ITensorNetworksNext.jl b/src/ITensorNetworksNext.jl index 35c9e59..b59c3bd 100644 --- a/src/ITensorNetworksNext.jl +++ b/src/ITensorNetworksNext.jl @@ -3,5 +3,6 @@ module ITensorNetworksNext include("lazynameddimsarrays.jl") include("abstracttensornetwork.jl") include("tensornetwork.jl") +include("contract_network.jl") end diff --git a/src/contract_network.jl b/src/contract_network.jl new file mode 100644 index 0000000..67d69e0 --- /dev/null +++ b/src/contract_network.jl @@ -0,0 +1,47 @@ +using BackendSelection: @Algorithm_str, Algorithm +using ITensorNetworksNext.LazyNamedDimsArrays: substitute, materialize, lazy, + symnameddims + +#Algorithmic defaults +default_sequence_alg(::Algorithm"exact") = "leftassociative" +default_sequence(::Algorithm"exact") = nothing +function set_default_kwargs(alg::Algorithm"exact") + sequence = get(alg, :sequence, nothing) + sequence_alg = get(alg, :sequence_alg, default_sequence_alg(alg)) + return Algorithm("exact"; sequence, sequence_alg) +end + +function contraction_sequence_to_expr(seq) + if seq isa AbstractVector + return prod(contraction_sequence_to_expr, seq) + else + return symnameddims(seq) + end +end + +function contraction_sequence(::Algorithm"leftassociative", tn::Vector{<:AbstractArray}) + return prod(symnameddims, 1:length(tn)) +end + +function contraction_sequence(tn::Vector{<:AbstractArray}; sequence_alg = default_sequence_alg(Algorithm("exact"))) + return contraction_sequence(Algorithm(sequence_alg), tn) +end + +function contract_network(alg::Algorithm"exact", tn::Vector{<:AbstractArray}) + if !isnothing(alg.sequence) + sequence = alg.sequence + else + sequence = contraction_sequence(tn; sequence_alg = alg.sequence_alg) + end + + sequence = substitute(sequence, Dict(symnameddims(i) => lazy(tn[i]) for i in 1:length(tn))) + return materialize(sequence) +end + +function contract_network(alg::Algorithm"exact", tn::AbstractTensorNetwork) + return contract_network(alg, [tn[v] for v in vertices(tn)]) +end + +function contract_network(tn; alg, kwargs...) + return contract_network(set_default_kwargs(Algorithm(alg; kwargs...)), tn) +end diff --git a/test/Project.toml b/test/Project.toml index 5a5ce6a..4c12286 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -9,6 +9,7 @@ NamedDimsArrays = "60cbd0c0-df58-4cb7-918c-6f5607b73fde" NamedGraphs = "678767b0-92e7-4007-89e4-4527a8725b19" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb" +TensorOperations = "6aa20fa7-93e2-5fca-9bc0-fbd0db3c71a2" TermInterface = "8ea1fca8-c5ef-4a55-8b96-4e9afe9c9a3c" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" WrappedUnions = "325db55a-9c6c-5b90-b1a2-ec87e7a38c44" @@ -25,5 +26,6 @@ NamedGraphs = "0.6.8, 0.7" SafeTestsets = "0.1" Suppressor = "0.2.8" TermInterface = "2" +TensorOperations = "5.3.1" Test = "1.10" WrappedUnions = "0.3" diff --git a/test/test_contract_network.jl b/test/test_contract_network.jl new file mode 100644 index 0000000..2b7b945 --- /dev/null +++ b/test/test_contract_network.jl @@ -0,0 +1,39 @@ +using Graphs: edges +using NamedGraphs.GraphsExtensions: arranged_edges, incident_edges +using NamedGraphs.NamedGraphGenerators: named_grid +using ITensorBase: Index, ITensor +using ITensorNetworksNext: + TensorNetwork, linkinds, siteinds, contract_network +using TensorOperations: TensorOperations +using Test: @test, @testset + +@testset "contract_network" begin + @testset "Contract Vectors of ITensors" begin + i, j, k = Index(2), Index(2), Index(5) + A = ITensor([1.0 1.0; 0.5 1.0], i, j) + B = ITensor([2.0, 1.0], i) + C = ITensor([5.0, 1.0], j) + D = ITensor([-2.0, 3.0, 4.0, 5.0, 1.0], k) + + ABCD_1 = contract_network([A, B, C, D]; alg = "exact", sequence_alg = "leftassociative") + ABCD_2 = contract_network([A, B, C, D]; alg = "exact", sequence_alg = "optimal") + + @test ABCD_1 == ABCD_2 + end + + @testset "Contract One Dimensional Network" begin + dims = (4, 4) + g = named_grid(dims) + l = Dict(e => Index(2) for e in edges(g)) + l = merge(l, Dict(reverse(e) => l[e] for e in edges(g))) + tn = TensorNetwork(g) do v + is = map(e -> l[e], incident_edges(g, v)) + return randn(Tuple(is)) + end + + z1 = contract_network(tn; alg = "exact", sequence_alg = "optimal")[] + z2 = contract_network(tn; alg = "exact", sequence_alg = "leftassociative")[] + + @test abs(z1 - z2) / abs(z1) <= 1.0e3 * eps(Float64) + end +end From 7c9800e5fc89d4ca40929f2c35de64e11ffb0faf Mon Sep 17 00:00:00 2001 From: Joey Date: Thu, 2 Oct 2025 13:25:12 -0400 Subject: [PATCH 10/13] Working BP Commit --- src/ITensorNetworksNext.jl | 3 +++ src/abstracttensornetwork.jl | 2 +- test/test_beliefpropagation.jl | 25 +++++++++++++++++++++++++ 3 files changed, 29 insertions(+), 1 deletion(-) create mode 100644 test/test_beliefpropagation.jl diff --git a/src/ITensorNetworksNext.jl b/src/ITensorNetworksNext.jl index b59c3bd..62471ad 100644 --- a/src/ITensorNetworksNext.jl +++ b/src/ITensorNetworksNext.jl @@ -5,4 +5,7 @@ include("abstracttensornetwork.jl") include("tensornetwork.jl") include("contract_network.jl") +include("beliefpropagation/abstractbeliefpropagationcache.jl") +include("beliefpropagation/beliefpropagationcache.jl") + end diff --git a/src/abstracttensornetwork.jl b/src/abstracttensornetwork.jl index cdcf409..3cd2533 100644 --- a/src/abstracttensornetwork.jl +++ b/src/abstracttensornetwork.jl @@ -275,4 +275,4 @@ function Base.show(io::IO, mime::MIME"text/plain", graph::AbstractTensorNetwork) return nothing end -Base.show(io::IO, graph::AbstractTensorNetwork) = show(io, MIME"text/plain"(), graph) +Base.show(io::IO, graph::AbstractTensorNetwork) = show(io, MIME"text/plain"(), graph) \ No newline at end of file diff --git a/test/test_beliefpropagation.jl b/test/test_beliefpropagation.jl new file mode 100644 index 0000000..4b179fb --- /dev/null +++ b/test/test_beliefpropagation.jl @@ -0,0 +1,25 @@ +using Dictionaries: Dictionary +using ITensorBase: Index +using ITensorNetworksNext: ITensorNetworksNext, BeliefPropagationCache, TensorNetwork, adapt_messages, default_message, default_messages, edge_scalars, messages, setmessages!, factors, freenergy, + partitionfunction +using Graphs: edges, vertices +using NamedGraphs.NamedGraphGenerators: named_grid +using NamedGraphs.GraphsExtensions: arranged_edges, incident_edges +using Test: @test, @testset + +@testset "BeliefPropagation" begin + dims = (4, 1) + g = named_grid(dims) + l = Dict(e => Index(2) for e in edges(g)) + l = merge(l, Dict(reverse(e) => l[e] for e in edges(g))) + tn = TensorNetwork(g) do v + is = map(e -> l[e], incident_edges(g, v)) + return randn(Tuple(is)) + end + + bpc = BeliefPropagationCache(tn) + bpc = ITensorNetworksNext.update(bpc; maxiter = 10) + z_bp = partitionfunction(bpc) + z_exact = reduce(*, [tn[v] for v in vertices(g)])[] + @test abs(z_bp - z_exact) <= 1e-14 +end \ No newline at end of file From 7fd8cd9b3b00653dd000439dca9ac1c884bccf6b Mon Sep 17 00:00:00 2001 From: Joey Date: Thu, 23 Oct 2025 18:23:27 -0400 Subject: [PATCH 11/13] BP Code --- .../abstractbeliefpropagationcache.jl | 151 +++++++++++ .../beliefpropagationcache.jl | 237 ++++++++++++++++++ test/test_beliefpropagation.jl | 20 +- 3 files changed, 407 insertions(+), 1 deletion(-) create mode 100644 src/beliefpropagation/abstractbeliefpropagationcache.jl create mode 100644 src/beliefpropagation/beliefpropagationcache.jl diff --git a/src/beliefpropagation/abstractbeliefpropagationcache.jl b/src/beliefpropagation/abstractbeliefpropagationcache.jl new file mode 100644 index 0000000..5eae283 --- /dev/null +++ b/src/beliefpropagation/abstractbeliefpropagationcache.jl @@ -0,0 +1,151 @@ +abstract type AbstractBeliefPropagationCache{V} <: AbstractGraph{V} end + +#Interface +factor(bp_cache::AbstractBeliefPropagationCache, vertex) = not_implemented() +setfactor!(bp_cache::AbstractBeliefPropagationCache, vertex, factor) = not_implemented() +messages(bp_cache::AbstractBeliefPropagationCache) = not_implemented() +message(bp_cache::AbstractBeliefPropagationCache, edge::AbstractEdge) = not_implemented() +function default_message(bp_cache::AbstractBeliefPropagationCache, edge::AbstractEdge) + return not_implemented() +end +default_messages(bp_cache::AbstractBeliefPropagationCache) = not_implemented() +function setmessage!(bp_cache::AbstractBeliefPropagationCache, edge::AbstractEdge, message) + return not_implemented() +end +function deletemessage!(bp_cache::AbstractBeliefPropagationCache, edge::AbstractEdge) + return not_implemented() +end +function rescale_messages( + bp_cache::AbstractBeliefPropagationCache, edges::Vector{<:AbstractEdge}; kwargs... + ) + return not_implemented() +end +function rescale_vertices( + bp_cache::AbstractBeliefPropagationCache, vertices::Vector; kwargs... + ) + return not_implemented() +end + +function vertex_scalar(bp_cache::AbstractBeliefPropagationCache, vertex; kwargs...) + return not_implemented() +end +function edge_scalar( + bp_cache::AbstractBeliefPropagationCache, edge::AbstractEdge; kwargs... + ) + return not_implemented() +end + +#Graph functionality needed +Graphs.vertices(bp_cache::AbstractBeliefPropagationCache) = not_implemented() +Graphs.edges(bp_cache::AbstractBeliefPropagationCache) = not_implemented() +function NamedGraphs.GraphsExtensions.boundary_edges( + bp_cache::AbstractBeliefPropagationCache, vertices; kwargs... + ) + return not_implemented() +end + +#Functions derived from the interface +function setmessages!(bp_cache::AbstractBeliefPropagationCache, edges, messages) + for (e, m) in zip(edges) + setmessage!(bp_cache, e, m) + end + return +end + +function deletemessages!( + bp_cache::AbstractBeliefPropagationCache, edges::Vector{<:AbstractEdge} = edges(bp_cache) + ) + for e in edges + deletemessage!(bp_cache, e) + end + return bp_cache +end + +function vertex_scalars( + bp_cache::AbstractBeliefPropagationCache, vertices = Graphs.vertices(bp_cache); kwargs... + ) + return map(v -> region_scalar(bp_cache, v; kwargs...), vertices) +end + +function edge_scalars( + bp_cache::AbstractBeliefPropagationCache, edges = Graphs.edges(bp_cache); kwargs... + ) + return map(e -> region_scalar(bp_cache, e; kwargs...), edges) +end + +function scalar_factors_quotient(bp_cache::AbstractBeliefPropagationCache) + return vertex_scalars(bp_cache), edge_scalars(bp_cache) +end + +function incoming_messages( + bp_cache::AbstractBeliefPropagationCache, vertices::Vector{<:Any}; ignore_edges = [] + ) + b_edges = NamedGraphs.GraphsExtensions.boundary_edges(bp_cache, vertices; dir = :in) + b_edges = !isempty(ignore_edges) ? setdiff(b_edges, ignore_edges) : b_edges + return messages(bp_cache, b_edges) +end + +function incoming_messages(bp_cache::AbstractBeliefPropagationCache, vertex; kwargs...) + return incoming_messages(bp_cache, [vertex]; kwargs...) +end + +#Adapt interface for changing device +function map_messages(f, bp_cache::AbstractBeliefPropagationCache, es = edges(bp_cache)) + bp_cache = copy(bp_cache) + for e in es + setmessage!(bp_cache, e, f(message(bp_cache, e))) + end + return bp_cache +end +function map_factors(f, bp_cache::AbstractBeliefPropagationCache, vs = vertices(bp_cache)) + bp_cache = copy(bp_cache) + for v in vs + setfactor!(bp_cache, v, f(factor(bp_cache, v))) + end + return bp_cache +end +function adapt_messages(to, bp_cache::AbstractBeliefPropagationCache, args...) + return map_messages(adapt(to), bp_cache, args...) +end +function adapt_factors(to, bp_cache::AbstractBeliefPropagationCache, args...) + return map_factors(adapt(to), bp_cache, args...) +end + +function freenergy(bp_cache::AbstractBeliefPropagationCache) + numerator_terms, denominator_terms = scalar_factors_quotient(bp_cache) + if any(t -> real(t) < 0, numerator_terms) + numerator_terms = complex.(numerator_terms) + end + if any(t -> real(t) < 0, denominator_terms) + denominator_terms = complex.(denominator_terms) + end + + any(iszero, denominator_terms) && return -Inf + return sum(log.(numerator_terms)) - sum(log.((denominator_terms))) +end + +function partitionfunction(bp_cache::AbstractBeliefPropagationCache) + return exp(freenergy(bp_cache)) +end + +function rescale_messages(bp_cache::AbstractBeliefPropagationCache, edge::AbstractEdge) + return rescale_messages(bp_cache, [edge]) +end + +function rescale_messages(bp_cache::AbstractBeliefPropagationCache) + return rescale_messages(bp_cache, edges(bp_cache)) +end + +function rescale_vertices(bpc::AbstractBeliefPropagationCache; kwargs...) + return rescale_vertices(bpc, collect(vertices(bpc)); kwargs...) +end + +function rescale_vertex(bpc::AbstractBeliefPropagationCache, vertex; kwargs...) + return rescale_vertices(bpc, [vertex]; kwargs...) +end + +function rescale(bpc::AbstractBeliefPropagationCache, args...; kwargs...) + bpc = rescale_messages(bpc) + bpc = rescale_partitions(bpc, args...; kwargs...) + return bpc +end diff --git a/src/beliefpropagation/beliefpropagationcache.jl b/src/beliefpropagation/beliefpropagationcache.jl new file mode 100644 index 0000000..295502a --- /dev/null +++ b/src/beliefpropagation/beliefpropagationcache.jl @@ -0,0 +1,237 @@ +using DiagonalArrays: delta +using Dictionaries: Dictionary, set!, delete! +using Graphs: AbstractGraph, is_tree, connected_components +using NamedGraphs.GraphsExtensions: default_root_vertex, forest_cover, post_order_dfs_edges +using ITensorBase: ITensor, dim +using TypeParameterAccessors: unwrap_array_type, unwrap_array, parenttype + +struct BeliefPropagationCache{V, N <: AbstractDataGraph{V}} <: + AbstractBeliefPropagationCache{V} + network::N + messages::Dictionary +end + +messages(bp_cache::BeliefPropagationCache) = bp_cache.messages +network(bp_cache::BeliefPropagationCache) = bp_cache.network +default_messages() = Dictionary() + +BeliefPropagationCache(network) = BeliefPropagationCache(network, default_messages()) + +function Base.copy(bp_cache::BeliefPropagationCache) + return BeliefPropagationCache(copy(network(bp_cache)), copy(messages(bp_cache))) +end + +function deletemessage!(bp_cache::BeliefPropagationCache, e::AbstractEdge) + ms = messages(bp_cache) + delete!(ms, e) + return bp_cache +end + +function setmessage!(bp_cache::BeliefPropagationCache, e::AbstractEdge, message) + ms = messages(bp_cache) + set!(ms, e, message) + return bp_cache +end + +function message(bp_cache::AbstractBeliefPropagationCache, edge::AbstractEdge; kwargs...) + ms = messages(bp_cache) + return get(() -> default_message(bp_cache, edge; kwargs...), ms, edge) +end + +function messages(bp_cache::AbstractBeliefPropagationCache, edges::Vector{<:AbstractEdge}) + return [message(bp_cache, e) for e in edges] +end + +default_bp_maxiter(g::AbstractGraph) = is_tree(g) ? 1 : nothing +#Forward onto the network +for f in [ + :(Graphs.vertices), + :(Graphs.edges), + :(Graphs.is_tree), + :(NamedGraphs.GraphsExtensions.boundary_edges), + :(factors), + :(default_bp_maxiter), + :(ITensorNetworksNext.setfactor!), + :(ITensorNetworksNext.linkinds), + :(ITensorNetworksNext.underlying_graph), + ] + @eval begin + function $f(bp_cache::BeliefPropagationCache, args...; kwargs...) + return $f(network(bp_cache), args...; kwargs...) + end + end +end + +#TODO: Get subgraph working on an ITensorNetwork to overload this directly +function default_bp_edge_sequence(bp_cache::BeliefPropagationCache) + return forest_cover_edge_sequence(underlying_graph(bp_cache)) +end + +function factors(tn::AbstractTensorNetwork, vertex) + return [tn[vertex]] +end + +function region_scalar(bp_cache::BeliefPropagationCache, edge::AbstractEdge) + return (message(bp_cache, edge) * message(bp_cache, reverse(edge)))[] +end + +function region_scalar(bp_cache::BeliefPropagationCache, vertex) + incoming_ms = incoming_messages(bp_cache, vertex) + state = factors(bp_cache, vertex) + return (reduce(*, incoming_ms) * reduce(*, state))[] +end + +function default_message(bp_cache::BeliefPropagationCache, edge::AbstractEdge) + return default_message(network(bp_cache), edge::AbstractEdge) +end + +function default_message(tn::AbstractTensorNetwork, edge::AbstractEdge) + t = ITensor(ones(dim.(linkinds(tn, edge))...), linkinds(tn, edge)...) + #TODO: Get datatype working on tensornetworks so we can support GPU, etc... + return t +end + +#Algorithmic defaults +default_update_alg(bp_cache::BeliefPropagationCache) = "bp" +default_message_update_alg(bp_cache::BeliefPropagationCache) = "contract" +default_normalize(::Algorithm"contract") = true +default_sequence_alg(::Algorithm"contract") = "optimal" +function set_default_kwargs(alg::Algorithm"contract") + normalize = get(alg, :normalize, default_normalize(alg)) + sequence_alg = get(alg, :sequence_alg, default_sequence_alg(alg)) + return Algorithm("contract"; normalize, sequence_alg) +end +function set_default_kwargs(alg::Algorithm"adapt_update") + _alg = set_default_kwargs(get(alg, :alg, Algorithm("contract"))) + return Algorithm("adapt_update"; adapt = alg.adapt, alg = _alg) +end +default_verbose(::Algorithm"bp") = false +default_tol(::Algorithm"bp") = nothing +function set_default_kwargs(alg::Algorithm"bp", bp_cache::BeliefPropagationCache) + verbose = get(alg, :verbose, default_verbose(alg)) + maxiter = get(alg, :maxiter, default_bp_maxiter(bp_cache)) + edge_sequence = get(alg, :edge_sequence, default_bp_edge_sequence(bp_cache)) + tol = get(alg, :tol, default_tol(alg)) + message_update_alg = set_default_kwargs( + get(alg, :message_update_alg, Algorithm(default_message_update_alg(bp_cache))) + ) + return Algorithm("bp"; verbose, maxiter, edge_sequence, tol, message_update_alg) +end + +#TODO: Update message etc should go here... +function updated_message( + alg::Algorithm"contract", bp_cache::BeliefPropagationCache, edge::AbstractEdge + ) + vertex = src(edge) + incoming_ms = incoming_messages( + bp_cache, vertex; ignore_edges = typeof(edge)[reverse(edge)] + ) + state = factors(bp_cache, vertex) + #contract_list = ITensor[incoming_ms; state] + #sequence = contraction_sequence(contract_list; alg=alg.kwargs.sequence_alg) + #updated_messages = contract(contract_list; sequence) + updated_message = + !isempty(incoming_ms) ? reduce(*, state) * reduce(*, incoming_ms) : reduce(*, state) + if alg.normalize + message_norm = LinearAlgebra.norm(updated_message) + if !iszero(message_norm) + updated_message /= message_norm + end + end + return updated_message +end + +function updated_message( + bp_cache::BeliefPropagationCache, + edge::AbstractEdge; + alg = default_message_update_alg(bpc), + kwargs..., + ) + return updated_message(set_default_kwargs(Algorithm(alg; kwargs...)), bp_cache, edge) +end + +function update_message!( + message_update_alg::Algorithm, bp_cache::BeliefPropagationCache, edge::AbstractEdge + ) + return setmessage!(bp_cache, edge, updated_message(message_update_alg, bp_cache, edge)) +end + +""" +Do a sequential update of the message tensors on `edges` +""" +function update_iteration( + alg::Algorithm"bp", + bpc::AbstractBeliefPropagationCache, + edges::Vector; + (update_diff!) = nothing, + ) + bpc = copy(bpc) + for e in edges + prev_message = !isnothing(update_diff!) ? message(bpc, e) : nothing + update_message!(alg.message_update_alg, bpc, e) + if !isnothing(update_diff!) + update_diff![] += message_diff(message(bpc, e), prev_message) + end + end + return bpc +end + +""" +Do parallel updates between groups of edges of all message tensors +Currently we send the full message tensor data struct to update for each edge_group. But really we only need the +mts relevant to that group. +""" +function update_iteration( + alg::Algorithm"bp", + bpc::AbstractBeliefPropagationCache, + edge_groups::Vector{<:Vector{<:AbstractEdge}}; + (update_diff!) = nothing, + ) + new_mts = empty(messages(bpc)) + for edges in edge_groups + bpc_t = update_iteration(alg.kwargs.message_update_alg, bpc, edges; (update_diff!)) + for e in edges + set!(new_mts, e, message(bpc_t, e)) + end + end + return set_messages(bpc, new_mts) +end + +""" +More generic interface for update, with default params +""" +function update(alg::Algorithm"bp", bpc::AbstractBeliefPropagationCache) + compute_error = !isnothing(alg.tol) + if isnothing(alg.maxiter) + error("You need to specify a number of iterations for BP!") + end + for i in 1:alg.maxiter + diff = compute_error ? Ref(0.0) : nothing + bpc = update_iteration(alg, bpc, alg.edge_sequence; (update_diff!) = diff) + if compute_error && (diff.x / length(alg.edge_sequence)) <= alg.tol + if alg.verbose + println("BP converged to desired precision after $i iterations.") + end + break + end + end + return bpc +end + +function update(bpc::AbstractBeliefPropagationCache; alg = default_update_alg(bpc), kwargs...) + return update(set_default_kwargs(Algorithm(alg; kwargs...), bpc), bpc) +end + +#Edge sequence stuff +function forest_cover_edge_sequence(g::AbstractGraph; root_vertex = default_root_vertex) + forests = forest_cover(g) + edges = edgetype(g)[] + for forest in forests + trees = [forest[vs] for vs in connected_components(forest)] + for tree in trees + tree_edges = post_order_dfs_edges(tree, root_vertex(tree)) + push!(edges, vcat(tree_edges, reverse(reverse.(tree_edges)))...) + end + end + return edges +end \ No newline at end of file diff --git a/test/test_beliefpropagation.jl b/test/test_beliefpropagation.jl index 4b179fb..81ee722 100644 --- a/test/test_beliefpropagation.jl +++ b/test/test_beliefpropagation.jl @@ -3,11 +3,13 @@ using ITensorBase: Index using ITensorNetworksNext: ITensorNetworksNext, BeliefPropagationCache, TensorNetwork, adapt_messages, default_message, default_messages, edge_scalars, messages, setmessages!, factors, freenergy, partitionfunction using Graphs: edges, vertices -using NamedGraphs.NamedGraphGenerators: named_grid +using NamedGraphs.NamedGraphGenerators: named_grid, named_comb_tree using NamedGraphs.GraphsExtensions: arranged_edges, incident_edges using Test: @test, @testset @testset "BeliefPropagation" begin + + #Chain of tensors dims = (4, 1) g = named_grid(dims) l = Dict(e => Index(2) for e in edges(g)) @@ -17,6 +19,22 @@ using Test: @test, @testset return randn(Tuple(is)) end + bpc = BeliefPropagationCache(tn) + bpc = ITensorNetworksNext.update(bpc; maxiter = 1) + z_bp = partitionfunction(bpc) + z_exact = reduce(*, [tn[v] for v in vertices(g)])[] + @test abs(z_bp - z_exact) <= 1e-14 + + #Tree of tensors + dims = (4, 3) + g = named_comb_tree(dims) + l = Dict(e => Index(3) for e in edges(g)) + l = merge(l, Dict(reverse(e) => l[e] for e in edges(g))) + tn = TensorNetwork(g) do v + is = map(e -> l[e], incident_edges(g, v)) + return randn(Tuple(is)) + end + bpc = BeliefPropagationCache(tn) bpc = ITensorNetworksNext.update(bpc; maxiter = 10) z_bp = partitionfunction(bpc) From 0bbf58450efc0033daaa477290ba539e016e1983 Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Tue, 28 Oct 2025 15:12:37 -0400 Subject: [PATCH 12/13] Reintroduce `AbstractNetworkIterator` and `AbstractProblem` interface --- src/ITensorNetworksNext.jl | 2 + src/abstract_problem.jl | 1 + src/iterators.jl | 173 +++++++++++++++++++++++++++++++++++++ test/test_iterators.jl | 161 ++++++++++++++++++++++++++++++++++ 4 files changed, 337 insertions(+) create mode 100644 src/abstract_problem.jl create mode 100644 src/iterators.jl create mode 100644 test/test_iterators.jl diff --git a/src/ITensorNetworksNext.jl b/src/ITensorNetworksNext.jl index 62471ad..4ed09c0 100644 --- a/src/ITensorNetworksNext.jl +++ b/src/ITensorNetworksNext.jl @@ -4,6 +4,8 @@ include("lazynameddimsarrays.jl") include("abstracttensornetwork.jl") include("tensornetwork.jl") include("contract_network.jl") +include("abstract_problem.jl") +include("iterators.jl") include("beliefpropagation/abstractbeliefpropagationcache.jl") include("beliefpropagation/beliefpropagationcache.jl") diff --git a/src/abstract_problem.jl b/src/abstract_problem.jl new file mode 100644 index 0000000..5a65e0a --- /dev/null +++ b/src/abstract_problem.jl @@ -0,0 +1 @@ +abstract type AbstractProblem end diff --git a/src/iterators.jl b/src/iterators.jl new file mode 100644 index 0000000..1fe4844 --- /dev/null +++ b/src/iterators.jl @@ -0,0 +1,173 @@ +""" + abstract type AbstractNetworkIterator + +A stateful iterator with two states: `increment!` and `compute!`. Each iteration begins +with a call to `increment!` before executing `compute!`, however the initial call to +`iterate` skips the `increment!` call as it is assumed the iterator is initalized such that +this call is implict. Termination of the iterator is controlled by the function `done`. +""" +abstract type AbstractNetworkIterator end + +# We use greater than or equals here as we increment the state at the start of the iteration +islaststep(iterator::AbstractNetworkIterator) = state(iterator) >= length(iterator) + +function Base.iterate(iterator::AbstractNetworkIterator, init = true) + # The assumption is that first "increment!" is implicit, therefore we must skip the + # the termination check for the first iteration, i.e. `AbstractNetworkIterator` is not + # defined when length < 1, + init || islaststep(iterator) && return nothing + # We seperate increment! from step! and demand that any AbstractNetworkIterator *must* + # define a method for increment! This way we avoid cases where one may wish to nest + # calls to different step! methods accidentaly incrementing multiple times. + init || increment!(iterator) + rv = compute!(iterator) + return rv, false +end + +function increment! end +compute!(iterator::AbstractNetworkIterator) = iterator + +step!(iterator::AbstractNetworkIterator) = step!(identity, iterator) +function step!(f, iterator::AbstractNetworkIterator) + compute!(iterator) + f(iterator) + increment!(iterator) + return iterator +end + +# +# RegionIterator +# +""" + struct RegionIterator{Problem, RegionPlan} <: AbstractNetworkIterator +""" +mutable struct RegionIterator{Problem, RegionPlan} <: AbstractNetworkIterator + problem::Problem + region_plan::RegionPlan + which_region::Int + const which_sweep::Int + function RegionIterator(problem::P, region_plan::R, sweep::Int) where {P, R} + if length(region_plan) == 0 + throw(BoundsError("Cannot construct a region iterator with 0 elements.")) + end + return new{P, R}(problem, region_plan, 1, sweep) + end +end + +function RegionIterator(problem; sweep, sweep_kwargs...) + plan = region_plan(problem; sweep_kwargs...) + return RegionIterator(problem, plan, sweep) +end + +function new_region_iterator(iterator::RegionIterator; sweep_kwargs...) + return RegionIterator(iterator.problem; sweep_kwargs...) +end + +state(region_iter::RegionIterator) = region_iter.which_region +Base.length(region_iter::RegionIterator) = length(region_iter.region_plan) + +problem(region_iter::RegionIterator) = region_iter.problem + +function current_region_plan(region_iter::RegionIterator) + return region_iter.region_plan[region_iter.which_region] +end + +function current_region(region_iter::RegionIterator) + region, _ = current_region_plan(region_iter) + return region +end + +function region_kwargs(region_iter::RegionIterator) + _, kwargs = current_region_plan(region_iter) + return kwargs +end +function region_kwargs(f::Function, iter::RegionIterator) + return get(region_kwargs(iter), Symbol(f, :_kwargs), (;)) +end + +function prev_region(region_iter::RegionIterator) + state(region_iter) <= 1 && return nothing + prev, _ = region_iter.region_plan[region_iter.which_region - 1] + return prev +end + +function next_region(region_iter::RegionIterator) + islaststep(region_iter) && return nothing + next, _ = region_iter.region_plan[region_iter.which_region + 1] + return next +end + +# +# Functions associated with RegionIterator +# +function increment!(region_iter::RegionIterator) + region_iter.which_region += 1 + return region_iter +end + +function compute!(iter::RegionIterator) + _, local_state = extract!(iter; region_kwargs(extract!, iter)...) + _, local_state = update!(iter, local_state; region_kwargs(update!, iter)...) + insert!(iter, local_state; region_kwargs(insert!, iter)...) + + return iter +end + +region_plan(problem; sweep_kwargs...) = euler_sweep(state(problem); sweep_kwargs...) + +# +# SweepIterator +# + +mutable struct SweepIterator{Problem, Iter} <: AbstractNetworkIterator + region_iter::RegionIterator{Problem} + sweep_kwargs::Iterators.Stateful{Iter} + which_sweep::Int + function SweepIterator(problem::Prob, sweep_kwargs::Iter) where {Prob, Iter} + stateful_sweep_kwargs = Iterators.Stateful(sweep_kwargs) + first_state = Iterators.peel(stateful_sweep_kwargs) + + if isnothing(first_state) + throw(BoundsError("Cannot construct a sweep iterator with 0 elements.")) + end + + first_kwargs, _ = first_state + region_iter = RegionIterator(problem; sweep = 1, first_kwargs...) + + return new{Prob, Iter}(region_iter, stateful_sweep_kwargs, 1) + end +end + +islaststep(sweep_iter::SweepIterator) = isnothing(peek(sweep_iter.sweep_kwargs)) + +region_iterator(sweep_iter::SweepIterator) = sweep_iter.region_iter +problem(sweep_iter::SweepIterator) = problem(region_iterator(sweep_iter)) + +state(sweep_iter::SweepIterator) = sweep_iter.which_sweep +Base.length(sweep_iter::SweepIterator) = length(sweep_iter.sweep_kwargs) +function increment!(sweep_iter::SweepIterator) + sweep_iter.which_sweep += 1 + sweep_kwargs, _ = Iterators.peel(sweep_iter.sweep_kwargs) + update_region_iterator!(sweep_iter; sweep_kwargs...) + return sweep_iter +end + +function update_region_iterator!(iterator::SweepIterator; kwargs...) + sweep = state(iterator) + iterator.region_iter = new_region_iterator(iterator.region_iter; sweep, kwargs...) + return iterator +end + +function compute!(sweep_iter::SweepIterator) + for _ in sweep_iter.region_iter + # TODO: Is it sensible to execute the default region callback function? + end + return +end + +# More basic constructor where sweep_kwargs are constant throughout sweeps +function SweepIterator(problem, nsweeps::Int; sweep_kwargs...) + # Initialize this to an empty RegionIterator + sweep_kwargs_iter = Iterators.repeated(sweep_kwargs, nsweeps) + return SweepIterator(problem, sweep_kwargs_iter) +end diff --git a/test/test_iterators.jl b/test/test_iterators.jl new file mode 100644 index 0000000..456ebcf --- /dev/null +++ b/test/test_iterators.jl @@ -0,0 +1,161 @@ +using Test: @test, @testset, @test_throws +import ITensorNetworksNext as ITensorNetworks +using .ITensorNetworks: SweepIterator, RegionIterator, islaststep, state, increment!, compute!, eachregion + +module TestIteratorUtils + + import ITensorNetworksNext as ITensorNetworks + using .ITensorNetworks + + struct TestProblem <: ITensorNetworks.AbstractProblem + data::Vector{Int} + end + ITensorNetworks.region_plan(::TestProblem) = [:a => (; val = 1), :b => (; val = 2)] + function ITensorNetworks.compute!(iter::ITensorNetworks.RegionIterator{<:TestProblem}) + kwargs = ITensorNetworks.region_kwargs(iter) + push!(ITensorNetworks.problem(iter).data, kwargs.val) + return iter + end + + + mutable struct TestIterator <: ITensorNetworks.AbstractNetworkIterator + state::Int + max::Int + output::Vector{Int} + end + + ITensorNetworks.increment!(TI::TestIterator) = TI.state += 1 + Base.length(TI::TestIterator) = TI.max + ITensorNetworks.state(TI::TestIterator) = TI.state + function ITensorNetworks.compute!(TI::TestIterator) + push!(TI.output, ITensorNetworks.state(TI)) + return TI + end + + mutable struct SquareAdapter <: ITensorNetworks.AbstractNetworkIterator + parent::TestIterator + end + + Base.length(SA::SquareAdapter) = length(SA.parent) + ITensorNetworks.increment!(SA::SquareAdapter) = ITensorNetworks.increment!(SA.parent) + ITensorNetworks.state(SA::SquareAdapter) = ITensorNetworks.state(SA.parent) + function ITensorNetworks.compute!(SA::SquareAdapter) + ITensorNetworks.compute!(SA.parent) + return last(SA.parent.output)^2 + end + +end + +@testset "Iterators" begin + + import .TestIteratorUtils + + @testset "`AbstractNetworkIterator` Interface" begin + + @testset "Edge cases" begin + TI = TestIteratorUtils.TestIterator(1, 1, []) + cb = [] + @test islaststep(TI) + for _ in TI + @test islaststep(TI) + push!(cb, state(TI)) + end + @test length(cb) == 1 + @test length(TI.output) == 1 + @test only(cb) == 1 + + prob = TestIteratorUtils.TestProblem([]) + @test_throws BoundsError SweepIterator(prob, 0) + @test_throws BoundsError RegionIterator(prob, [], 1) + end + + TI = TestIteratorUtils.TestIterator(1, 4, []) + + @test !islaststep((TI)) + + # First iterator should compute only + rv, st = iterate(TI) + @test !islaststep((TI)) + @test !st + @test rv === TI + @test length(TI.output) == 1 + @test only(TI.output) == 1 + @test state(TI) == 1 + @test !st + + rv, st = iterate(TI, st) + @test !islaststep((TI)) + @test !st + @test length(TI.output) == 2 + @test state(TI) == 2 + @test TI.output == [1, 2] + + increment!(TI) + @test !islaststep((TI)) + @test state(TI) == 3 + @test length(TI.output) == 2 + @test TI.output == [1, 2] + + compute!(TI) + @test !islaststep((TI)) + @test state(TI) == 3 + @test length(TI.output) == 3 + @test TI.output == [1, 2, 3] + + # Final Step + iterate(TI, false) + @test islaststep((TI)) + @test state(TI) == 4 + @test length(TI.output) == 4 + @test TI.output == [1, 2, 3, 4] + + @test iterate(TI, false) === nothing + + TI = TestIteratorUtils.TestIterator(1, 5, []) + + cb = [] + + for _ in TI + @test length(cb) == length(TI.output) - 1 + @test cb == (TI.output)[1:(end - 1)] + push!(cb, state(TI)) + @test cb == TI.output + end + + @test islaststep((TI)) + @test length(TI.output) == 5 + @test length(cb) == 5 + @test cb == TI.output + + + TI = TestIteratorUtils.TestIterator(1, 5, []) + end + + @testset "Adapters" begin + TI = TestIteratorUtils.TestIterator(1, 5, []) + SA = TestIteratorUtils.SquareAdapter(TI) + + @testset "Generic" begin + + i = 0 + for rv in SA + i += 1 + @test rv isa Int + @test rv == i^2 + @test state(SA) == i + end + + @test islaststep((SA)) + + TI = TestIteratorUtils.TestIterator(1, 5, []) + SA = TestIteratorUtils.SquareAdapter(TI) + + SA_c = collect(SA) + + @test SA_c isa Vector + @test length(SA_c) == 5 + @test SA_c == [1, 4, 9, 16, 25] + + end + end +end From 013b8b4a137fd937b101b825e68a2bc6964df719 Mon Sep 17 00:00:00 2001 From: Jack Dunham Date: Tue, 28 Oct 2025 15:18:28 -0400 Subject: [PATCH 13/13] Express BP in terms of `SweepIterator` interface Introduce `BeliefPropagationProblem` wrapper to hold the cache and the error `diff` field. Also simplifies some kwargs wrangling. --- Project.toml | 2 + src/ITensorNetworksNext.jl | 1 + .../beliefpropagationcache.jl | 126 ++---------------- .../beliefpropagationproblem.jl | 85 ++++++++++++ 4 files changed, 101 insertions(+), 113 deletions(-) create mode 100644 src/beliefpropagation/beliefpropagationproblem.jl diff --git a/Project.toml b/Project.toml index 42df730..189983f 100644 --- a/Project.toml +++ b/Project.toml @@ -10,6 +10,7 @@ BackendSelection = "680c2d7c-f67a-4cc9-ae9c-da132b1447a5" DataGraphs = "b5a273c3-7e6c-41f6-98bd-8d7f1525a36a" Dictionaries = "85a47980-9c8c-11e8-2b9f-f7ca1fa99fb4" Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6" +ITensorBase = "4795dd04-0d67-49bb-8f44-b89c448a1dc7" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" NamedDimsArrays = "60cbd0c0-df58-4cb7-918c-6f5607b73fde" @@ -33,6 +34,7 @@ BackendSelection = "0.1.6" DataGraphs = "0.2.7" Dictionaries = "0.4.5" Graphs = "1.13.1" +ITensorBase = "0.2.14" LinearAlgebra = "1.10" MacroTools = "0.5.16" NamedDimsArrays = "0.8" diff --git a/src/ITensorNetworksNext.jl b/src/ITensorNetworksNext.jl index 4ed09c0..8cc4dd0 100644 --- a/src/ITensorNetworksNext.jl +++ b/src/ITensorNetworksNext.jl @@ -9,5 +9,6 @@ include("iterators.jl") include("beliefpropagation/abstractbeliefpropagationcache.jl") include("beliefpropagation/beliefpropagationcache.jl") +include("beliefpropagation/beliefpropagationproblem.jl") end diff --git a/src/beliefpropagation/beliefpropagationcache.jl b/src/beliefpropagation/beliefpropagationcache.jl index 295502a..cdae651 100644 --- a/src/beliefpropagation/beliefpropagationcache.jl +++ b/src/beliefpropagation/beliefpropagationcache.jl @@ -1,9 +1,7 @@ -using DiagonalArrays: delta using Dictionaries: Dictionary, set!, delete! using Graphs: AbstractGraph, is_tree, connected_components using NamedGraphs.GraphsExtensions: default_root_vertex, forest_cover, post_order_dfs_edges using ITensorBase: ITensor, dim -using TypeParameterAccessors: unwrap_array_type, unwrap_array, parenttype struct BeliefPropagationCache{V, N <: AbstractDataGraph{V}} <: AbstractBeliefPropagationCache{V} @@ -13,9 +11,8 @@ end messages(bp_cache::BeliefPropagationCache) = bp_cache.messages network(bp_cache::BeliefPropagationCache) = bp_cache.network -default_messages() = Dictionary() -BeliefPropagationCache(network) = BeliefPropagationCache(network, default_messages()) +BeliefPropagationCache(network) = BeliefPropagationCache(network, Dictionary()) function Base.copy(bp_cache::BeliefPropagationCache) return BeliefPropagationCache(copy(network(bp_cache)), copy(messages(bp_cache))) @@ -33,16 +30,15 @@ function setmessage!(bp_cache::BeliefPropagationCache, e::AbstractEdge, message) return bp_cache end -function message(bp_cache::AbstractBeliefPropagationCache, edge::AbstractEdge; kwargs...) +function message(bp_cache::BeliefPropagationCache, edge::AbstractEdge; kwargs...) ms = messages(bp_cache) return get(() -> default_message(bp_cache, edge; kwargs...), ms, edge) end -function messages(bp_cache::AbstractBeliefPropagationCache, edges::Vector{<:AbstractEdge}) +function messages(bp_cache::BeliefPropagationCache, edges::Vector{<:AbstractEdge}) return [message(bp_cache, e) for e in edges] end -default_bp_maxiter(g::AbstractGraph) = is_tree(g) ? 1 : nothing #Forward onto the network for f in [ :(Graphs.vertices), @@ -62,11 +58,6 @@ for f in [ end end -#TODO: Get subgraph working on an ITensorNetwork to overload this directly -function default_bp_edge_sequence(bp_cache::BeliefPropagationCache) - return forest_cover_edge_sequence(underlying_graph(bp_cache)) -end - function factors(tn::AbstractTensorNetwork, vertex) return [tn[vertex]] end @@ -91,33 +82,6 @@ function default_message(tn::AbstractTensorNetwork, edge::AbstractEdge) return t end -#Algorithmic defaults -default_update_alg(bp_cache::BeliefPropagationCache) = "bp" -default_message_update_alg(bp_cache::BeliefPropagationCache) = "contract" -default_normalize(::Algorithm"contract") = true -default_sequence_alg(::Algorithm"contract") = "optimal" -function set_default_kwargs(alg::Algorithm"contract") - normalize = get(alg, :normalize, default_normalize(alg)) - sequence_alg = get(alg, :sequence_alg, default_sequence_alg(alg)) - return Algorithm("contract"; normalize, sequence_alg) -end -function set_default_kwargs(alg::Algorithm"adapt_update") - _alg = set_default_kwargs(get(alg, :alg, Algorithm("contract"))) - return Algorithm("adapt_update"; adapt = alg.adapt, alg = _alg) -end -default_verbose(::Algorithm"bp") = false -default_tol(::Algorithm"bp") = nothing -function set_default_kwargs(alg::Algorithm"bp", bp_cache::BeliefPropagationCache) - verbose = get(alg, :verbose, default_verbose(alg)) - maxiter = get(alg, :maxiter, default_bp_maxiter(bp_cache)) - edge_sequence = get(alg, :edge_sequence, default_bp_edge_sequence(bp_cache)) - tol = get(alg, :tol, default_tol(alg)) - message_update_alg = set_default_kwargs( - get(alg, :message_update_alg, Algorithm(default_message_update_alg(bp_cache))) - ) - return Algorithm("bp"; verbose, maxiter, edge_sequence, tol, message_update_alg) -end - #TODO: Update message etc should go here... function updated_message( alg::Algorithm"contract", bp_cache::BeliefPropagationCache, edge::AbstractEdge @@ -141,85 +105,21 @@ function updated_message( return updated_message end -function updated_message( - bp_cache::BeliefPropagationCache, - edge::AbstractEdge; - alg = default_message_update_alg(bpc), - kwargs..., +function default_algorithm( + ::Type{<:Algorithm"contract"}; normalize = true, sequence_alg = "optimal" ) - return updated_message(set_default_kwargs(Algorithm(alg; kwargs...)), bp_cache, edge) + return Algorithm("contract"; normalize, sequence_alg) end - -function update_message!( - message_update_alg::Algorithm, bp_cache::BeliefPropagationCache, edge::AbstractEdge +function default_algorithm( + ::Type{<:Algorithm"adapt_update"}; adapt, alg = default_algorithm(Algorithm"contract") ) - return setmessage!(bp_cache, edge, updated_message(message_update_alg, bp_cache, edge)) + return Algorithm("adapt_update"; adapt, alg) end -""" -Do a sequential update of the message tensors on `edges` -""" -function update_iteration( - alg::Algorithm"bp", - bpc::AbstractBeliefPropagationCache, - edges::Vector; - (update_diff!) = nothing, - ) - bpc = copy(bpc) - for e in edges - prev_message = !isnothing(update_diff!) ? message(bpc, e) : nothing - update_message!(alg.message_update_alg, bpc, e) - if !isnothing(update_diff!) - update_diff![] += message_diff(message(bpc, e), prev_message) - end - end - return bpc -end - -""" -Do parallel updates between groups of edges of all message tensors -Currently we send the full message tensor data struct to update for each edge_group. But really we only need the -mts relevant to that group. -""" -function update_iteration( - alg::Algorithm"bp", - bpc::AbstractBeliefPropagationCache, - edge_groups::Vector{<:Vector{<:AbstractEdge}}; - (update_diff!) = nothing, +function update_message!( + message_update_alg::Algorithm, bpc::BeliefPropagationCache, edge::AbstractEdge ) - new_mts = empty(messages(bpc)) - for edges in edge_groups - bpc_t = update_iteration(alg.kwargs.message_update_alg, bpc, edges; (update_diff!)) - for e in edges - set!(new_mts, e, message(bpc_t, e)) - end - end - return set_messages(bpc, new_mts) -end - -""" -More generic interface for update, with default params -""" -function update(alg::Algorithm"bp", bpc::AbstractBeliefPropagationCache) - compute_error = !isnothing(alg.tol) - if isnothing(alg.maxiter) - error("You need to specify a number of iterations for BP!") - end - for i in 1:alg.maxiter - diff = compute_error ? Ref(0.0) : nothing - bpc = update_iteration(alg, bpc, alg.edge_sequence; (update_diff!) = diff) - if compute_error && (diff.x / length(alg.edge_sequence)) <= alg.tol - if alg.verbose - println("BP converged to desired precision after $i iterations.") - end - break - end - end - return bpc -end - -function update(bpc::AbstractBeliefPropagationCache; alg = default_update_alg(bpc), kwargs...) - return update(set_default_kwargs(Algorithm(alg; kwargs...), bpc), bpc) + return setmessage!(bpc, edge, updated_message(message_update_alg, bpc, edge)) end #Edge sequence stuff @@ -234,4 +134,4 @@ function forest_cover_edge_sequence(g::AbstractGraph; root_vertex = default_root end end return edges -end \ No newline at end of file +end diff --git a/src/beliefpropagation/beliefpropagationproblem.jl b/src/beliefpropagation/beliefpropagationproblem.jl new file mode 100644 index 0000000..a497363 --- /dev/null +++ b/src/beliefpropagation/beliefpropagationproblem.jl @@ -0,0 +1,85 @@ +mutable struct BeliefPropagationProblem{V, Cache <: AbstractBeliefPropagationCache{V}} <: + AbstractProblem + const cache::Cache + diff::Union{Nothing, Float64} +end + +function default_algorithm( + ::Type{<:Algorithm"bp"}, + bpc::BeliefPropagationCache; + verbose = false, + tol = nothing, + edge_sequence = forest_cover_edge_sequence(underlying_graph(bpc)), + message_update_alg = default_algorithm(Algorithm"contract"), + maxiter = is_tree(bpc) ? 1 : nothing, + ) + return Algorithm("bp"; verbose, tol, edge_sequence, message_update_alg, maxiter) +end + +function compute!(iter::RegionIterator{<:BeliefPropagationProblem}) + prob = iter.problem + + edge_group, kwargs = current_region_plan(iter) + + new_message_tensors = map(edge_group) do edge + old_message = message(prob.cache, edge) + + new_message = updated_message(kwargs.message_update_alg, prob.cache, edge) + + if !isnothing(prob.diff) + # TODO: Define `message_diff` + prob.diff += message_diff(new_message, old_message) + end + + return new_message + end + + foreach(edge_group, new_message_tensors) do edge, new_message + setmessage!(prob.cache, edge, new_message) + end + + return iter +end + +function region_plan( + prob::BeliefPropagationProblem; root_vertex = default_root_vertex, sweep_kwargs... + ) + edges = forest_cover_edge_sequence(underlying_graph(prob.cache); root_vertex) + + plan = map(edges) do e + return [e] => (; sweep_kwargs...) + end + + return plan +end + +function update(bpc::AbstractBeliefPropagationCache; kwargs...) + return update(default_algorithm(Algorithm"bp", bpc; kwargs...), bpc) +end +function update(alg::Algorithm"bp", bpc) + compute_error = !isnothing(alg.tol) + + diff = compute_error ? 0.0 : nothing + + prob = BeliefPropagationProblem(bpc, diff) + + iter = SweepIterator(prob, alg.maxiter; compute_error, getfield(alg, :kwargs)...) + + for _ in iter + if compute_error && prob.diff <= alg.tol + break + end + end + + if alg.verbose && compute_error + if prob.diff <= alg.tol + println("BP converged to desired precision after $(iter.which_sweep) iterations.") + else + println( + "BP failed to converge to precision $(alg.tol), got $(prob.diff) after $(iter.which_sweep) iterations", + ) + end + end + + return bpc +end