diff --git a/deps/ReactantExtra/WORKSPACE b/deps/ReactantExtra/WORKSPACE index cc83efc42f..d15a5f0726 100644 --- a/deps/ReactantExtra/WORKSPACE +++ b/deps/ReactantExtra/WORKSPACE @@ -4,7 +4,7 @@ NSYNC_COMMIT = "82b118aa7ace3132e517e2c467f8732978cf4023" NSYNC_SHA256 = "" -ENZYMEXLA_COMMIT = "8dc549db2c67dd4940743b6fbca96af2cab41de5" +ENZYMEXLA_COMMIT = "c5b0090d53998673b2f728b7590b97d7bc548d2b" ENZYMEXLA_SHA256 = "" diff --git a/src/Ops.jl b/src/Ops.jl index 252d87dc84..16e7aeaf65 100644 --- a/src/Ops.jl +++ b/src/Ops.jl @@ -934,6 +934,36 @@ end return TracedRArray{T,N}((), MLIR.IR.result(conv), result_size) end +@noinline function lapack_symm( + A::TracedRArray{T}, + B::TracedRArray{T}, + C::TracedRArray{T}, + alpha::TracedRNumber{T}, + beta::TracedRNumber{T}; + side::Symbol, + uplo::Symbol, + location=mlir_stacktrace("lapack_symm", @__FILE__, @__LINE__), +) where {T} + ctx = MLIR.IR.context() + ressize = size(C) + resT = mlir_type(TracedRArray{unwrapped_eltype(C),length(ressize)}, ressize) + + res = MLIR.IR.result( + enzymexla.lapack_symm( + A.mlir_data, + B.mlir_data, + C.mlir_data, + alpha.mlir_data, + beta.mlir_data; + output=resT, + side=MLIR.API.enzymexlaLapackSideAttrGet(ctx, side == :L ? 1 : 0), + uplo=MLIR.API.enzymexlaLapackUploAttrGet(ctx, uplo == :U ? 1 : 0), + location, + ), + ) + return TracedRArray{T,length(ressize)}((), res, ressize) +end + Base.@nospecializeinfer @noinline function dot_general( @nospecialize(lhs::TracedRArray{T1}), @nospecialize(rhs::TracedRArray{T2}); diff --git a/src/stdlibs/LinearAlgebra.jl b/src/stdlibs/LinearAlgebra.jl index d4820c9966..9831f704f4 100644 --- a/src/stdlibs/LinearAlgebra.jl +++ b/src/stdlibs/LinearAlgebra.jl @@ -273,6 +273,68 @@ function overloaded_mul!( return C end +function overloaded_mul!( + @nospecialize(C::TracedRArray{T,2} where {T}), + @nospecialize(A::Symmetric), + @nospecialize(B::AbstractMatrix), + α::Number=true, + β::Number=true, +) + # Promote to traced arrays + A = call_with_reactant(Reactant.promote_to, TracedRArray, parent(A)) + B = call_with_reactant(Reactant.promote_to, TracedRArray, B) + + # Dimension checks + if size(C) != (size(A, 1), size(B, 2)) + throw(DimensionMismatch("C=$(size(C)), A=$(size(A)), B=$(size(B))")) + end + + T = Reactant.unwrapped_eltype(C) + tmp = @opcall lapack_symm( + T.(materialize_traced_array(A)), + T.(materialize_traced_array(B)), + T.(materialize_traced_array(C)), + Reactant.promote_to(TracedRNumber{T}, α), + Reactant.promote_to(TracedRNumber{T}, β), + side=:L, + uplo=:U, + ) + + set_mlir_data!(C, get_mlir_data(tmp)) # TODO remove later, handling in place ops are weird + return C +end + +function overloaded_mul!( + @nospecialize(C::TracedRArray{T,2} where {T}), + @nospecialize(A::AbstractMatrix), + @nospecialize(B::Symmetric), + α::Number=true, + β::Number=true, +) + # Promote to traced arrays + A = call_with_reactant(Reactant.promote_to, TracedRArray, A) + B = call_with_reactant(Reactant.promote_to, TracedRArray, parent(B)) + + # Dimension checks + if size(C) != (size(A, 1), size(B, 2)) + throw(DimensionMismatch("C=$(size(C)), A=$(size(A)), B=$(size(B))")) + end + + T = Reactant.unwrapped_eltype(C) + tmp = @opcall lapack_symm( + T.(materialize_traced_array(A)), + T.(materialize_traced_array(B)), + T.(materialize_traced_array(C)), + Reactant.promote_to(TracedRNumber{T}, α), + Reactant.promote_to(TracedRNumber{T}, β), + side=:R, + uplo=:U, + ) + + set_mlir_data!(C, get_mlir_data(tmp)) # TODO remove later, handling in place ops are weird + return C +end + function LinearAlgebra.triu!(@nospecialize(X::TracedRArray{T,2}), k::Integer) where {T} iota_1 = @opcall iota(Int64, [size(X)...]; iota_dimension=1) iota_2 = @opcall subtract( diff --git a/test/integration/linear_algebra.jl b/test/integration/linear_algebra.jl index 5790bfc928..e6bc28913b 100644 --- a/test/integration/linear_algebra.jl +++ b/test/integration/linear_algebra.jl @@ -432,3 +432,34 @@ end 1e-2 end end + +@testset "Symmetric Multiplication" begin + @testset "F32" begin + A = Symmetric(rand(Float32,(10,10))) + B = rand(Float32,(10,10)) + C = rand(Float32,(10,10)) + A_ra = Reactant.to_rarray(A) + B_ra = Reactant.to_rarray(B) + C_ra = Reactant.to_rarray(C) + + alpha = rand(Float32) + beta = rand(Float32) + + @test @code_hlo optimize=false A_ra * B_ra * alpha + + end + @testset "F64" begin + A = Symmetric(rand(Float64,(10,10))) + B = rand(Float64,(10,10)) + C = rand(Float64,(10,10)) + A_ra = Reactant.to_rarray(A) + B_ra = Reactant.to_rarray(B) + C_ra = Reactant.to_rarray(C) + + alpha = rand(Float64) + beta = rand(Float64) + + @test @code_hlo optimize=false A_ra * B_ra * alpha + + end +end \ No newline at end of file