Skip to content

Commit a902e9b

Browse files
authored
Fix tr in Julia 1.12 (#633)
1 parent a3af656 commit a902e9b

File tree

2 files changed

+11
-0
lines changed

2 files changed

+11
-0
lines changed

src/host/linalg.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,13 @@ function Base.copyto!(A::Array{T,N}, B::Transpose{T, <:AbstractGPUArray{T,N}}) w
5858
copyto!(A, Transpose(Array(parent(B))))
5959
end
6060

61+
## trace
62+
63+
function LinearAlgebra.tr(A::AnyGPUMatrix)
64+
LinearAlgebra.checksquare(A)
65+
sum(diag(A))
66+
end
67+
6168
## copy upper triangle to lower and vice versa
6269

6370
function LinearAlgebra.copytri!(A::AbstractGPUMatrix, uplo::AbstractChar, conjugate::Bool=false)

test/testsuite/linalg.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,10 @@
1414
@test compare(transpose!, AT, Array{Float32}(undef, 128, 32), rand(Float32, 32, 128))
1515
end
1616

17+
@testset "tr" begin
18+
@test compare(tr, AT, rand(Float32, 32, 32))
19+
end
20+
1721
@testset "permutedims" begin
1822
@test compare(x -> permutedims(x, (2, 1)), AT, rand(Float32, 2, 3))
1923
@test compare(x -> permutedims(x, (2, 1, 3)), AT, rand(Float32, 4, 5, 6))

0 commit comments

Comments
 (0)