diff --git a/deps/ReactantExtra/WORKSPACE b/deps/ReactantExtra/WORKSPACE index a6c78bea14..b0987cb317 100644 --- a/deps/ReactantExtra/WORKSPACE +++ b/deps/ReactantExtra/WORKSPACE @@ -4,7 +4,7 @@ NSYNC_COMMIT = "82b118aa7ace3132e517e2c467f8732978cf4023" NSYNC_SHA256 = "" -ENZYMEXLA_COMMIT = "bb0db132691f945bbc82fe4812bbf4c200340d37" +ENZYMEXLA_COMMIT = "52805e23ffb3dde87974b93a0cc3e75cd11fc4ad" ENZYMEXLA_SHA256 = "" diff --git a/src/Ops.jl b/src/Ops.jl index 22bf67679b..cae1c9bc90 100644 --- a/src/Ops.jl +++ b/src/Ops.jl @@ -3312,6 +3312,46 @@ Compute the row maximum pivoted LU factorization of `x` and return the factors ` return (res, ipiv, perm, info) end +@noinline function svd( + x::TracedRArray{T,N}, + ::Type{iT}=Int32; + full::Bool=false, + location=mlir_stacktrace("svd", @__FILE__, @__LINE__), +) where {T,iT,N} + @assert N >= 2 + + batch_sizes = size(x)[1:(end - 2)] + m, n = size(x)[(end - 1):end] + r = min(m, n) + + U_size = (batch_sizes..., m, full ? m : r) + S_size = (batch_sizes..., r) + Vt_size = (batch_sizes..., full ? n : r, n) + info_size = batch_sizes + + svd_op = enzymexla.linalg_svd( + x.mlir_data; + U=mlir_type(TracedRArray{T,N}, U_size), + S=mlir_type(TracedRArray{Base.real(T),N - 1}, S_size), + Vt=mlir_type(TracedRArray{T,N}, Vt_size), + info=mlir_type(TracedRArray{iT,N - 2}, info_size), + full=full, + location, + ) + + U = TracedRArray{T,N}((), MLIR.IR.result(svd_op, 1), U_size) + S = TracedRArray{Base.real(T),N - 1}((), MLIR.IR.result(svd_op, 2), S_size) + Vt = TracedRArray{T,N}((), MLIR.IR.result(svd_op, 3), Vt_size) + + if N == 2 + info = TracedRNumber{iT}((), MLIR.IR.result(svd_op, 4)) + else + info = TracedRArray{iT,N - 2}((), MLIR.IR.result(svd_op, 4), info_size) + end + + return U, S, Vt, info +end + @noinline function reduce_window( f::F, inputs::Vector{TracedRArray{T,N}}, diff --git a/src/stdlibs/LinearAlgebra.jl b/src/stdlibs/LinearAlgebra.jl index b92a5d1177..72bfb310ca 100644 --- a/src/stdlibs/LinearAlgebra.jl +++ b/src/stdlibs/LinearAlgebra.jl @@ -27,6 +27,10 @@ function __init__() (BLAS.@blasfunc(dgetrf_), :enzymexla_lapack_dgetrf_), (BLAS.@blasfunc(cgetrf_), :enzymexla_lapack_cgetrf_), (BLAS.@blasfunc(zgetrf_), :enzymexla_lapack_zgetrf_), + (BLAS.@blasfunc(sgesvd_), :enzymexla_lapack_sgesvd_), + (BLAS.@blasfunc(dgesvd_), :enzymexla_lapack_dgesvd_), + (BLAS.@blasfunc(cgesvd_), :enzymexla_lapack_cgesvd_), + (BLAS.@blasfunc(zgesvd_), :enzymexla_lapack_zgesvd_), ] sym = Libdl.dlsym(libblastrampoline_handle, cname) @ccall MLIR.API.mlir_c.EnzymeJaXMapSymbol(