Skip to content

Commit db411ea

Browse files
araujomsGianmarcoCuppari
authored andcommitted
make checksquare fail on non-matrices (#1476)
Closes #1362
1 parent 6ef853e commit db411ea

File tree

2 files changed

+17
-12
lines changed

2 files changed

+17
-12
lines changed

src/LinearAlgebra.jl

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -326,7 +326,7 @@ StridedMatrixStride1{T} = StridedArrayStride1{T,2}
326326
"""
327327
LinearAlgebra.checksquare(A)
328328
329-
Check that a matrix is square, then return its common dimension.
329+
Checks whether a matrix is square, returning its common dimension if it is the case, or throwing a DimensionMismatch error otherwise.
330330
For multiple arguments, return a vector.
331331
332332
# Examples
@@ -340,19 +340,13 @@ julia> LinearAlgebra.checksquare(A, B)
340340
```
341341
"""
342342
function checksquare(A)
343-
m,n = size(A)
344-
m == n || throw(DimensionMismatch(lazy"matrix is not square: dimensions are $(size(A))"))
345-
m
343+
sizeA = size(A)
344+
length(sizeA) == 2 || throw(DimensionMismatch(lazy"input is not a matrix: dimensions are $sizeA"))
345+
sizeA[1] == sizeA[2] || throw(DimensionMismatch(lazy"matrix is not square: dimensions are $sizeA"))
346+
return sizeA[1]
346347
end
347348

348-
function checksquare(A...)
349-
sizes = Int[]
350-
for a in A
351-
size(a,1)==size(a,2) || throw(DimensionMismatch(lazy"matrix is not square: dimensions are $(size(a))"))
352-
push!(sizes, size(a,1))
353-
end
354-
return sizes
355-
end
349+
checksquare(A...) = [checksquare(a) for a in A]
356350

357351
function char_uplo(uplo::Symbol)
358352
if uplo === :U

test/dense.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1430,4 +1430,15 @@ end
14301430
@test log(D) log(UpperTriangular(D))
14311431
end
14321432

1433+
@testset "issue 1362" begin
1434+
A = zeros(2,2)
1435+
B = zeros(2,3)
1436+
C = zeros(2,2,1)
1437+
@test LinearAlgebra.checksquare(A) == 2
1438+
@test LinearAlgebra.checksquare(A,A) == [2, 2]
1439+
@test_throws DimensionMismatch LinearAlgebra.checksquare(B)
1440+
@test_throws DimensionMismatch LinearAlgebra.checksquare(C)
1441+
@test_throws DimensionMismatch LinearAlgebra.checksquare(A,B)
1442+
end
1443+
14331444
end # module TestDense

0 commit comments

Comments
 (0)