From cbf278560754dc59b9a0b72fb5143a87ebe9115e Mon Sep 17 00:00:00 2001 From: snonk Date: Tue, 28 Oct 2025 20:54:50 -0500 Subject: [PATCH 1/3] clean up symm --- deps/ReactantExtra/WORKSPACE | 2 +- deps/ReactantExtra/make-bindings.jl | 38 +- src/Ops.jl | 30 + src/mlir/Dialects/EnzymeXLA.jl | 974 +++++++++------------------- src/stdlibs/LinearAlgebra.jl | 62 ++ test/integration/linear_algebra.jl | 31 + 6 files changed, 449 insertions(+), 688 deletions(-) diff --git a/deps/ReactantExtra/WORKSPACE b/deps/ReactantExtra/WORKSPACE index cc83efc42f..d15a5f0726 100644 --- a/deps/ReactantExtra/WORKSPACE +++ b/deps/ReactantExtra/WORKSPACE @@ -4,7 +4,7 @@ NSYNC_COMMIT = "82b118aa7ace3132e517e2c467f8732978cf4023" NSYNC_SHA256 = "" -ENZYMEXLA_COMMIT = "8dc549db2c67dd4940743b6fbca96af2cab41de5" +ENZYMEXLA_COMMIT = "c5b0090d53998673b2f728b7590b97d7bc548d2b" ENZYMEXLA_SHA256 = "" diff --git a/deps/ReactantExtra/make-bindings.jl b/deps/ReactantExtra/make-bindings.jl index f84309fef1..ef14ab82f1 100644 --- a/deps/ReactantExtra/make-bindings.jl +++ b/deps/ReactantExtra/make-bindings.jl @@ -22,26 +22,26 @@ end src_dir = joinpath(dirname(dirname(@__DIR__)), "src") for file in [ - "Builtin.jl", - "Arith.jl", - "Affine.jl", - "Func.jl", - "Enzyme.jl", + # "Builtin.jl", + # "Arith.jl", + # "Affine.jl", + # "Func.jl", + # "Enzyme.jl", "EnzymeXLA.jl", - "StableHLO.jl", - "CHLO.jl", - "VHLO.jl", - "Llvm.jl", - "Nvvm.jl", - "Gpu.jl", - "Affine.jl", - "TPU.jl", - "MosaicGPU.jl", - "Triton.jl", - "Shardy.jl", - "MPI.jl", - "MemRef.jl", - "SparseTensor.jl", + # "StableHLO.jl", + # "CHLO.jl", + # "VHLO.jl", + # "Llvm.jl", + # "Nvvm.jl", + # "Gpu.jl", + # "Affine.jl", + # "TPU.jl", + # "MosaicGPU.jl", + # "Triton.jl", + # "Shardy.jl", + # "MPI.jl", + # "MemRef.jl", + # "SparseTensor.jl", ] build_file(joinpath(src_dir, "mlir", "Dialects", file)) end diff --git a/src/Ops.jl b/src/Ops.jl index 252d87dc84..cbfbd9decd 100644 --- a/src/Ops.jl +++ b/src/Ops.jl @@ -934,6 +934,36 @@ end return TracedRArray{T,N}((), MLIR.IR.result(conv), result_size) end +@noinline function lapack_symm( + A::TracedRArray{T}, + B::TracedRArray{T}, + C::TracedRArray{T}, + alpha::TracedRNumber{T}, + beta::TracedRNumber{T}; + side::Symbol, + uplo::Symbol, + location=mlir_stacktrace("lapack_symm", @__FILE__, @__LINE__), +) where {T} + ctx = MLIR.IR.context() + ressize = size(C) + resT = mlir_type(TracedRArray{unwrapped_eltype(C),length(ressize)}, ressize) + + res = MLIR.IR.result( + enzymexla.lapack_symm( + A.mlir_data, + B.mlir_data, + C.mlir_data, + alpha.mlir_data, + beta.mlir_data; + output=resT, + side=MLIR.API.enzymexlaLapackSideAttrGet(ctx, side == :L ? 1 : 0), + uplo=MLIR.API.enzymexlaLapackUploAttrGet(ctx, uplo == :U ? 1 : 0), + location, + ), + ) + return TracedRArray{resT,length(ressize)}((), res, ressize) +end + Base.@nospecializeinfer @noinline function dot_general( @nospecialize(lhs::TracedRArray{T1}), @nospecialize(rhs::TracedRArray{T2}); diff --git a/src/mlir/Dialects/EnzymeXLA.jl b/src/mlir/Dialects/EnzymeXLA.jl index 4de6d8a13f..239c9442ea 100755 --- a/src/mlir/Dialects/EnzymeXLA.jl +++ b/src/mlir/Dialects/EnzymeXLA.jl @@ -1,259 +1,168 @@ module enzymexla using ...IR -import ...IR: - NamedAttribute, - Value, - Location, - Block, - Region, - Attribute, - create_operation, - context, - IndexType +import ...IR: NamedAttribute, Value, Location, Block, Region, Attribute, create_operation, context, IndexType import ..Dialects: namedattribute, operandsegmentsizes import ...API -function scope( - operands::Vector{Value}; results::Vector{IR.Type}, region::Region, location=Location() -) - op_ty_results = IR.Type[results...,] - operands = Value[operands...,] - owned_regions = Region[region,] + + +function scope(operands::Vector{Value}; results::Vector{IR.Type}, region::Region, location=Location()) + op_ty_results = IR.Type[results..., ] + operands = Value[operands..., ] + owned_regions = Region[region, ] successors = Block[] attributes = NamedAttribute[] - - return create_operation( - "enzymexla.scope", - location; - operands, - owned_regions, - successors, - attributes, + + create_operation( + "enzymexla.scope", location; + operands, owned_regions, successors, attributes, results=op_ty_results, - result_inference=false, + result_inference=false ) end + function alternatives(; regions::Vector{Region}, location=Location()) op_ty_results = IR.Type[] operands = Value[] - owned_regions = Region[regions...,] + owned_regions = Region[regions..., ] successors = Block[] attributes = NamedAttribute[] - - return create_operation( - "enzymexla.alternatives", - location; - operands, - owned_regions, - successors, - attributes, + + create_operation( + "enzymexla.alternatives", location; + operands, owned_regions, successors, attributes, results=op_ty_results, - result_inference=false, + result_inference=false ) end + function barrier(indices::Vector{Value}; location=Location()) op_ty_results = IR.Type[] - operands = Value[indices...,] + operands = Value[indices..., ] owned_regions = Region[] successors = Block[] attributes = NamedAttribute[] - - return create_operation( - "enzymexla.barrier", - location; - operands, - owned_regions, - successors, - attributes, + + create_operation( + "enzymexla.barrier", location; + operands, owned_regions, successors, attributes, results=op_ty_results, - result_inference=false, + result_inference=false ) end -function cacheload( - memref::Value, indices::Vector{Value}; result::IR.Type, location=Location() -) - op_ty_results = IR.Type[result,] - operands = Value[memref, indices...] - owned_regions = Region[] - successors = Block[] - attributes = NamedAttribute[] - - return create_operation( - "enzymexla.cacheload", - location; - operands, - owned_regions, - successors, - attributes, - results=op_ty_results, - result_inference=false, - ) -end function comm_region(; result_0::Vector{IR.Type}, body::Region, location=Location()) - op_ty_results = IR.Type[result_0...,] + op_ty_results = IR.Type[result_0..., ] operands = Value[] - owned_regions = Region[body,] + owned_regions = Region[body, ] successors = Block[] attributes = NamedAttribute[] - - return create_operation( - "enzymexla.comm_region", - location; - operands, - owned_regions, - successors, - attributes, + + create_operation( + "enzymexla.comm_region", location; + operands, owned_regions, successors, attributes, results=op_ty_results, - result_inference=false, + result_inference=false ) end -function extend( - operand::Value; - result=nothing::Union{Nothing,IR.Type}, - lhs, - rhs, - dimension, - location=Location(), -) + +function extend(operand::Value; result=nothing::Union{Nothing, IR.Type}, lhs, rhs, dimension, location=Location()) op_ty_results = IR.Type[] - operands = Value[operand,] + operands = Value[operand, ] owned_regions = Region[] successors = Block[] - attributes = NamedAttribute[ - namedattribute("lhs", lhs), - namedattribute("rhs", rhs), - namedattribute("dimension", dimension), - ] + attributes = NamedAttribute[namedattribute("lhs", lhs), namedattribute("rhs", rhs), namedattribute("dimension", dimension), ] !isnothing(result) && push!(op_ty_results, result) - - return create_operation( - "enzymexla.extend", - location; - operands, - owned_regions, - successors, - attributes, + + create_operation( + "enzymexla.extend", location; + operands, owned_regions, successors, attributes, results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + result_inference=(length(op_ty_results) == 0 ? true : false) ) end -function gpu_block( - blockIndexX::Value, - blockIndexY::Value, - blockIndexZ::Value; - region::Region, - location=Location(), -) + +function gpu_block(blockIndexX::Value, blockIndexY::Value, blockIndexZ::Value; region::Region, location=Location()) op_ty_results = IR.Type[] - operands = Value[blockIndexX, blockIndexY, blockIndexZ] - owned_regions = Region[region,] + operands = Value[blockIndexX, blockIndexY, blockIndexZ, ] + owned_regions = Region[region, ] successors = Block[] attributes = NamedAttribute[] - - return create_operation( - "enzymexla.gpu_block", - location; - operands, - owned_regions, - successors, - attributes, + + create_operation( + "enzymexla.gpu_block", location; + operands, owned_regions, successors, attributes, results=op_ty_results, - result_inference=false, + result_inference=false ) end + function gpu_error(; result::IR.Type, region::Region, location=Location()) - op_ty_results = IR.Type[result,] + op_ty_results = IR.Type[result, ] operands = Value[] - owned_regions = Region[region,] + owned_regions = Region[region, ] successors = Block[] attributes = NamedAttribute[] - - return create_operation( - "enzymexla.gpu_error", - location; - operands, - owned_regions, - successors, - attributes, + + create_operation( + "enzymexla.gpu_error", location; + operands, owned_regions, successors, attributes, results=op_ty_results, - result_inference=false, + result_inference=false ) end + function gpu_kernel_address(; result::IR.Type, fn, location=Location()) - op_ty_results = IR.Type[result,] + op_ty_results = IR.Type[result, ] operands = Value[] owned_regions = Region[] successors = Block[] - attributes = NamedAttribute[namedattribute("fn", fn),] - - return create_operation( - "enzymexla.gpu_kernel_address", - location; - operands, - owned_regions, - successors, - attributes, + attributes = NamedAttribute[namedattribute("fn", fn), ] + + create_operation( + "enzymexla.gpu_kernel_address", location; + operands, owned_regions, successors, attributes, results=op_ty_results, - result_inference=false, + result_inference=false ) end -function gpu_occupancy( - blockSize::Value, - dynamicSMemSize::Value, - flags::Value; - result::IR.Type, - fn, - location=Location(), -) - op_ty_results = IR.Type[result,] - operands = Value[blockSize, dynamicSMemSize, flags] + +function gpu_occupancy(blockSize::Value, dynamicSMemSize::Value, flags::Value; result::IR.Type, fn, location=Location()) + op_ty_results = IR.Type[result, ] + operands = Value[blockSize, dynamicSMemSize, flags, ] owned_regions = Region[] successors = Block[] - attributes = NamedAttribute[namedattribute("fn", fn),] - - return create_operation( - "enzymexla.gpu_occupancy", - location; - operands, - owned_regions, - successors, - attributes, + attributes = NamedAttribute[namedattribute("fn", fn), ] + + create_operation( + "enzymexla.gpu_occupancy", location; + operands, owned_regions, successors, attributes, results=op_ty_results, - result_inference=false, + result_inference=false ) end -function gpu_thread( - threadIndexX::Value, - threadIndexY::Value, - threadIndexZ::Value; - region::Region, - location=Location(), -) + +function gpu_thread(threadIndexX::Value, threadIndexY::Value, threadIndexZ::Value; region::Region, location=Location()) op_ty_results = IR.Type[] - operands = Value[threadIndexX, threadIndexY, threadIndexZ] - owned_regions = Region[region,] + operands = Value[threadIndexX, threadIndexY, threadIndexZ, ] + owned_regions = Region[region, ] successors = Block[] attributes = NamedAttribute[] - - return create_operation( - "enzymexla.gpu_thread", - location; - operands, - owned_regions, - successors, - attributes, + + create_operation( + "enzymexla.gpu_thread", location; + operands, owned_regions, successors, attributes, results=op_ty_results, - result_inference=false, + result_inference=false ) end @@ -261,52 +170,38 @@ end `gpu_wrapper` The optional arguments to this operation are suggestions about what block -dimensions this gpu kernel should have - usually taken f rom kernel - launch params +dimensions this gpu kernel should have - usually taken from kernel launch +params """ -function gpu_wrapper( - blockDims::Vector{Value}; result::IR.Type, region::Region, location=Location() -) - op_ty_results = IR.Type[result,] - operands = Value[blockDims...,] - owned_regions = Region[region,] +function gpu_wrapper(blockDims::Vector{Value}; result::IR.Type, region::Region, location=Location()) + op_ty_results = IR.Type[result, ] + operands = Value[blockDims..., ] + owned_regions = Region[region, ] successors = Block[] attributes = NamedAttribute[] - - return create_operation( - "enzymexla.gpu_wrapper", - location; - operands, - owned_regions, - successors, - attributes, + + create_operation( + "enzymexla.gpu_wrapper", location; + operands, owned_regions, successors, attributes, results=op_ty_results, - result_inference=false, + result_inference=false ) end -function ml_gelu( - input::Value; - result=nothing::Union{Nothing,IR.Type}, - gelu_approximation, - location=Location(), -) + +function ml_gelu(input::Value; result=nothing::Union{Nothing, IR.Type}, gelu_approximation, location=Location()) op_ty_results = IR.Type[] - operands = Value[input,] + operands = Value[input, ] owned_regions = Region[] successors = Block[] - attributes = NamedAttribute[namedattribute("gelu_approximation", gelu_approximation),] + attributes = NamedAttribute[namedattribute("gelu_approximation", gelu_approximation), ] !isnothing(result) && push!(op_ty_results, result) - - return create_operation( - "enzymexla.ml.gelu", - location; - operands, - owned_regions, - successors, - attributes, + + create_operation( + "enzymexla.ml.gelu", location; + operands, owned_regions, successors, attributes, results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + result_inference=(length(op_ty_results) == 0 ? true : false) ) end @@ -315,31 +210,19 @@ end This operation is modeled after LAPACK\'s *GEMQR routines. """ -function lapack_gemqrt( - V::Value, - T::Value, - C::Value; - output::IR.Type, - side, - transpose=nothing, - location=Location(), -) - op_ty_results = IR.Type[output,] - operands = Value[V, T, C] +function lapack_gemqrt(V::Value, T::Value, C::Value; output::IR.Type, side, transpose=nothing, location=Location()) + op_ty_results = IR.Type[output, ] + operands = Value[V, T, C, ] owned_regions = Region[] successors = Block[] - attributes = NamedAttribute[namedattribute("side", side),] + attributes = NamedAttribute[namedattribute("side", side), ] !isnothing(transpose) && push!(attributes, namedattribute("transpose", transpose)) - - return create_operation( - "enzymexla.lapack.gemqrt", - location; - operands, - owned_regions, - successors, - attributes, + + create_operation( + "enzymexla.lapack.gemqrt", location; + operands, owned_regions, successors, attributes, results=op_ty_results, - result_inference=false, + result_inference=false ) end @@ -348,31 +231,23 @@ end This operation computes the QR factorization of a matrix using Householder reflections. Mathematically, it decomposes A into the product of an -orthogonal matri x Q and an upper triangular matrix R, - such that A = QR. +orthogonal matrix Q and an upper triangular matrix R, such that A = QR. - This operation is modeled after - LAPACK\'s *GEQRF routines, which returns the result in - the QR packed format. +This operation is modeled after LAPACK\'s *GEQRF routines, which returns the +result in the QR packed format. """ -function lapack_geqrf( - input::Value; output::IR.Type, tau::IR.Type, info::IR.Type, location=Location() -) - op_ty_results = IR.Type[output, tau, info] - operands = Value[input,] +function lapack_geqrf(input::Value; output::IR.Type, tau::IR.Type, info::IR.Type, location=Location()) + op_ty_results = IR.Type[output, tau, info, ] + operands = Value[input, ] owned_regions = Region[] successors = Block[] attributes = NamedAttribute[] - - return create_operation( - "enzymexla.lapack.geqrf", - location; - operands, - owned_regions, - successors, - attributes, + + create_operation( + "enzymexla.lapack.geqrf", location; + operands, owned_regions, successors, attributes, results=op_ty_results, - result_inference=false, + result_inference=false ) end @@ -386,190 +261,97 @@ orthogonal matrix Q and an upper triangular matrix R, such that A = QR. This operation is modeled after LAPACK\'s *GEQRT routines, which returns the result in the QR CompactWY format. """ -function lapack_geqrt( - input::Value; - output::IR.Type, - T::IR.Type, - info::IR.Type, - blocksize=nothing, - location=Location(), -) - op_ty_results = IR.Type[output, T, info] - operands = Value[input,] +function lapack_geqrt(input::Value; output::IR.Type, T::IR.Type, info::IR.Type, blocksize=nothing, location=Location()) + op_ty_results = IR.Type[output, T, info, ] + operands = Value[input, ] owned_regions = Region[] successors = Block[] attributes = NamedAttribute[] !isnothing(blocksize) && push!(attributes, namedattribute("blocksize", blocksize)) - - return create_operation( - "enzymexla.lapack.geqrt", - location; - operands, - owned_regions, - successors, - attributes, + + create_operation( + "enzymexla.lapack.geqrt", location; + operands, owned_regions, successors, attributes, results=op_ty_results, - result_inference=false, + result_inference=false ) end + function get_stream(; result::IR.Type, location=Location()) - op_ty_results = IR.Type[result,] + op_ty_results = IR.Type[result, ] operands = Value[] owned_regions = Region[] successors = Block[] attributes = NamedAttribute[] - - return create_operation( - "enzymexla.get_stream", - location; - operands, - owned_regions, - successors, - attributes, + + create_operation( + "enzymexla.get_stream", location; + operands, owned_regions, successors, attributes, results=op_ty_results, - result_inference=false, - ) -end - -function jit_call( - inputs::Vector{Value}; - result_0::Vector{IR.Type}, - fn, - backend_config=nothing, - operand_layouts=nothing, - result_layouts=nothing, - arg_attrs=nothing, - res_attrs=nothing, - output_operand_aliases=nothing, - xla_side_effect_free=nothing, - location=Location(), -) - op_ty_results = IR.Type[result_0...,] - operands = Value[inputs...,] + result_inference=false + ) +end + + +function jit_call(inputs::Vector{Value}; result_0::Vector{IR.Type}, fn, backend_config=nothing, operand_layouts=nothing, result_layouts=nothing, arg_attrs=nothing, res_attrs=nothing, output_operand_aliases=nothing, xla_side_effect_free=nothing, location=Location()) + op_ty_results = IR.Type[result_0..., ] + operands = Value[inputs..., ] owned_regions = Region[] successors = Block[] - attributes = NamedAttribute[namedattribute("fn", fn),] - !isnothing(backend_config) && - push!(attributes, namedattribute("backend_config", backend_config)) - !isnothing(operand_layouts) && - push!(attributes, namedattribute("operand_layouts", operand_layouts)) - !isnothing(result_layouts) && - push!(attributes, namedattribute("result_layouts", result_layouts)) + attributes = NamedAttribute[namedattribute("fn", fn), ] + !isnothing(backend_config) && push!(attributes, namedattribute("backend_config", backend_config)) + !isnothing(operand_layouts) && push!(attributes, namedattribute("operand_layouts", operand_layouts)) + !isnothing(result_layouts) && push!(attributes, namedattribute("result_layouts", result_layouts)) !isnothing(arg_attrs) && push!(attributes, namedattribute("arg_attrs", arg_attrs)) !isnothing(res_attrs) && push!(attributes, namedattribute("res_attrs", res_attrs)) - !isnothing(output_operand_aliases) && - push!(attributes, namedattribute("output_operand_aliases", output_operand_aliases)) - !isnothing(xla_side_effect_free) && - push!(attributes, namedattribute("xla_side_effect_free", xla_side_effect_free)) - - return create_operation( - "enzymexla.jit_call", - location; - operands, - owned_regions, - successors, - attributes, + !isnothing(output_operand_aliases) && push!(attributes, namedattribute("output_operand_aliases", output_operand_aliases)) + !isnothing(xla_side_effect_free) && push!(attributes, namedattribute("xla_side_effect_free", xla_side_effect_free)) + + create_operation( + "enzymexla.jit_call", location; + operands, owned_regions, successors, attributes, results=op_ty_results, - result_inference=false, - ) -end - -function kernel_call( - gridx::Value, - gridy::Value, - gridz::Value, - blockx::Value, - blocky::Value, - blockz::Value, - shmem::Value, - clusterx=nothing::Union{Nothing,Value}; - clustery=nothing::Union{Nothing,Value}, - clusterz=nothing::Union{Nothing,Value}, - inputs::Vector{Value}, - result_0::Vector{IR.Type}, - fn, - backend_config=nothing, - operand_layouts=nothing, - result_layouts=nothing, - arg_attrs=nothing, - res_attrs=nothing, - output_operand_aliases=nothing, - xla_side_effect_free=nothing, - location=Location(), -) - op_ty_results = IR.Type[result_0...,] - operands = Value[gridx, gridy, gridz, blockx, blocky, blockz, shmem, inputs...] + result_inference=false + ) +end + + +function kernel_call(gridx::Value, gridy::Value, gridz::Value, blockx::Value, blocky::Value, blockz::Value, shmem::Value, inputs::Vector{Value}; result_0::Vector{IR.Type}, fn, backend_config=nothing, operand_layouts=nothing, result_layouts=nothing, arg_attrs=nothing, res_attrs=nothing, output_operand_aliases=nothing, xla_side_effect_free=nothing, location=Location()) + op_ty_results = IR.Type[result_0..., ] + operands = Value[gridx, gridy, gridz, blockx, blocky, blockz, shmem, inputs..., ] owned_regions = Region[] successors = Block[] - attributes = NamedAttribute[namedattribute("fn", fn),] - !isnothing(clusterx) && push!(operands, clusterx) - !isnothing(clustery) && push!(operands, clustery) - !isnothing(clusterz) && push!(operands, clusterz) - push!( - attributes, - operandsegmentsizes([ - 1, - 1, - 1, - 1, - 1, - 1, - 1, - (clusterx == nothing) ? 0 : 1, - (clustery == nothing) ? 0 : 1, - (clusterz == nothing) ? 0 : 1, - length(inputs), - ]), - ) - !isnothing(backend_config) && - push!(attributes, namedattribute("backend_config", backend_config)) - !isnothing(operand_layouts) && - push!(attributes, namedattribute("operand_layouts", operand_layouts)) - !isnothing(result_layouts) && - push!(attributes, namedattribute("result_layouts", result_layouts)) + attributes = NamedAttribute[namedattribute("fn", fn), ] + !isnothing(backend_config) && push!(attributes, namedattribute("backend_config", backend_config)) + !isnothing(operand_layouts) && push!(attributes, namedattribute("operand_layouts", operand_layouts)) + !isnothing(result_layouts) && push!(attributes, namedattribute("result_layouts", result_layouts)) !isnothing(arg_attrs) && push!(attributes, namedattribute("arg_attrs", arg_attrs)) !isnothing(res_attrs) && push!(attributes, namedattribute("res_attrs", res_attrs)) - !isnothing(output_operand_aliases) && - push!(attributes, namedattribute("output_operand_aliases", output_operand_aliases)) - !isnothing(xla_side_effect_free) && - push!(attributes, namedattribute("xla_side_effect_free", xla_side_effect_free)) - - return create_operation( - "enzymexla.kernel_call", - location; - operands, - owned_regions, - successors, - attributes, + !isnothing(output_operand_aliases) && push!(attributes, namedattribute("output_operand_aliases", output_operand_aliases)) + !isnothing(xla_side_effect_free) && push!(attributes, namedattribute("xla_side_effect_free", xla_side_effect_free)) + + create_operation( + "enzymexla.kernel_call", location; + operands, owned_regions, successors, attributes, results=op_ty_results, - result_inference=false, + result_inference=false ) end -function linalg_lu( - input::Value; - output::IR.Type, - pivots::IR.Type, - permutation::IR.Type, - info::IR.Type, - location=Location(), -) - op_ty_results = IR.Type[output, pivots, permutation, info] - operands = Value[input,] + +function linalg_lu(input::Value; output::IR.Type, pivots::IR.Type, permutation::IR.Type, info::IR.Type, location=Location()) + op_ty_results = IR.Type[output, pivots, permutation, info, ] + operands = Value[input, ] owned_regions = Region[] successors = Block[] attributes = NamedAttribute[] - - return create_operation( - "enzymexla.linalg.lu", - location; - operands, - owned_regions, - successors, - attributes, + + create_operation( + "enzymexla.linalg.lu", location; + operands, owned_regions, successors, attributes, results=op_ty_results, - result_inference=false, + result_inference=false ) end @@ -591,68 +373,51 @@ that case, it returns a !gpu.async.token. %token = gpu.memcpy async [%dep] %dst, %src : memref, memref ``` """ -function memcpy( - asyncDependencies::Vector{Value}, - target::Value, - source::Value, - size::Value; - asyncToken=nothing::Union{Nothing,IR.Type}, - location=Location(), -) +function memcpy(asyncDependencies::Vector{Value}, target::Value, source::Value, size::Value; asyncToken=nothing::Union{Nothing, IR.Type}, location=Location()) op_ty_results = IR.Type[] - operands = Value[asyncDependencies..., target, source, size] + operands = Value[asyncDependencies..., target, source, size, ] owned_regions = Region[] successors = Block[] attributes = NamedAttribute[] !isnothing(asyncToken) && push!(op_ty_results, asyncToken) - - return create_operation( - "enzymexla.memcpy", - location; - operands, - owned_regions, - successors, - attributes, + + create_operation( + "enzymexla.memcpy", location; + operands, owned_regions, successors, attributes, results=op_ty_results, - result_inference=false, + result_inference=false ) end + function memref2pointer(source::Value; result::IR.Type, location=Location()) - op_ty_results = IR.Type[result,] - operands = Value[source,] + op_ty_results = IR.Type[result, ] + operands = Value[source, ] owned_regions = Region[] successors = Block[] attributes = NamedAttribute[] - - return create_operation( - "enzymexla.memref2pointer", - location; - operands, - owned_regions, - successors, - attributes, + + create_operation( + "enzymexla.memref2pointer", location; + operands, owned_regions, successors, attributes, results=op_ty_results, - result_inference=false, + result_inference=false ) end + function noop(blockDims::Vector{Value}; location=Location()) op_ty_results = IR.Type[] - operands = Value[blockDims...,] + operands = Value[blockDims..., ] owned_regions = Region[] successors = Block[] attributes = NamedAttribute[] - - return create_operation( - "enzymexla.noop", - location; - operands, - owned_regions, - successors, - attributes, + + create_operation( + "enzymexla.noop", location; + operands, owned_regions, successors, attributes, results=op_ty_results, - result_inference=false, + result_inference=false ) end @@ -662,21 +427,17 @@ end This operation is modeled after LAPACK\'s *ORGQR/*UNGQR routines. """ function lapack_orgqr(input::Value, tau::Value; output::IR.Type, location=Location()) - op_ty_results = IR.Type[output,] - operands = Value[input, tau] + op_ty_results = IR.Type[output, ] + operands = Value[input, tau, ] owned_regions = Region[] successors = Block[] attributes = NamedAttribute[] - - return create_operation( - "enzymexla.lapack.orgqr", - location; - operands, - owned_regions, - successors, - attributes, + + create_operation( + "enzymexla.lapack.orgqr", location; + operands, owned_regions, successors, attributes, results=op_ty_results, - result_inference=false, + result_inference=false ) end @@ -685,69 +446,51 @@ end This operation is modeled after LAPACK\'s *ORMQR routines. """ -function lapack_ormqr( - A::Value, - tau::Value, - C::Value; - output::IR.Type, - side, - transpose=nothing, - location=Location(), -) - op_ty_results = IR.Type[output,] - operands = Value[A, tau, C] +function lapack_ormqr(A::Value, tau::Value, C::Value; output::IR.Type, side, transpose=nothing, location=Location()) + op_ty_results = IR.Type[output, ] + operands = Value[A, tau, C, ] owned_regions = Region[] successors = Block[] - attributes = NamedAttribute[namedattribute("side", side),] + attributes = NamedAttribute[namedattribute("side", side), ] !isnothing(transpose) && push!(attributes, namedattribute("transpose", transpose)) - - return create_operation( - "enzymexla.lapack.ormqr", - location; - operands, - owned_regions, - successors, - attributes, + + create_operation( + "enzymexla.lapack.ormqr", location; + operands, owned_regions, successors, attributes, results=op_ty_results, - result_inference=false, + result_inference=false ) end + function pointer2memref(source::Value; result::IR.Type, location=Location()) - op_ty_results = IR.Type[result,] - operands = Value[source,] + op_ty_results = IR.Type[result, ] + operands = Value[source, ] owned_regions = Region[] successors = Block[] attributes = NamedAttribute[] - - return create_operation( - "enzymexla.pointer2memref", - location; - operands, - owned_regions, - successors, - attributes, + + create_operation( + "enzymexla.pointer2memref", location; + operands, owned_regions, successors, attributes, results=op_ty_results, - result_inference=false, + result_inference=false ) end + function polygeist_yield(; location=Location()) op_ty_results = IR.Type[] operands = Value[] owned_regions = Region[] successors = Block[] attributes = NamedAttribute[] - - return create_operation( - "enzymexla.polygeist_yield", - location; - operands, - owned_regions, - successors, - attributes, + + create_operation( + "enzymexla.polygeist_yield", location; + operands, owned_regions, successors, attributes, results=op_ty_results, - result_inference=false, + result_inference=false ) end @@ -765,139 +508,86 @@ will be a m x n trapezoidal matrix. This operation is modeled after the mathematical formulation of the QR factorization, and not after LAPACK\'s compact formats. """ -function linalg_qr( - input::Value; Q::IR.Type, R::IR.Type, algorithm=nothing, location=Location() -) - op_ty_results = IR.Type[Q, R] - operands = Value[input,] +function linalg_qr(input::Value; Q::IR.Type, R::IR.Type, algorithm=nothing, location=Location()) + op_ty_results = IR.Type[Q, R, ] + operands = Value[input, ] owned_regions = Region[] successors = Block[] attributes = NamedAttribute[] !isnothing(algorithm) && push!(attributes, namedattribute("algorithm", algorithm)) - - return create_operation( - "enzymexla.linalg.qr", - location; - operands, - owned_regions, - successors, - attributes, + + create_operation( + "enzymexla.linalg.qr", location; + operands, owned_regions, successors, attributes, results=op_ty_results, - result_inference=false, + result_inference=false ) end -function ml_relu(input::Value; result=nothing::Union{Nothing,IR.Type}, location=Location()) + +function ml_relu(input::Value; result=nothing::Union{Nothing, IR.Type}, location=Location()) op_ty_results = IR.Type[] - operands = Value[input,] + operands = Value[input, ] owned_regions = Region[] successors = Block[] attributes = NamedAttribute[] !isnothing(result) && push!(op_ty_results, result) - - return create_operation( - "enzymexla.ml.relu", - location; - operands, - owned_regions, - successors, - attributes, + + create_operation( + "enzymexla.ml.relu", location; + operands, owned_regions, successors, attributes, results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + result_inference=(length(op_ty_results) == 0 ? true : false) ) end -function rotate( - operand::Value; - result=nothing::Union{Nothing,IR.Type}, - amount, - dimension, - location=Location(), -) + +function rotate(operand::Value; result=nothing::Union{Nothing, IR.Type}, amount, dimension, location=Location()) op_ty_results = IR.Type[] - operands = Value[operand,] + operands = Value[operand, ] owned_regions = Region[] successors = Block[] - attributes = NamedAttribute[ - namedattribute("amount", amount), namedattribute("dimension", dimension) - ] + attributes = NamedAttribute[namedattribute("amount", amount), namedattribute("dimension", dimension), ] !isnothing(result) && push!(op_ty_results, result) - - return create_operation( - "enzymexla.rotate", - location; - operands, - owned_regions, - successors, - attributes, + + create_operation( + "enzymexla.rotate", location; + operands, owned_regions, successors, attributes, results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + result_inference=(length(op_ty_results) == 0 ? true : false) ) end -function linalg_svd( - input::Value; - U::IR.Type, - S::IR.Type, - Vt::IR.Type, - info::IR.Type, - full=nothing, - location=Location(), -) - op_ty_results = IR.Type[U, S, Vt, info] - operands = Value[input,] + +function linalg_svd(input::Value; U::IR.Type, S::IR.Type, Vt::IR.Type, info::IR.Type, full=nothing, location=Location()) + op_ty_results = IR.Type[U, S, Vt, info, ] + operands = Value[input, ] owned_regions = Region[] successors = Block[] attributes = NamedAttribute[] !isnothing(full) && push!(attributes, namedattribute("full", full)) - - return create_operation( - "enzymexla.linalg.svd", - location; - operands, - owned_regions, - successors, - attributes, + + create_operation( + "enzymexla.linalg.svd", location; + operands, owned_regions, successors, attributes, results=op_ty_results, - result_inference=false, + result_inference=false ) end -function stream2token(source::Value; result::IR.Type, location=Location()) - op_ty_results = IR.Type[result,] - operands = Value[source,] - owned_regions = Region[] - successors = Block[] - attributes = NamedAttribute[] - return create_operation( - "enzymexla.stream2token", - location; - operands, - owned_regions, - successors, - attributes, - results=op_ty_results, - result_inference=false, - ) -end - -function subindex(source::Value, index::Value; result::IR.Type, location=Location()) - op_ty_results = IR.Type[result,] - operands = Value[source, index] +function stream2token(source::Value; result::IR.Type, location=Location()) + op_ty_results = IR.Type[result, ] + operands = Value[source, ] owned_regions = Region[] successors = Block[] attributes = NamedAttribute[] - - return create_operation( - "enzymexla.subindex", - location; - operands, - owned_regions, - successors, - attributes, + + create_operation( + "enzymexla.stream2token", location; + operands, owned_regions, successors, attributes, results=op_ty_results, - result_inference=false, + result_inference=false ) end @@ -906,105 +596,53 @@ end C := alpha*A*B + beta*C, or C := alpha*B*A + beta*C, where alpha and beta are scalars, A is a symmetric matrix\" """ -function lapack_symm( - A::Value, - B::Value, - C::Value, - alpha::Value, - beta::Value; - output::IR.Type, - side, - uplo, - location=Location(), -) - op_ty_results = IR.Type[output,] - operands = Value[A, B, C, alpha, beta] +function lapack_symm(A::Value, B::Value, C::Value, alpha::Value, beta::Value; output::IR.Type, side, uplo, location=Location()) + op_ty_results = IR.Type[output, ] + operands = Value[A, B, C, alpha, beta, ] owned_regions = Region[] successors = Block[] - attributes = NamedAttribute[namedattribute("side", side), namedattribute("uplo", uplo)] - - return create_operation( - "enzymexla.lapack.symm", - location; - operands, - owned_regions, - successors, - attributes, + attributes = NamedAttribute[namedattribute("side", side), namedattribute("uplo", uplo), ] + + create_operation( + "enzymexla.lapack.symm", location; + operands, owned_regions, successors, attributes, results=op_ty_results, - result_inference=false, + result_inference=false ) end -function typeAlign(; result::IR.Type, source, location=Location()) - op_ty_results = IR.Type[result,] - operands = Value[] - owned_regions = Region[] - successors = Block[] - attributes = NamedAttribute[namedattribute("source", source),] - - return create_operation( - "enzymexla.typeAlign", - location; - operands, - owned_regions, - successors, - attributes, - results=op_ty_results, - result_inference=false, - ) -end -function wrap( - operand::Value; - result=nothing::Union{Nothing,IR.Type}, - lhs, - rhs, - dimension, - location=Location(), -) +function wrap(operand::Value; result=nothing::Union{Nothing, IR.Type}, lhs, rhs, dimension, location=Location()) op_ty_results = IR.Type[] - operands = Value[operand,] + operands = Value[operand, ] owned_regions = Region[] successors = Block[] - attributes = NamedAttribute[ - namedattribute("lhs", lhs), - namedattribute("rhs", rhs), - namedattribute("dimension", dimension), - ] + attributes = NamedAttribute[namedattribute("lhs", lhs), namedattribute("rhs", rhs), namedattribute("dimension", dimension), ] !isnothing(result) && push!(op_ty_results, result) - - return create_operation( - "enzymexla.wrap", - location; - operands, - owned_regions, - successors, - attributes, + + create_operation( + "enzymexla.wrap", location; + operands, owned_regions, successors, attributes, results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + result_inference=(length(op_ty_results) == 0 ? true : false) ) end -function xla_wrapper( - inputs::Vector{Value}; fn, arg_attrs=nothing, res_attrs=nothing, location=Location() -) + +function xla_wrapper(inputs::Vector{Value}; fn, arg_attrs=nothing, res_attrs=nothing, location=Location()) op_ty_results = IR.Type[] - operands = Value[inputs...,] + operands = Value[inputs..., ] owned_regions = Region[] successors = Block[] - attributes = NamedAttribute[namedattribute("fn", fn),] + attributes = NamedAttribute[namedattribute("fn", fn), ] !isnothing(arg_attrs) && push!(attributes, namedattribute("arg_attrs", arg_attrs)) !isnothing(res_attrs) && push!(attributes, namedattribute("res_attrs", res_attrs)) - - return create_operation( - "enzymexla.xla_wrapper", - location; - operands, - owned_regions, - successors, - attributes, + + create_operation( + "enzymexla.xla_wrapper", location; + operands, owned_regions, successors, attributes, results=op_ty_results, - result_inference=false, + result_inference=false ) end diff --git a/src/stdlibs/LinearAlgebra.jl b/src/stdlibs/LinearAlgebra.jl index d4820c9966..9831f704f4 100644 --- a/src/stdlibs/LinearAlgebra.jl +++ b/src/stdlibs/LinearAlgebra.jl @@ -273,6 +273,68 @@ function overloaded_mul!( return C end +function overloaded_mul!( + @nospecialize(C::TracedRArray{T,2} where {T}), + @nospecialize(A::Symmetric), + @nospecialize(B::AbstractMatrix), + α::Number=true, + β::Number=true, +) + # Promote to traced arrays + A = call_with_reactant(Reactant.promote_to, TracedRArray, parent(A)) + B = call_with_reactant(Reactant.promote_to, TracedRArray, B) + + # Dimension checks + if size(C) != (size(A, 1), size(B, 2)) + throw(DimensionMismatch("C=$(size(C)), A=$(size(A)), B=$(size(B))")) + end + + T = Reactant.unwrapped_eltype(C) + tmp = @opcall lapack_symm( + T.(materialize_traced_array(A)), + T.(materialize_traced_array(B)), + T.(materialize_traced_array(C)), + Reactant.promote_to(TracedRNumber{T}, α), + Reactant.promote_to(TracedRNumber{T}, β), + side=:L, + uplo=:U, + ) + + set_mlir_data!(C, get_mlir_data(tmp)) # TODO remove later, handling in place ops are weird + return C +end + +function overloaded_mul!( + @nospecialize(C::TracedRArray{T,2} where {T}), + @nospecialize(A::AbstractMatrix), + @nospecialize(B::Symmetric), + α::Number=true, + β::Number=true, +) + # Promote to traced arrays + A = call_with_reactant(Reactant.promote_to, TracedRArray, A) + B = call_with_reactant(Reactant.promote_to, TracedRArray, parent(B)) + + # Dimension checks + if size(C) != (size(A, 1), size(B, 2)) + throw(DimensionMismatch("C=$(size(C)), A=$(size(A)), B=$(size(B))")) + end + + T = Reactant.unwrapped_eltype(C) + tmp = @opcall lapack_symm( + T.(materialize_traced_array(A)), + T.(materialize_traced_array(B)), + T.(materialize_traced_array(C)), + Reactant.promote_to(TracedRNumber{T}, α), + Reactant.promote_to(TracedRNumber{T}, β), + side=:R, + uplo=:U, + ) + + set_mlir_data!(C, get_mlir_data(tmp)) # TODO remove later, handling in place ops are weird + return C +end + function LinearAlgebra.triu!(@nospecialize(X::TracedRArray{T,2}), k::Integer) where {T} iota_1 = @opcall iota(Int64, [size(X)...]; iota_dimension=1) iota_2 = @opcall subtract( diff --git a/test/integration/linear_algebra.jl b/test/integration/linear_algebra.jl index 5790bfc928..e6bc28913b 100644 --- a/test/integration/linear_algebra.jl +++ b/test/integration/linear_algebra.jl @@ -432,3 +432,34 @@ end 1e-2 end end + +@testset "Symmetric Multiplication" begin + @testset "F32" begin + A = Symmetric(rand(Float32,(10,10))) + B = rand(Float32,(10,10)) + C = rand(Float32,(10,10)) + A_ra = Reactant.to_rarray(A) + B_ra = Reactant.to_rarray(B) + C_ra = Reactant.to_rarray(C) + + alpha = rand(Float32) + beta = rand(Float32) + + @test @code_hlo optimize=false A_ra * B_ra * alpha + + end + @testset "F64" begin + A = Symmetric(rand(Float64,(10,10))) + B = rand(Float64,(10,10)) + C = rand(Float64,(10,10)) + A_ra = Reactant.to_rarray(A) + B_ra = Reactant.to_rarray(B) + C_ra = Reactant.to_rarray(C) + + alpha = rand(Float64) + beta = rand(Float64) + + @test @code_hlo optimize=false A_ra * B_ra * alpha + + end +end \ No newline at end of file From 8e2e42479e2ae9be67657ab54be9d1302c82ee85 Mon Sep 17 00:00:00 2001 From: snonk Date: Wed, 29 Oct 2025 13:29:18 -0500 Subject: [PATCH 2/3] revert --- deps/ReactantExtra/make-bindings.jl | 38 +- src/mlir/Dialects/EnzymeXLA.jl | 974 +++++++++++++++++++--------- 2 files changed, 687 insertions(+), 325 deletions(-) diff --git a/deps/ReactantExtra/make-bindings.jl b/deps/ReactantExtra/make-bindings.jl index ef14ab82f1..f84309fef1 100644 --- a/deps/ReactantExtra/make-bindings.jl +++ b/deps/ReactantExtra/make-bindings.jl @@ -22,26 +22,26 @@ end src_dir = joinpath(dirname(dirname(@__DIR__)), "src") for file in [ - # "Builtin.jl", - # "Arith.jl", - # "Affine.jl", - # "Func.jl", - # "Enzyme.jl", + "Builtin.jl", + "Arith.jl", + "Affine.jl", + "Func.jl", + "Enzyme.jl", "EnzymeXLA.jl", - # "StableHLO.jl", - # "CHLO.jl", - # "VHLO.jl", - # "Llvm.jl", - # "Nvvm.jl", - # "Gpu.jl", - # "Affine.jl", - # "TPU.jl", - # "MosaicGPU.jl", - # "Triton.jl", - # "Shardy.jl", - # "MPI.jl", - # "MemRef.jl", - # "SparseTensor.jl", + "StableHLO.jl", + "CHLO.jl", + "VHLO.jl", + "Llvm.jl", + "Nvvm.jl", + "Gpu.jl", + "Affine.jl", + "TPU.jl", + "MosaicGPU.jl", + "Triton.jl", + "Shardy.jl", + "MPI.jl", + "MemRef.jl", + "SparseTensor.jl", ] build_file(joinpath(src_dir, "mlir", "Dialects", file)) end diff --git a/src/mlir/Dialects/EnzymeXLA.jl b/src/mlir/Dialects/EnzymeXLA.jl index 239c9442ea..4de6d8a13f 100755 --- a/src/mlir/Dialects/EnzymeXLA.jl +++ b/src/mlir/Dialects/EnzymeXLA.jl @@ -1,168 +1,259 @@ module enzymexla using ...IR -import ...IR: NamedAttribute, Value, Location, Block, Region, Attribute, create_operation, context, IndexType +import ...IR: + NamedAttribute, + Value, + Location, + Block, + Region, + Attribute, + create_operation, + context, + IndexType import ..Dialects: namedattribute, operandsegmentsizes import ...API - - -function scope(operands::Vector{Value}; results::Vector{IR.Type}, region::Region, location=Location()) - op_ty_results = IR.Type[results..., ] - operands = Value[operands..., ] - owned_regions = Region[region, ] +function scope( + operands::Vector{Value}; results::Vector{IR.Type}, region::Region, location=Location() +) + op_ty_results = IR.Type[results...,] + operands = Value[operands...,] + owned_regions = Region[region,] successors = Block[] attributes = NamedAttribute[] - - create_operation( - "enzymexla.scope", location; - operands, owned_regions, successors, attributes, + + return create_operation( + "enzymexla.scope", + location; + operands, + owned_regions, + successors, + attributes, results=op_ty_results, - result_inference=false + result_inference=false, ) end - function alternatives(; regions::Vector{Region}, location=Location()) op_ty_results = IR.Type[] operands = Value[] - owned_regions = Region[regions..., ] + owned_regions = Region[regions...,] successors = Block[] attributes = NamedAttribute[] - - create_operation( - "enzymexla.alternatives", location; - operands, owned_regions, successors, attributes, + + return create_operation( + "enzymexla.alternatives", + location; + operands, + owned_regions, + successors, + attributes, results=op_ty_results, - result_inference=false + result_inference=false, ) end - function barrier(indices::Vector{Value}; location=Location()) op_ty_results = IR.Type[] - operands = Value[indices..., ] + operands = Value[indices...,] owned_regions = Region[] successors = Block[] attributes = NamedAttribute[] - - create_operation( - "enzymexla.barrier", location; - operands, owned_regions, successors, attributes, + + return create_operation( + "enzymexla.barrier", + location; + operands, + owned_regions, + successors, + attributes, results=op_ty_results, - result_inference=false + result_inference=false, ) end +function cacheload( + memref::Value, indices::Vector{Value}; result::IR.Type, location=Location() +) + op_ty_results = IR.Type[result,] + operands = Value[memref, indices...] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "enzymexla.cacheload", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end function comm_region(; result_0::Vector{IR.Type}, body::Region, location=Location()) - op_ty_results = IR.Type[result_0..., ] + op_ty_results = IR.Type[result_0...,] operands = Value[] - owned_regions = Region[body, ] + owned_regions = Region[body,] successors = Block[] attributes = NamedAttribute[] - - create_operation( - "enzymexla.comm_region", location; - operands, owned_regions, successors, attributes, + + return create_operation( + "enzymexla.comm_region", + location; + operands, + owned_regions, + successors, + attributes, results=op_ty_results, - result_inference=false + result_inference=false, ) end - -function extend(operand::Value; result=nothing::Union{Nothing, IR.Type}, lhs, rhs, dimension, location=Location()) +function extend( + operand::Value; + result=nothing::Union{Nothing,IR.Type}, + lhs, + rhs, + dimension, + location=Location(), +) op_ty_results = IR.Type[] - operands = Value[operand, ] + operands = Value[operand,] owned_regions = Region[] successors = Block[] - attributes = NamedAttribute[namedattribute("lhs", lhs), namedattribute("rhs", rhs), namedattribute("dimension", dimension), ] + attributes = NamedAttribute[ + namedattribute("lhs", lhs), + namedattribute("rhs", rhs), + namedattribute("dimension", dimension), + ] !isnothing(result) && push!(op_ty_results, result) - - create_operation( - "enzymexla.extend", location; - operands, owned_regions, successors, attributes, + + return create_operation( + "enzymexla.extend", + location; + operands, + owned_regions, + successors, + attributes, results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false) + result_inference=(length(op_ty_results) == 0 ? true : false), ) end - -function gpu_block(blockIndexX::Value, blockIndexY::Value, blockIndexZ::Value; region::Region, location=Location()) +function gpu_block( + blockIndexX::Value, + blockIndexY::Value, + blockIndexZ::Value; + region::Region, + location=Location(), +) op_ty_results = IR.Type[] - operands = Value[blockIndexX, blockIndexY, blockIndexZ, ] - owned_regions = Region[region, ] + operands = Value[blockIndexX, blockIndexY, blockIndexZ] + owned_regions = Region[region,] successors = Block[] attributes = NamedAttribute[] - - create_operation( - "enzymexla.gpu_block", location; - operands, owned_regions, successors, attributes, + + return create_operation( + "enzymexla.gpu_block", + location; + operands, + owned_regions, + successors, + attributes, results=op_ty_results, - result_inference=false + result_inference=false, ) end - function gpu_error(; result::IR.Type, region::Region, location=Location()) - op_ty_results = IR.Type[result, ] + op_ty_results = IR.Type[result,] operands = Value[] - owned_regions = Region[region, ] + owned_regions = Region[region,] successors = Block[] attributes = NamedAttribute[] - - create_operation( - "enzymexla.gpu_error", location; - operands, owned_regions, successors, attributes, + + return create_operation( + "enzymexla.gpu_error", + location; + operands, + owned_regions, + successors, + attributes, results=op_ty_results, - result_inference=false + result_inference=false, ) end - function gpu_kernel_address(; result::IR.Type, fn, location=Location()) - op_ty_results = IR.Type[result, ] + op_ty_results = IR.Type[result,] operands = Value[] owned_regions = Region[] successors = Block[] - attributes = NamedAttribute[namedattribute("fn", fn), ] - - create_operation( - "enzymexla.gpu_kernel_address", location; - operands, owned_regions, successors, attributes, + attributes = NamedAttribute[namedattribute("fn", fn),] + + return create_operation( + "enzymexla.gpu_kernel_address", + location; + operands, + owned_regions, + successors, + attributes, results=op_ty_results, - result_inference=false + result_inference=false, ) end - -function gpu_occupancy(blockSize::Value, dynamicSMemSize::Value, flags::Value; result::IR.Type, fn, location=Location()) - op_ty_results = IR.Type[result, ] - operands = Value[blockSize, dynamicSMemSize, flags, ] +function gpu_occupancy( + blockSize::Value, + dynamicSMemSize::Value, + flags::Value; + result::IR.Type, + fn, + location=Location(), +) + op_ty_results = IR.Type[result,] + operands = Value[blockSize, dynamicSMemSize, flags] owned_regions = Region[] successors = Block[] - attributes = NamedAttribute[namedattribute("fn", fn), ] - - create_operation( - "enzymexla.gpu_occupancy", location; - operands, owned_regions, successors, attributes, + attributes = NamedAttribute[namedattribute("fn", fn),] + + return create_operation( + "enzymexla.gpu_occupancy", + location; + operands, + owned_regions, + successors, + attributes, results=op_ty_results, - result_inference=false + result_inference=false, ) end - -function gpu_thread(threadIndexX::Value, threadIndexY::Value, threadIndexZ::Value; region::Region, location=Location()) +function gpu_thread( + threadIndexX::Value, + threadIndexY::Value, + threadIndexZ::Value; + region::Region, + location=Location(), +) op_ty_results = IR.Type[] - operands = Value[threadIndexX, threadIndexY, threadIndexZ, ] - owned_regions = Region[region, ] + operands = Value[threadIndexX, threadIndexY, threadIndexZ] + owned_regions = Region[region,] successors = Block[] attributes = NamedAttribute[] - - create_operation( - "enzymexla.gpu_thread", location; - operands, owned_regions, successors, attributes, + + return create_operation( + "enzymexla.gpu_thread", + location; + operands, + owned_regions, + successors, + attributes, results=op_ty_results, - result_inference=false + result_inference=false, ) end @@ -170,38 +261,52 @@ end `gpu_wrapper` The optional arguments to this operation are suggestions about what block -dimensions this gpu kernel should have - usually taken from kernel launch -params +dimensions this gpu kernel should have - usually taken f rom kernel + launch params """ -function gpu_wrapper(blockDims::Vector{Value}; result::IR.Type, region::Region, location=Location()) - op_ty_results = IR.Type[result, ] - operands = Value[blockDims..., ] - owned_regions = Region[region, ] +function gpu_wrapper( + blockDims::Vector{Value}; result::IR.Type, region::Region, location=Location() +) + op_ty_results = IR.Type[result,] + operands = Value[blockDims...,] + owned_regions = Region[region,] successors = Block[] attributes = NamedAttribute[] - - create_operation( - "enzymexla.gpu_wrapper", location; - operands, owned_regions, successors, attributes, + + return create_operation( + "enzymexla.gpu_wrapper", + location; + operands, + owned_regions, + successors, + attributes, results=op_ty_results, - result_inference=false + result_inference=false, ) end - -function ml_gelu(input::Value; result=nothing::Union{Nothing, IR.Type}, gelu_approximation, location=Location()) +function ml_gelu( + input::Value; + result=nothing::Union{Nothing,IR.Type}, + gelu_approximation, + location=Location(), +) op_ty_results = IR.Type[] - operands = Value[input, ] + operands = Value[input,] owned_regions = Region[] successors = Block[] - attributes = NamedAttribute[namedattribute("gelu_approximation", gelu_approximation), ] + attributes = NamedAttribute[namedattribute("gelu_approximation", gelu_approximation),] !isnothing(result) && push!(op_ty_results, result) - - create_operation( - "enzymexla.ml.gelu", location; - operands, owned_regions, successors, attributes, + + return create_operation( + "enzymexla.ml.gelu", + location; + operands, + owned_regions, + successors, + attributes, results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false) + result_inference=(length(op_ty_results) == 0 ? true : false), ) end @@ -210,19 +315,31 @@ end This operation is modeled after LAPACK\'s *GEMQR routines. """ -function lapack_gemqrt(V::Value, T::Value, C::Value; output::IR.Type, side, transpose=nothing, location=Location()) - op_ty_results = IR.Type[output, ] - operands = Value[V, T, C, ] +function lapack_gemqrt( + V::Value, + T::Value, + C::Value; + output::IR.Type, + side, + transpose=nothing, + location=Location(), +) + op_ty_results = IR.Type[output,] + operands = Value[V, T, C] owned_regions = Region[] successors = Block[] - attributes = NamedAttribute[namedattribute("side", side), ] + attributes = NamedAttribute[namedattribute("side", side),] !isnothing(transpose) && push!(attributes, namedattribute("transpose", transpose)) - - create_operation( - "enzymexla.lapack.gemqrt", location; - operands, owned_regions, successors, attributes, + + return create_operation( + "enzymexla.lapack.gemqrt", + location; + operands, + owned_regions, + successors, + attributes, results=op_ty_results, - result_inference=false + result_inference=false, ) end @@ -231,23 +348,31 @@ end This operation computes the QR factorization of a matrix using Householder reflections. Mathematically, it decomposes A into the product of an -orthogonal matrix Q and an upper triangular matrix R, such that A = QR. +orthogonal matri x Q and an upper triangular matrix R, + such that A = QR. -This operation is modeled after LAPACK\'s *GEQRF routines, which returns the -result in the QR packed format. + This operation is modeled after + LAPACK\'s *GEQRF routines, which returns the result in + the QR packed format. """ -function lapack_geqrf(input::Value; output::IR.Type, tau::IR.Type, info::IR.Type, location=Location()) - op_ty_results = IR.Type[output, tau, info, ] - operands = Value[input, ] +function lapack_geqrf( + input::Value; output::IR.Type, tau::IR.Type, info::IR.Type, location=Location() +) + op_ty_results = IR.Type[output, tau, info] + operands = Value[input,] owned_regions = Region[] successors = Block[] attributes = NamedAttribute[] - - create_operation( - "enzymexla.lapack.geqrf", location; - operands, owned_regions, successors, attributes, + + return create_operation( + "enzymexla.lapack.geqrf", + location; + operands, + owned_regions, + successors, + attributes, results=op_ty_results, - result_inference=false + result_inference=false, ) end @@ -261,97 +386,190 @@ orthogonal matrix Q and an upper triangular matrix R, such that A = QR. This operation is modeled after LAPACK\'s *GEQRT routines, which returns the result in the QR CompactWY format. """ -function lapack_geqrt(input::Value; output::IR.Type, T::IR.Type, info::IR.Type, blocksize=nothing, location=Location()) - op_ty_results = IR.Type[output, T, info, ] - operands = Value[input, ] +function lapack_geqrt( + input::Value; + output::IR.Type, + T::IR.Type, + info::IR.Type, + blocksize=nothing, + location=Location(), +) + op_ty_results = IR.Type[output, T, info] + operands = Value[input,] owned_regions = Region[] successors = Block[] attributes = NamedAttribute[] !isnothing(blocksize) && push!(attributes, namedattribute("blocksize", blocksize)) - - create_operation( - "enzymexla.lapack.geqrt", location; - operands, owned_regions, successors, attributes, + + return create_operation( + "enzymexla.lapack.geqrt", + location; + operands, + owned_regions, + successors, + attributes, results=op_ty_results, - result_inference=false + result_inference=false, ) end - function get_stream(; result::IR.Type, location=Location()) - op_ty_results = IR.Type[result, ] + op_ty_results = IR.Type[result,] operands = Value[] owned_regions = Region[] successors = Block[] attributes = NamedAttribute[] - - create_operation( - "enzymexla.get_stream", location; - operands, owned_regions, successors, attributes, - results=op_ty_results, - result_inference=false - ) -end - -function jit_call(inputs::Vector{Value}; result_0::Vector{IR.Type}, fn, backend_config=nothing, operand_layouts=nothing, result_layouts=nothing, arg_attrs=nothing, res_attrs=nothing, output_operand_aliases=nothing, xla_side_effect_free=nothing, location=Location()) - op_ty_results = IR.Type[result_0..., ] - operands = Value[inputs..., ] + return create_operation( + "enzymexla.get_stream", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function jit_call( + inputs::Vector{Value}; + result_0::Vector{IR.Type}, + fn, + backend_config=nothing, + operand_layouts=nothing, + result_layouts=nothing, + arg_attrs=nothing, + res_attrs=nothing, + output_operand_aliases=nothing, + xla_side_effect_free=nothing, + location=Location(), +) + op_ty_results = IR.Type[result_0...,] + operands = Value[inputs...,] owned_regions = Region[] successors = Block[] - attributes = NamedAttribute[namedattribute("fn", fn), ] - !isnothing(backend_config) && push!(attributes, namedattribute("backend_config", backend_config)) - !isnothing(operand_layouts) && push!(attributes, namedattribute("operand_layouts", operand_layouts)) - !isnothing(result_layouts) && push!(attributes, namedattribute("result_layouts", result_layouts)) + attributes = NamedAttribute[namedattribute("fn", fn),] + !isnothing(backend_config) && + push!(attributes, namedattribute("backend_config", backend_config)) + !isnothing(operand_layouts) && + push!(attributes, namedattribute("operand_layouts", operand_layouts)) + !isnothing(result_layouts) && + push!(attributes, namedattribute("result_layouts", result_layouts)) !isnothing(arg_attrs) && push!(attributes, namedattribute("arg_attrs", arg_attrs)) !isnothing(res_attrs) && push!(attributes, namedattribute("res_attrs", res_attrs)) - !isnothing(output_operand_aliases) && push!(attributes, namedattribute("output_operand_aliases", output_operand_aliases)) - !isnothing(xla_side_effect_free) && push!(attributes, namedattribute("xla_side_effect_free", xla_side_effect_free)) - - create_operation( - "enzymexla.jit_call", location; - operands, owned_regions, successors, attributes, + !isnothing(output_operand_aliases) && + push!(attributes, namedattribute("output_operand_aliases", output_operand_aliases)) + !isnothing(xla_side_effect_free) && + push!(attributes, namedattribute("xla_side_effect_free", xla_side_effect_free)) + + return create_operation( + "enzymexla.jit_call", + location; + operands, + owned_regions, + successors, + attributes, results=op_ty_results, - result_inference=false - ) -end - - -function kernel_call(gridx::Value, gridy::Value, gridz::Value, blockx::Value, blocky::Value, blockz::Value, shmem::Value, inputs::Vector{Value}; result_0::Vector{IR.Type}, fn, backend_config=nothing, operand_layouts=nothing, result_layouts=nothing, arg_attrs=nothing, res_attrs=nothing, output_operand_aliases=nothing, xla_side_effect_free=nothing, location=Location()) - op_ty_results = IR.Type[result_0..., ] - operands = Value[gridx, gridy, gridz, blockx, blocky, blockz, shmem, inputs..., ] + result_inference=false, + ) +end + +function kernel_call( + gridx::Value, + gridy::Value, + gridz::Value, + blockx::Value, + blocky::Value, + blockz::Value, + shmem::Value, + clusterx=nothing::Union{Nothing,Value}; + clustery=nothing::Union{Nothing,Value}, + clusterz=nothing::Union{Nothing,Value}, + inputs::Vector{Value}, + result_0::Vector{IR.Type}, + fn, + backend_config=nothing, + operand_layouts=nothing, + result_layouts=nothing, + arg_attrs=nothing, + res_attrs=nothing, + output_operand_aliases=nothing, + xla_side_effect_free=nothing, + location=Location(), +) + op_ty_results = IR.Type[result_0...,] + operands = Value[gridx, gridy, gridz, blockx, blocky, blockz, shmem, inputs...] owned_regions = Region[] successors = Block[] - attributes = NamedAttribute[namedattribute("fn", fn), ] - !isnothing(backend_config) && push!(attributes, namedattribute("backend_config", backend_config)) - !isnothing(operand_layouts) && push!(attributes, namedattribute("operand_layouts", operand_layouts)) - !isnothing(result_layouts) && push!(attributes, namedattribute("result_layouts", result_layouts)) + attributes = NamedAttribute[namedattribute("fn", fn),] + !isnothing(clusterx) && push!(operands, clusterx) + !isnothing(clustery) && push!(operands, clustery) + !isnothing(clusterz) && push!(operands, clusterz) + push!( + attributes, + operandsegmentsizes([ + 1, + 1, + 1, + 1, + 1, + 1, + 1, + (clusterx == nothing) ? 0 : 1, + (clustery == nothing) ? 0 : 1, + (clusterz == nothing) ? 0 : 1, + length(inputs), + ]), + ) + !isnothing(backend_config) && + push!(attributes, namedattribute("backend_config", backend_config)) + !isnothing(operand_layouts) && + push!(attributes, namedattribute("operand_layouts", operand_layouts)) + !isnothing(result_layouts) && + push!(attributes, namedattribute("result_layouts", result_layouts)) !isnothing(arg_attrs) && push!(attributes, namedattribute("arg_attrs", arg_attrs)) !isnothing(res_attrs) && push!(attributes, namedattribute("res_attrs", res_attrs)) - !isnothing(output_operand_aliases) && push!(attributes, namedattribute("output_operand_aliases", output_operand_aliases)) - !isnothing(xla_side_effect_free) && push!(attributes, namedattribute("xla_side_effect_free", xla_side_effect_free)) - - create_operation( - "enzymexla.kernel_call", location; - operands, owned_regions, successors, attributes, + !isnothing(output_operand_aliases) && + push!(attributes, namedattribute("output_operand_aliases", output_operand_aliases)) + !isnothing(xla_side_effect_free) && + push!(attributes, namedattribute("xla_side_effect_free", xla_side_effect_free)) + + return create_operation( + "enzymexla.kernel_call", + location; + operands, + owned_regions, + successors, + attributes, results=op_ty_results, - result_inference=false + result_inference=false, ) end - -function linalg_lu(input::Value; output::IR.Type, pivots::IR.Type, permutation::IR.Type, info::IR.Type, location=Location()) - op_ty_results = IR.Type[output, pivots, permutation, info, ] - operands = Value[input, ] +function linalg_lu( + input::Value; + output::IR.Type, + pivots::IR.Type, + permutation::IR.Type, + info::IR.Type, + location=Location(), +) + op_ty_results = IR.Type[output, pivots, permutation, info] + operands = Value[input,] owned_regions = Region[] successors = Block[] attributes = NamedAttribute[] - - create_operation( - "enzymexla.linalg.lu", location; - operands, owned_regions, successors, attributes, + + return create_operation( + "enzymexla.linalg.lu", + location; + operands, + owned_regions, + successors, + attributes, results=op_ty_results, - result_inference=false + result_inference=false, ) end @@ -373,51 +591,68 @@ that case, it returns a !gpu.async.token. %token = gpu.memcpy async [%dep] %dst, %src : memref, memref ``` """ -function memcpy(asyncDependencies::Vector{Value}, target::Value, source::Value, size::Value; asyncToken=nothing::Union{Nothing, IR.Type}, location=Location()) +function memcpy( + asyncDependencies::Vector{Value}, + target::Value, + source::Value, + size::Value; + asyncToken=nothing::Union{Nothing,IR.Type}, + location=Location(), +) op_ty_results = IR.Type[] - operands = Value[asyncDependencies..., target, source, size, ] + operands = Value[asyncDependencies..., target, source, size] owned_regions = Region[] successors = Block[] attributes = NamedAttribute[] !isnothing(asyncToken) && push!(op_ty_results, asyncToken) - - create_operation( - "enzymexla.memcpy", location; - operands, owned_regions, successors, attributes, + + return create_operation( + "enzymexla.memcpy", + location; + operands, + owned_regions, + successors, + attributes, results=op_ty_results, - result_inference=false + result_inference=false, ) end - function memref2pointer(source::Value; result::IR.Type, location=Location()) - op_ty_results = IR.Type[result, ] - operands = Value[source, ] + op_ty_results = IR.Type[result,] + operands = Value[source,] owned_regions = Region[] successors = Block[] attributes = NamedAttribute[] - - create_operation( - "enzymexla.memref2pointer", location; - operands, owned_regions, successors, attributes, + + return create_operation( + "enzymexla.memref2pointer", + location; + operands, + owned_regions, + successors, + attributes, results=op_ty_results, - result_inference=false + result_inference=false, ) end - function noop(blockDims::Vector{Value}; location=Location()) op_ty_results = IR.Type[] - operands = Value[blockDims..., ] + operands = Value[blockDims...,] owned_regions = Region[] successors = Block[] attributes = NamedAttribute[] - - create_operation( - "enzymexla.noop", location; - operands, owned_regions, successors, attributes, + + return create_operation( + "enzymexla.noop", + location; + operands, + owned_regions, + successors, + attributes, results=op_ty_results, - result_inference=false + result_inference=false, ) end @@ -427,17 +662,21 @@ end This operation is modeled after LAPACK\'s *ORGQR/*UNGQR routines. """ function lapack_orgqr(input::Value, tau::Value; output::IR.Type, location=Location()) - op_ty_results = IR.Type[output, ] - operands = Value[input, tau, ] + op_ty_results = IR.Type[output,] + operands = Value[input, tau] owned_regions = Region[] successors = Block[] attributes = NamedAttribute[] - - create_operation( - "enzymexla.lapack.orgqr", location; - operands, owned_regions, successors, attributes, + + return create_operation( + "enzymexla.lapack.orgqr", + location; + operands, + owned_regions, + successors, + attributes, results=op_ty_results, - result_inference=false + result_inference=false, ) end @@ -446,51 +685,69 @@ end This operation is modeled after LAPACK\'s *ORMQR routines. """ -function lapack_ormqr(A::Value, tau::Value, C::Value; output::IR.Type, side, transpose=nothing, location=Location()) - op_ty_results = IR.Type[output, ] - operands = Value[A, tau, C, ] +function lapack_ormqr( + A::Value, + tau::Value, + C::Value; + output::IR.Type, + side, + transpose=nothing, + location=Location(), +) + op_ty_results = IR.Type[output,] + operands = Value[A, tau, C] owned_regions = Region[] successors = Block[] - attributes = NamedAttribute[namedattribute("side", side), ] + attributes = NamedAttribute[namedattribute("side", side),] !isnothing(transpose) && push!(attributes, namedattribute("transpose", transpose)) - - create_operation( - "enzymexla.lapack.ormqr", location; - operands, owned_regions, successors, attributes, + + return create_operation( + "enzymexla.lapack.ormqr", + location; + operands, + owned_regions, + successors, + attributes, results=op_ty_results, - result_inference=false + result_inference=false, ) end - function pointer2memref(source::Value; result::IR.Type, location=Location()) - op_ty_results = IR.Type[result, ] - operands = Value[source, ] + op_ty_results = IR.Type[result,] + operands = Value[source,] owned_regions = Region[] successors = Block[] attributes = NamedAttribute[] - - create_operation( - "enzymexla.pointer2memref", location; - operands, owned_regions, successors, attributes, + + return create_operation( + "enzymexla.pointer2memref", + location; + operands, + owned_regions, + successors, + attributes, results=op_ty_results, - result_inference=false + result_inference=false, ) end - function polygeist_yield(; location=Location()) op_ty_results = IR.Type[] operands = Value[] owned_regions = Region[] successors = Block[] attributes = NamedAttribute[] - - create_operation( - "enzymexla.polygeist_yield", location; - operands, owned_regions, successors, attributes, + + return create_operation( + "enzymexla.polygeist_yield", + location; + operands, + owned_regions, + successors, + attributes, results=op_ty_results, - result_inference=false + result_inference=false, ) end @@ -508,86 +765,139 @@ will be a m x n trapezoidal matrix. This operation is modeled after the mathematical formulation of the QR factorization, and not after LAPACK\'s compact formats. """ -function linalg_qr(input::Value; Q::IR.Type, R::IR.Type, algorithm=nothing, location=Location()) - op_ty_results = IR.Type[Q, R, ] - operands = Value[input, ] +function linalg_qr( + input::Value; Q::IR.Type, R::IR.Type, algorithm=nothing, location=Location() +) + op_ty_results = IR.Type[Q, R] + operands = Value[input,] owned_regions = Region[] successors = Block[] attributes = NamedAttribute[] !isnothing(algorithm) && push!(attributes, namedattribute("algorithm", algorithm)) - - create_operation( - "enzymexla.linalg.qr", location; - operands, owned_regions, successors, attributes, + + return create_operation( + "enzymexla.linalg.qr", + location; + operands, + owned_regions, + successors, + attributes, results=op_ty_results, - result_inference=false + result_inference=false, ) end - -function ml_relu(input::Value; result=nothing::Union{Nothing, IR.Type}, location=Location()) +function ml_relu(input::Value; result=nothing::Union{Nothing,IR.Type}, location=Location()) op_ty_results = IR.Type[] - operands = Value[input, ] + operands = Value[input,] owned_regions = Region[] successors = Block[] attributes = NamedAttribute[] !isnothing(result) && push!(op_ty_results, result) - - create_operation( - "enzymexla.ml.relu", location; - operands, owned_regions, successors, attributes, + + return create_operation( + "enzymexla.ml.relu", + location; + operands, + owned_regions, + successors, + attributes, results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false) + result_inference=(length(op_ty_results) == 0 ? true : false), ) end - -function rotate(operand::Value; result=nothing::Union{Nothing, IR.Type}, amount, dimension, location=Location()) +function rotate( + operand::Value; + result=nothing::Union{Nothing,IR.Type}, + amount, + dimension, + location=Location(), +) op_ty_results = IR.Type[] - operands = Value[operand, ] + operands = Value[operand,] owned_regions = Region[] successors = Block[] - attributes = NamedAttribute[namedattribute("amount", amount), namedattribute("dimension", dimension), ] + attributes = NamedAttribute[ + namedattribute("amount", amount), namedattribute("dimension", dimension) + ] !isnothing(result) && push!(op_ty_results, result) - - create_operation( - "enzymexla.rotate", location; - operands, owned_regions, successors, attributes, + + return create_operation( + "enzymexla.rotate", + location; + operands, + owned_regions, + successors, + attributes, results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false) + result_inference=(length(op_ty_results) == 0 ? true : false), ) end - -function linalg_svd(input::Value; U::IR.Type, S::IR.Type, Vt::IR.Type, info::IR.Type, full=nothing, location=Location()) - op_ty_results = IR.Type[U, S, Vt, info, ] - operands = Value[input, ] +function linalg_svd( + input::Value; + U::IR.Type, + S::IR.Type, + Vt::IR.Type, + info::IR.Type, + full=nothing, + location=Location(), +) + op_ty_results = IR.Type[U, S, Vt, info] + operands = Value[input,] owned_regions = Region[] successors = Block[] attributes = NamedAttribute[] !isnothing(full) && push!(attributes, namedattribute("full", full)) - - create_operation( - "enzymexla.linalg.svd", location; - operands, owned_regions, successors, attributes, + + return create_operation( + "enzymexla.linalg.svd", + location; + operands, + owned_regions, + successors, + attributes, results=op_ty_results, - result_inference=false + result_inference=false, ) end - function stream2token(source::Value; result::IR.Type, location=Location()) - op_ty_results = IR.Type[result, ] - operands = Value[source, ] + op_ty_results = IR.Type[result,] + operands = Value[source,] owned_regions = Region[] successors = Block[] attributes = NamedAttribute[] - - create_operation( - "enzymexla.stream2token", location; - operands, owned_regions, successors, attributes, + + return create_operation( + "enzymexla.stream2token", + location; + operands, + owned_regions, + successors, + attributes, results=op_ty_results, - result_inference=false + result_inference=false, + ) +end + +function subindex(source::Value, index::Value; result::IR.Type, location=Location()) + op_ty_results = IR.Type[result,] + operands = Value[source, index] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "enzymexla.subindex", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, ) end @@ -596,53 +906,105 @@ end C := alpha*A*B + beta*C, or C := alpha*B*A + beta*C, where alpha and beta are scalars, A is a symmetric matrix\" """ -function lapack_symm(A::Value, B::Value, C::Value, alpha::Value, beta::Value; output::IR.Type, side, uplo, location=Location()) - op_ty_results = IR.Type[output, ] - operands = Value[A, B, C, alpha, beta, ] +function lapack_symm( + A::Value, + B::Value, + C::Value, + alpha::Value, + beta::Value; + output::IR.Type, + side, + uplo, + location=Location(), +) + op_ty_results = IR.Type[output,] + operands = Value[A, B, C, alpha, beta] owned_regions = Region[] successors = Block[] - attributes = NamedAttribute[namedattribute("side", side), namedattribute("uplo", uplo), ] - - create_operation( - "enzymexla.lapack.symm", location; - operands, owned_regions, successors, attributes, + attributes = NamedAttribute[namedattribute("side", side), namedattribute("uplo", uplo)] + + return create_operation( + "enzymexla.lapack.symm", + location; + operands, + owned_regions, + successors, + attributes, results=op_ty_results, - result_inference=false + result_inference=false, ) end +function typeAlign(; result::IR.Type, source, location=Location()) + op_ty_results = IR.Type[result,] + operands = Value[] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("source", source),] + + return create_operation( + "enzymexla.typeAlign", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end -function wrap(operand::Value; result=nothing::Union{Nothing, IR.Type}, lhs, rhs, dimension, location=Location()) +function wrap( + operand::Value; + result=nothing::Union{Nothing,IR.Type}, + lhs, + rhs, + dimension, + location=Location(), +) op_ty_results = IR.Type[] - operands = Value[operand, ] + operands = Value[operand,] owned_regions = Region[] successors = Block[] - attributes = NamedAttribute[namedattribute("lhs", lhs), namedattribute("rhs", rhs), namedattribute("dimension", dimension), ] + attributes = NamedAttribute[ + namedattribute("lhs", lhs), + namedattribute("rhs", rhs), + namedattribute("dimension", dimension), + ] !isnothing(result) && push!(op_ty_results, result) - - create_operation( - "enzymexla.wrap", location; - operands, owned_regions, successors, attributes, + + return create_operation( + "enzymexla.wrap", + location; + operands, + owned_regions, + successors, + attributes, results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false) + result_inference=(length(op_ty_results) == 0 ? true : false), ) end - -function xla_wrapper(inputs::Vector{Value}; fn, arg_attrs=nothing, res_attrs=nothing, location=Location()) +function xla_wrapper( + inputs::Vector{Value}; fn, arg_attrs=nothing, res_attrs=nothing, location=Location() +) op_ty_results = IR.Type[] - operands = Value[inputs..., ] + operands = Value[inputs...,] owned_regions = Region[] successors = Block[] - attributes = NamedAttribute[namedattribute("fn", fn), ] + attributes = NamedAttribute[namedattribute("fn", fn),] !isnothing(arg_attrs) && push!(attributes, namedattribute("arg_attrs", arg_attrs)) !isnothing(res_attrs) && push!(attributes, namedattribute("res_attrs", res_attrs)) - - create_operation( - "enzymexla.xla_wrapper", location; - operands, owned_regions, successors, attributes, + + return create_operation( + "enzymexla.xla_wrapper", + location; + operands, + owned_regions, + successors, + attributes, results=op_ty_results, - result_inference=false + result_inference=false, ) end From 01f9d3f1fa5f21662e81628ff6e229e57712ce81 Mon Sep 17 00:00:00 2001 From: snonk Date: Wed, 29 Oct 2025 13:58:06 -0500 Subject: [PATCH 3/3] fix return type --- src/Ops.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Ops.jl b/src/Ops.jl index cbfbd9decd..16e7aeaf65 100644 --- a/src/Ops.jl +++ b/src/Ops.jl @@ -961,7 +961,7 @@ end location, ), ) - return TracedRArray{resT,length(ressize)}((), res, ressize) + return TracedRArray{T,length(ressize)}((), res, ressize) end Base.@nospecializeinfer @noinline function dot_general(