Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions ext/ReactantKernelAbstractionsExt.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
module ReactantKernelAbstractionsExt

using Reactant: Reactant
using ReactantCore: ReactantCore

using Adapt: Adapt
using KernelAbstractions: KernelAbstractions
Expand Down Expand Up @@ -101,6 +102,14 @@ function tokw(ndrange, workgroupsize, obj, args...)
end

function (obj::KA.Kernel{ReactantBackend})(args...; ndrange=nothing, workgroupsize=nothing)
# If we're already inside a compilation/tracing context, or if any arguments are traced,
# we should trace through this kernel call instead of trying to compile it again.
if Reactant.within_compile() || any(ReactantCore.is_traced, args)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this seems extraneous can this be done without?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I don't put it I get the error

ERROR: "Cannot trace existing trace type"
Stacktrace:
  [1] make_tracer(seen::Reactant.OrderedIdDict{…}, prev::Reactant.TracedRArray{…}, path::Any, mode::Reactant.TraceMode; toscalar::Bool, tobatch::Nothing, sharding::Any, runtime::Any, kwargs::@Kwargs{})
    @ Reactant ~/.julia/dev/Reactant/src/Tracing.jl:1298
  [2] prepare_mlir_fn_args(args::Tuple{…}, name::String, concretein::Bool, toscalar::Bool, argprefix::Symbol, runtime::Val{…}, optimize_then_pad::Bool, do_transpose::Bool, input_shardings::Nothing, verify_arg_names::Nothing)
    @ Reactant.TracedUtils ~/.julia/dev/Reactant/src/TracedUtils.jl:450
  [3] make_mlir_fn(f::typeof(ReactantKernelAbstractionsExt.tokw), args::Tuple{…}, kwargs::@NamedTuple{}, name::String, concretein::Bool; toscalar::Bool, return_dialect::Symbol, args_in_result::Symbol, construct_function_without_args::Bool, do_transpose::Bool, input_shardings::Nothing, output_shardings::Nothing, runtime::Val{…}, verify_arg_names::Nothing, argprefix::Symbol, resprefix::Symbol, resargprefix::Symbol, num_replicas::Int64, optimize_then_pad::Bool)
    @ Reactant.TracedUtils ~/.julia/dev/Reactant/src/TracedUtils.jl:321
  [4] make_mlir_fn
    @ ~/.julia/dev/Reactant/src/TracedUtils.jl:275 [inlined]
  [5] compile_mlir!(mod::Reactant.MLIR.IR.Module, f::typeof(ReactantKernelAbstractionsExt.tokw), args::Tuple{…}, compile_options::CompileOptions, callcache::Dict{…}, sdycache::Dict{…}; fn_kwargs::@NamedTuple{}, backend::String, runtime::Val{…}, legalize_stablehlo_to_mhlo::Bool, client::Reactant.XLA.PJRT.Client, kwargs::@Kwargs{})
    @ Reactant.Compiler ~/.julia/dev/Reactant/src/Compiler.jl:1608
  [6] compile_mlir!
    @ ~/.julia/dev/Reactant/src/Compiler.jl:1572 [inlined]
  [7] compile_xla(f::Function, args::Tuple{…}; before_xla_optimizations::Bool, client::Nothing, serializable::Bool, kwargs::@Kwargs{…})
    @ Reactant.Compiler ~/.julia/dev/Reactant/src/Compiler.jl:3500
  [8] compile_xla
    @ ~/.julia/dev/Reactant/src/Compiler.jl:3472 [inlined]
  [9] compile(f::Function, args::Tuple{…}; kwargs::@Kwargs{…})
    @ Reactant.Compiler ~/.julia/dev/Reactant/src/Compiler.jl:3576
 [10] macro expansion
    @ ~/.julia/dev/Reactant/src/Compiler.jl:2649 [inlined]
 [11] (::KernelAbstractions.Kernel{…})(::Reactant.TracedRArray{…}, ::Vararg{…}; ndrange::Int64, workgroupsize::Nothing)
    @ ReactantKernelAbstractionsExt ~/.julia/dev/Reactant/ext/ReactantKernelAbstractionsExt.jl:116
 [12] Kernel
    @ ~/.julia/dev/Reactant/ext/ReactantKernelAbstractionsExt.jl:104 [inlined]
 [13] spmv!
    @ ~/.julia/dev/Reactant/testsparse/script.jl:48 [inlined]
 [14] mul!(y::Reactant.TracedRArray{…}, A::GenericSparseMatrixCSR{…}, x::Reactant.TracedRArray{…}, α::Bool, β::Bool)
    @ Main ~/.julia/dev/Reactant/testsparse/script.jl:66
 [15] #mul!
    @ ~/.julia/dev/Reactant/src/Overlay.jl:136 [inlined]
 [16] (::Nothing)(none::typeof(mul!), none::Reactant.TracedRArray{…}, none::GenericSparseMatrixCSR{…}, none::Reactant.TracedRArray{…}, none::Bool, none::Bool)
    @ Reactant ./<missing>:0
 [17] call_with_reactant(::typeof(mul!), ::Reactant.TracedRArray{…}, ::GenericSparseMatrixCSR{…}, ::Reactant.TracedRArray{…}, ::Bool, ::Bool)
    @ Reactant ~/.julia/dev/Reactant/src/utils.jl:519
 [18] #mul!
    @ ~/.julia/dev/Reactant/src/Overlay.jl:143 [inlined]
 [19] (::Nothing)(none::typeof(mul!), none::Reactant.TracedRArray{…}, none::GenericSparseMatrixCSR{…}, none::Reactant.TracedRArray{…})
    @ Reactant ./<missing>:0
 [20] #mul!
    @ ~/.julia/dev/Reactant/src/Overlay.jl:143 [inlined]
 [21] call_with_reactant(::typeof(mul!), ::Reactant.TracedRArray{…}, ::GenericSparseMatrixCSR{…}, ::Reactant.TracedRArray{…})
    @ Reactant ~/.julia/dev/Reactant/src/utils.jl:0
 [22] make_mlir_fn(f::typeof(mul!), args::Tuple{…}, kwargs::@NamedTuple{}, name::String, concretein::Bool; toscalar::Bool, return_dialect::Symbol, args_in_result::Symbol, construct_function_without_args::Bool, do_transpose::Bool, input_shardings::Nothing, output_shardings::Nothing, runtime::Val{…}, verify_arg_names::Nothing, argprefix::Symbol, resprefix::Symbol, resargprefix::Symbol, num_replicas::Int64, optimize_then_pad::Bool)
    @ Reactant.TracedUtils ~/.julia/dev/Reactant/src/TracedUtils.jl:345
 [23] make_mlir_fn
    @ ~/.julia/dev/Reactant/src/TracedUtils.jl:275 [inlined]
 [24] compile_mlir!(mod::Reactant.MLIR.IR.Module, f::typeof(mul!), args::Tuple{…}, compile_options::CompileOptions, callcache::Dict{…}, sdycache::Dict{…}; fn_kwargs::@NamedTuple{}, backend::String, runtime::Val{…}, legalize_stablehlo_to_mhlo::Bool, client::Reactant.XLA.PJRT.Client, kwargs::@Kwargs{})
    @ Reactant.Compiler ~/.julia/dev/Reactant/src/Compiler.jl:1608
 [25] compile_mlir!
    @ ~/.julia/dev/Reactant/src/Compiler.jl:1572 [inlined]
 [26] compile_xla(f::Function, args::Tuple{…}; before_xla_optimizations::Bool, client::Nothing, serializable::Bool, kwargs::@Kwargs{…})
    @ Reactant.Compiler ~/.julia/dev/Reactant/src/Compiler.jl:3500
 [27] compile_xla
    @ ~/.julia/dev/Reactant/src/Compiler.jl:3472 [inlined]
 [28] compile(f::Function, args::Tuple{…}; kwargs::@Kwargs{…})
    @ Reactant.Compiler ~/.julia/dev/Reactant/src/Compiler.jl:3576
 [29] top-level scope
    @ ~/.julia/dev/Reactant/src/Compiler.jl:2649
Some type information was truncated. Use `show(err)` to see complete types.

Another alternative it to remove this check and specify the method directly as

for (cT, aT, bT) in (
    (:AbstractVector, :AnyDenseMatrix, :AbstractVector),
    (:AbstractMatrix, :AnyDenseMatrix, :AbstractVecOrMat),
)
    @eval begin
        @reactant_overlay @noinline function LinearAlgebra.mul!(
            C::$cT, A::$aT, B::$bT, α::Number, β::Number
        )

where AnyDenseMatrix is something like

const AnyDenseMatrix = Union{DenseMatrix, Transpose{Any, <:DenseMatrix}, Symmetric{Any, <:DenseMatrix}, UpperTriangular{Any, <:DenseMatrix}} # And all the other possible wrappers

This basically keeps the orgiginal code unchanged.

I don't know which case you do prefer.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It does sound a bit dangerous to assume that we know in advance every possible wrapper of a dense matrix? Not all of them are in Base or LinearAlgebra

return Reactant.call_with_reactant(
Reactant.ka_with_reactant, ndrange, workgroupsize, obj, args...
)
end

if Reactant.precompiling()
Reactant.@code_hlo optimize = false tokw(ndrange, workgroupsize, obj, args...)
else
Expand Down
2 changes: 2 additions & 0 deletions ext/ReactantSparseArraysExt/ReactantSparseArraysExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,6 @@ using SparseArrays:
include("Errors.jl")
include("ReadOnly.jl")

Reactant.use_overlayed_version(::AbstractSparseArray) = false

end
9 changes: 5 additions & 4 deletions src/Overlay.jl
Original file line number Diff line number Diff line change
Expand Up @@ -121,18 +121,19 @@ for (cT, aT, bT) in (
@reactant_overlay @noinline function LinearAlgebra.mul!(
C::$cT, A::$aT, B::$bT, α::Number, β::Number
)
A, B = aos_to_soa(A), aos_to_soa(B)
A2, B2 = aos_to_soa(A), aos_to_soa(B)
C2 = aos_to_soa(C)
if use_overlayed_version((C2, A, B))
TracedLinearAlgebra.overloaded_mul!(C2, A, B, α, β)
# A2 can also be a SparseMatrix, which should be handled by its own methods
if use_overlayed_version(A2) && use_overlayed_version((C2, A2, B2))
TracedLinearAlgebra.overloaded_mul!(C2, A2, B2, α, β)
if C2 !== C
C .= C2
end
else
# Inference barrier is required when calling function recursively within
# overload. This is required since otherwise type inference will think this
# is a recursive edge rather than a call to the base method
Base.inferencebarrier(LinearAlgebra.mul!)(C, A, B, α, β)
Base.inferencebarrier(LinearAlgebra.mul!)(C2, A2, B2, α, β)
end
return C
end
Expand Down