Skip to content

Commit 6677300

Browse files
albertomercuriowsmoses
authored andcommitted
Tracing check within KernelAbstractions
1 parent c23f255 commit 6677300

File tree

2 files changed

+10
-1
lines changed

2 files changed

+10
-1
lines changed

ext/ReactantKernelAbstractionsExt.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
module ReactantKernelAbstractionsExt
22

33
using Reactant: Reactant
4+
using ReactantCore: ReactantCore
45

56
using Adapt: Adapt
67
using KernelAbstractions: KernelAbstractions
@@ -101,6 +102,14 @@ function tokw(ndrange, workgroupsize, obj, args...)
101102
end
102103

103104
function (obj::KA.Kernel{ReactantBackend})(args...; ndrange=nothing, workgroupsize=nothing)
105+
# If we're already inside a compilation/tracing context, or if any arguments are traced,
106+
# we should trace through this kernel call instead of trying to compile it again.
107+
if Reactant.within_compile() || any(ReactantCore.is_traced, args)
108+
return Reactant.call_with_reactant(
109+
Reactant.ka_with_reactant, ndrange, workgroupsize, obj, args...
110+
)
111+
end
112+
104113
if Reactant.precompiling()
105114
Reactant.@code_hlo optimize = false tokw(ndrange, workgroupsize, obj, args...)
106115
else

src/Overlay.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ for (cT, aT, bT) in (
133133
# Inference barrier is required when calling function recursively within
134134
# overload. This is required since otherwise type inference will think this
135135
# is a recursive edge rather than a call to the base method
136-
Base.inferencebarrier(LinearAlgebra.mul!)(C, A, B, α, β)
136+
Base.inferencebarrier(LinearAlgebra.mul!)(C2, A2, B2, α, β)
137137
end
138138
return C
139139
end

0 commit comments

Comments
 (0)