Skip to content
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
GPUCompiler = "61eb1bfa-7361-4325-ad38-22787b887f55"
HTTP = "cd3eb016-35fb-5094-929b-558a96fad6f3"
LLVM = "929cbde3-209d-540e-8aea-75f648917ca0"
LLVMOpenMP_jll = "1d63c593-3942-5779-bab2-d838dc0a180e"
Expand All @@ -35,7 +36,6 @@ CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
DLFP8Types = "f4c16678-4a16-415b-82ef-ed337c5d6c7c"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
Float8s = "81dfefd7-55b0-40c6-a251-db853704e186"
GPUCompiler = "61eb1bfa-7361-4325-ad38-22787b887f55"
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
Expand Down
5 changes: 2 additions & 3 deletions ext/ReactantCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -617,7 +617,7 @@ end
f::LLVMFunc{F,tt}; shmem::Union{Integer,Base.Callable}=0, max_threads::Integer=0
) where {F,tt}
return CUDA.launch_configuration(
Base.inferencebarrier(CUDA.cufunction)(f.f, Tuple{tt.parameters[2:end]...}).fun;
call_with_native(CUDA.cufunction, f.f, Tuple{tt.parameters[2:end]...}).fun;
shmem,
max_threads,
)
Expand Down Expand Up @@ -1465,7 +1465,7 @@ end
@static if !Sys.isapple()
@setup_workload begin
Reactant.initialize_dialect()

init_jit()
if Reactant.XLA.REACTANT_XLA_RUNTIME == "PJRT"
client = Reactant.XLA.PJRT.CPUClient(; checkcount=false)
elseif Reactant.XLA.REACTANT_XLA_RUNTIME == "IFRT"
Expand Down Expand Up @@ -1504,7 +1504,6 @@ end
Reactant.XLA.free_client(client)
client.client = C_NULL
Reactant.deinitialize_dialect()
Reactant.clear_oc_cache()
end
end

Expand Down
10 changes: 5 additions & 5 deletions ext/ReactantNNlibExt/Overlay.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,23 @@
if any(Reactant.use_overlayed_version, (y, x, w))
overloaded_conv!(y, x, w, cdims; kwargs...)
else
Base.inferencebarrier(NNlib.conv!)(y, x, w, cdims; kwargs...)
call_with_native(NNlib.conv!, y, x, w, cdims; kwargs...)
end
end

@reactant_overlay @noinline function NNlib.maxpool!(y, x, pdims::NNlib.PoolDims; kwargs...)
if any(Reactant.use_overlayed_version, (y, x))
overloaded_maxpool!(y, x, pdims; kwargs...)
else
Base.inferencebarrier(NNlib.maxpool!)(y, x, pdims; kwargs...)
call_with_native(NNlib.maxpool!, y, x, pdims; kwargs...)
end
end

@reactant_overlay @noinline function NNlib.meanpool!(y, x, pdims::NNlib.PoolDims; kwargs...)
if any(Reactant.use_overlayed_version, (y, x))
overloaded_meanpool!(y, x, pdims; kwargs...)
else
Base.inferencebarrier(NNlib.meanpool!)(y, x, pdims; kwargs...)
call_with_native(NNlib.meanpool!, y, x, pdims; kwargs...)
end
end

Expand All @@ -28,7 +28,7 @@ end
if any(Reactant.use_overlayed_version, (dw, x, dy))
overloaded_∇conv_filter!(dw, x, dy, cdims; kwargs...)
else
Base.inferencebarrier(NNlib.∇conv_filter!)(dw, x, dy, cdims; kwargs...)
call_with_native(NNlib.∇conv_filter!, dw, x, dy, cdims; kwargs...)
end
end

Expand All @@ -38,6 +38,6 @@ end
if any(Reactant.use_overlayed_version, (dx, dy, w))
overloaded_∇conv_data!(dx, dy, w, cdims; kwargs...)
else
Base.inferencebarrier(NNlib.∇conv_data!)(dx, dy, w, cdims; kwargs...)
call_with_native(NNlib.∇conv_data!, dx, dy, w, cdims; kwargs...)
end
end
2 changes: 1 addition & 1 deletion ext/ReactantZygoteExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ using Enzyme: Enzyme, Reverse, Active, Const, Duplicated
and hence reliance on this behavior is strongly discouraged."
return Enzyme.gradient(Reverse, Const(f), args...)
else
return Base.inferencebarrier(Zygote.gradient)(CallWithReactant(f), args...)
return call_with_native(Zygote.gradient, CallWithReactant(f), args...)
end
end

Expand Down
110 changes: 0 additions & 110 deletions src/Interpreter.jl

This file was deleted.

Loading
Loading