|
1 | | -const NCONSTYLE = "Valid ncon style network" |
| 1 | +# Verify if a list of indices specifies a tensor contraction in ncon style. |
| 2 | +check_nconstyle(::Type{Bool}, network) = _check_nconstyle(network, Val(true)) |
| 3 | +check_nconstyle(network) = _check_nconstyle(network, Val(false)) |
2 | 4 |
|
3 | | -# check if a list of indices specifies a tensor contraction in ncon style |
4 | | -function isnconstyle(network) |
5 | | - return _nconstyle_error(network) == NCONSTYLE |
6 | | -end |
7 | | - |
8 | | -function _nconstyle_error(network) |
| 5 | +function _check_nconstyle(network, ::Val{check}) where {check} |
9 | 6 | allindices = Vector{Int}() |
10 | 7 | for ind in network |
11 | | - all(i -> isa(i, Integer), ind) || return "All indices must be integers" |
| 8 | + all(i -> isa(i, Integer), ind) || return check ? false : |
| 9 | + throw(IndexError("All indices must be integers")) |
12 | 10 | append!(allindices, ind) |
13 | 11 | end |
14 | 12 | while length(allindices) > 0 |
15 | 13 | i = pop!(allindices) |
16 | 14 | if i > 0 # positive labels represent contractions or traces and should appear twice |
17 | 15 | k = findfirst(isequal(i), allindices) |
18 | | - k === nothing && return "Index $i appears only once in the network" |
| 16 | + isnothing(k) && return check ? false : |
| 17 | + throw(IndexError(lazy"Index $i appears only once in the network")) |
19 | 18 | l = findnext(isequal(i), allindices, k + 1) |
20 | | - l !== nothing && return "Index $i appears more than twice in the network" |
| 19 | + !isnothing(l) && return check ? false : |
| 20 | + throw(IndexError(lazy"Index $i appears more than twice in the network")) |
21 | 21 | deleteat!(allindices, k) |
22 | 22 | elseif i < 0 # negative labels represent open indices and should appear once |
23 | | - findfirst(isequal(i), allindices) === nothing || return "Index $i appears more than once in the network" |
| 23 | + isnothing(findfirst(isequal(i), allindices)) || return check ? false : |
| 24 | + throw(IndexError(lazy"Index $i appears more than once in the network")) |
24 | 25 | else # i == 0 |
25 | | - return "Index 0 is not allowed in the network" |
| 26 | + return check ? false : throw(IndexError("Index 0 is not allowed in the network")) |
26 | 27 | end |
27 | 28 | end |
28 | | - return NCONSTYLE |
29 | | -end |
30 | | - |
31 | | -function nconstylecheck(network) |
32 | | - err = _nconstyle_error(network) |
33 | | - err === NCONSTYLE || throw(ArgumentError(err)) |
34 | | - return nothing |
| 29 | + return check ? true : nothing |
35 | 30 | end |
36 | 31 |
|
37 | 32 | function ncontree(network) |
|
0 commit comments