From 91769d7b494ed8077019db96e806005e9e0136e1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ga=C3=ABtan=20LOUNES?= Date: Thu, 16 Oct 2025 15:16:28 +0200 Subject: [PATCH 1/3] GPUCompiler + custom JIT --- Project.toml | 2 +- ext/ReactantCUDAExt.jl | 5 +- ext/ReactantNNlibExt/Overlay.jl | 10 +- ext/ReactantZygoteExt.jl | 2 +- src/Interpreter.jl | 110 ---- src/JIT.jl | 632 +++++++++++++++++++++++ src/Overlay.jl | 36 +- src/Precompile.jl | 49 +- src/Reactant.jl | 4 +- src/utils.jl | 872 -------------------------------- 10 files changed, 662 insertions(+), 1060 deletions(-) delete mode 100644 src/Interpreter.jl create mode 100644 src/JIT.jl diff --git a/Project.toml b/Project.toml index fbde73622a..800ac8a551 100644 --- a/Project.toml +++ b/Project.toml @@ -12,6 +12,7 @@ Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" +GPUCompiler = "61eb1bfa-7361-4325-ad38-22787b887f55" HTTP = "cd3eb016-35fb-5094-929b-558a96fad6f3" LLVM = "929cbde3-209d-540e-8aea-75f648917ca0" LLVMOpenMP_jll = "1d63c593-3942-5779-bab2-d838dc0a180e" @@ -35,7 +36,6 @@ CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" DLFP8Types = "f4c16678-4a16-415b-82ef-ed337c5d6c7c" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" Float8s = "81dfefd7-55b0-40c6-a251-db853704e186" -GPUCompiler = "61eb1bfa-7361-4325-ad38-22787b887f55" KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" diff --git a/ext/ReactantCUDAExt.jl b/ext/ReactantCUDAExt.jl index bd9b55edde..2fa3bb6f3e 100644 --- a/ext/ReactantCUDAExt.jl +++ b/ext/ReactantCUDAExt.jl @@ -617,7 +617,7 @@ end f::LLVMFunc{F,tt}; shmem::Union{Integer,Base.Callable}=0, max_threads::Integer=0 ) where {F,tt} return CUDA.launch_configuration( - Base.inferencebarrier(CUDA.cufunction)(f.f, Tuple{tt.parameters[2:end]...}).fun; + call_with_native(CUDA.cufunction, f.f, Tuple{tt.parameters[2:end]...}).fun; shmem, max_threads, ) @@ -1465,7 +1465,7 @@ end @static if !Sys.isapple() @setup_workload begin Reactant.initialize_dialect() - + init_jit() if Reactant.XLA.REACTANT_XLA_RUNTIME == "PJRT" client = Reactant.XLA.PJRT.CPUClient(; checkcount=false) elseif Reactant.XLA.REACTANT_XLA_RUNTIME == "IFRT" @@ -1504,7 +1504,6 @@ end Reactant.XLA.free_client(client) client.client = C_NULL Reactant.deinitialize_dialect() - Reactant.clear_oc_cache() end end diff --git a/ext/ReactantNNlibExt/Overlay.jl b/ext/ReactantNNlibExt/Overlay.jl index 24c5ddb372..e899474efe 100644 --- a/ext/ReactantNNlibExt/Overlay.jl +++ b/ext/ReactantNNlibExt/Overlay.jl @@ -2,7 +2,7 @@ if any(Reactant.use_overlayed_version, (y, x, w)) overloaded_conv!(y, x, w, cdims; kwargs...) else - Base.inferencebarrier(NNlib.conv!)(y, x, w, cdims; kwargs...) + call_with_native(NNlib.conv!, y, x, w, cdims; kwargs...) end end @@ -10,7 +10,7 @@ end if any(Reactant.use_overlayed_version, (y, x)) overloaded_maxpool!(y, x, pdims; kwargs...) else - Base.inferencebarrier(NNlib.maxpool!)(y, x, pdims; kwargs...) + call_with_native(NNlib.maxpool!, y, x, pdims; kwargs...) end end @@ -18,7 +18,7 @@ end if any(Reactant.use_overlayed_version, (y, x)) overloaded_meanpool!(y, x, pdims; kwargs...) else - Base.inferencebarrier(NNlib.meanpool!)(y, x, pdims; kwargs...) + call_with_native(NNlib.meanpool!, y, x, pdims; kwargs...) end end @@ -28,7 +28,7 @@ end if any(Reactant.use_overlayed_version, (dw, x, dy)) overloaded_∇conv_filter!(dw, x, dy, cdims; kwargs...) else - Base.inferencebarrier(NNlib.∇conv_filter!)(dw, x, dy, cdims; kwargs...) + call_with_native(NNlib.∇conv_filter!, dw, x, dy, cdims; kwargs...) end end @@ -38,6 +38,6 @@ end if any(Reactant.use_overlayed_version, (dx, dy, w)) overloaded_∇conv_data!(dx, dy, w, cdims; kwargs...) else - Base.inferencebarrier(NNlib.∇conv_data!)(dx, dy, w, cdims; kwargs...) + call_with_native(NNlib.∇conv_data!, dx, dy, w, cdims; kwargs...) end end diff --git a/ext/ReactantZygoteExt.jl b/ext/ReactantZygoteExt.jl index f50229ea17..2ce5f29e77 100644 --- a/ext/ReactantZygoteExt.jl +++ b/ext/ReactantZygoteExt.jl @@ -22,7 +22,7 @@ using Enzyme: Enzyme, Reverse, Active, Const, Duplicated and hence reliance on this behavior is strongly discouraged." return Enzyme.gradient(Reverse, Const(f), args...) else - return Base.inferencebarrier(Zygote.gradient)(CallWithReactant(f), args...) + return call_with_native(Zygote.gradient, CallWithReactant(f), args...) end end diff --git a/src/Interpreter.jl b/src/Interpreter.jl deleted file mode 100644 index 6ffcec0ab3..0000000000 --- a/src/Interpreter.jl +++ /dev/null @@ -1,110 +0,0 @@ -# Taken from https://github.com/JuliaLang/julia/pull/52964/files#diff-936d33e524bcd097015043bd6410824119be5c210d43185c4d19634eb4912708 -# Other references: -# - https://github.com/JuliaLang/julia/blob/0fd1f04dc7d4b905b0172b7130e9b1beab9bc4c9/test/compiler/AbstractInterpreter.jl#L228-L234 -# - https://github.com/JuliaLang/julia/blob/v1.10.4/test/compiler/newinterp.jl#L9 - -const CC = Core.Compiler -using Enzyme - -import Core.Compiler: - AbstractInterpreter, - abstract_call, - abstract_call_known, - ArgInfo, - StmtInfo, - AbsIntState, - get_max_methods, - CallMeta, - Effects, - NoCallInfo, - MethodResultPure - -Base.Experimental.@MethodTable(REACTANT_METHOD_TABLE) - -function var"@reactant_overlay"(__source__::LineNumberNode, __module__::Module, def) - return Base.Experimental.var"@overlay"( - __source__, __module__, :(Reactant.REACTANT_METHOD_TABLE), def - ) -end - -function set_reactant_abi( - interp, - @nospecialize(f), - arginfo::ArgInfo, - si::StmtInfo, - sv::AbsIntState, - max_methods::Int=get_max_methods(interp, f, sv), -) - (; fargs, argtypes) = arginfo - - if f === ReactantCore.within_compile - if length(argtypes) != 1 - @static if VERSION < v"1.11.0-" - return CallMeta(Union{}, Effects(), NoCallInfo()) - else - return CallMeta(Union{}, Union{}, Effects(), NoCallInfo()) - end - end - @static if VERSION < v"1.11.0-" - return CallMeta( - Core.Const(true), Core.Compiler.EFFECTS_TOTAL, MethodResultPure() - ) - else - return CallMeta( - Core.Const(true), Union{}, Core.Compiler.EFFECTS_TOTAL, MethodResultPure() - ) - end - end - - # Improve inference by considering call_with_reactant as having the same results as - # the original call - if f === call_with_reactant - arginfo2 = ArgInfo(fargs isa Nothing ? nothing : fargs[2:end], argtypes[2:end]) - return abstract_call(interp, arginfo2::ArgInfo, si, sv, max_methods) - end - - return Base.@invoke abstract_call_known( - interp::AbstractInterpreter, - f::Any, - arginfo::ArgInfo, - si::StmtInfo, - sv::AbsIntState, - max_methods::Int, - ) -end - -@static if Enzyme.GPUCompiler.HAS_INTEGRATED_CACHE - struct ReactantCacheToken end - - function ReactantInterpreter(; world::UInt=Base.get_world_counter()) - return Enzyme.Compiler.Interpreter.EnzymeInterpreter( - ReactantCacheToken(), - REACTANT_METHOD_TABLE, - world, - false, #=forward_rules=# - false, #=reverse_rules=# - false, #=inactive_rules=# - false, #=broadcast_rewrite=# - false, #=within_autodiff_rewrite=# - set_reactant_abi, - ) - end -else - const REACTANT_CACHE = Enzyme.GPUCompiler.CodeCache() - - function ReactantInterpreter(; - world::UInt=Base.get_world_counter(), code_cache=REACTANT_CACHE - ) - return Enzyme.Compiler.Interpreter.EnzymeInterpreter( - REACTANT_CACHE, - REACTANT_METHOD_TABLE, - world, - false, #=forward_rules=# - false, #=reverse_rules=# - false, #=inactive_rules=# - false, #=broadcast_rewrite=# - false, #=within_autodiff_rewrite=# - set_reactant_abi, - ) - end -end diff --git a/src/JIT.jl b/src/JIT.jl new file mode 100644 index 0000000000..d86d54c727 --- /dev/null +++ b/src/JIT.jl @@ -0,0 +1,632 @@ +using GPUCompiler +CC = Core.Compiler + +#leak each argument to a global variable +macro lk(args...) + quote + $([:( + let val = $(esc(p)) + global $(esc(p)) = val + end + ) for p in args]...) + end +end + +Base.Experimental.@MethodTable(REACTANT_METHOD_TABLE) + +function var"@reactant_overlay"(__source__::LineNumberNode, __module__::Module, def) + return Base.Experimental.var"@overlay"( + __source__, __module__, :(Reactant.REACTANT_METHOD_TABLE), def + ) +end + +function call_with_reactant() end + +@noinline call_with_native(@nospecialize(f), @nospecialize(args...)) = + Base.inferencebarrier(f)(args...) + +const __skip_rewrite_func_set = Set([ + typeof(call_with_reactant), + typeof(call_with_native), + typeof(task_local_storage), + typeof(getproperty), + typeof(invokelatest), +]) +const __skip_rewrite_func_set_lock = ReentrantLock() + +""" + @skip_rewrite_func f + +Mark function `f` so that Reactant's IR rewrite mechanism will skip it. +This can improve compilation time if it's safe to assume that no call inside `f` +will need a `@reactant_overlay` method. + +!!! info + Note that this marks the whole function, not a specific method with a type + signature. + +!!! warning + The macro call should be inside the `__init__` function. If you want to + mark it for precompilation, you must add the macro call in the global scope + too. + +See also: [`@skip_rewrite_type`](@ref) +""" +macro skip_rewrite_func(fname) + quote + @lock $(Reactant.__skip_rewrite_func_set_lock) push!( + $(Reactant.__skip_rewrite_func_set), typeof($(esc(fname))) + ) + end +end + +const __skip_files = Set([Symbol("sysimg.jl"), Symbol("boot.jl")]) + +struct CompilerParams <: AbstractCompilerParams + function CompilerParams() + return new() + end +end + +@kwdef struct MetaData end + +@kwdef struct DebugData + enable_log::Bool = true + enable_runtime_log::Bool = true + rewrite_call::Set = Set() + non_rewrite_call::Set = Set() +end + +struct ReactantToken end + +@kwdef struct ReactantInterpreter <: CC.AbstractInterpreter + token::ReactantToken = ReactantToken() + # Cache of inference results for this particular interpreter + local_cache::Vector{CC.InferenceResult} = CC.InferenceResult[] + # The world age we're working inside of + world::UInt = Base.get_world_counter() + + # Parameters for inference and optimization + inf_params::CC.InferenceParams = CC.InferenceParams() + opt_params::CC.OptimizationParams = CC.OptimizationParams() + + meta_data::Ref{MetaData} = Ref(MetaData()) + debug_data::Ref{DebugData} = Ref(DebugData()) +end + +log(interp::ReactantInterpreter)::Bool = interp.debug_data[].enable_log +runtime_log(interp::ReactantInterpreter)::Bool = interp.debug_data[].enable_runtime_log +reset_debug_data(interp::ReactantInterpreter) = interp.debug_data[] = DebugData(); + +NativeCompilerJob = CompilerJob{NativeCompilerTarget,CompilerParams} +GPUCompiler.can_throw(@nospecialize(job::NativeCompilerJob)) = true +function GPUCompiler.method_table(@nospecialize(job::NativeCompilerJob)) + return CC.method_table(GPUCompiler.get_interpreter(job)) +end + +current_interpreter = Ref{Union{Nothing,ReactantInterpreter}}(nothing) + +function GPUCompiler.get_interpreter(@nospecialize(job::NativeCompilerJob)) + isnothing(current_interpreter[]) && + (return current_interpreter[] = ReactantInterpreter(; world=job.world)) + + if job.world == current_interpreter[].world + current_interpreter[] + else + (; meta_data, debug_data) = current_interpreter[] + current_interpreter[] = ReactantInterpreter(; + world=job.world, meta_data, debug_data + ) + end +end + +@noinline barrier(@nospecialize(x), @nospecialize(T::Type = Any)) = + Core.Compiler.inferencebarrier(x)::T + +CC.InferenceParams(@nospecialize(interp::ReactantInterpreter)) = interp.inf_params +CC.OptimizationParams(@nospecialize(interp::ReactantInterpreter)) = interp.opt_params +CC.get_inference_world(@nospecialize(interp::ReactantInterpreter)) = interp.world +CC.get_inference_cache(@nospecialize(interp::ReactantInterpreter)) = interp.local_cache +CC.cache_owner(@nospecialize(interp::ReactantInterpreter)) = interp.token +function CC.method_table(@nospecialize(interp::ReactantInterpreter)) + return CC.OverlayMethodTable(CC.get_inference_world(interp), REACTANT_METHOD_TABLE) +end + +function has_ancestor(query::Module, target::Module) + query == target && return true + while true + next = parentmodule(query) + next == target && return true + next == query && return false + query = next + end +end +is_base_or_core(t::TypeVar) = begin + println("TypeVar ", t) + return false +end +is_base_or_core(t::Core.TypeofVararg) = is_base_or_core(t.T) +is_base_or_core(m::Module) = has_ancestor(m, Core) || has_ancestor(m, Base) +is_base_or_core(@nospecialize(u::Union)) = begin + u == Union{} && return true + is_base_or_core(u.a) && is_base_or_core(u.b) +end +is_base_or_core(u::UnionAll) = is_base_or_core(Base.unwrap_unionall(u)) +is_base_or_core(@nospecialize(ty::Type)) = is_base_or_core(parentmodule(ty)) + +function skip_rewrite(mi::Core.MethodInstance)::Bool + mod = mi.def.module + mi.def.file in __skip_files && return true + @lk mi + ft = Base.unwrap_unionall(mi.specTypes).parameters[1] + ft in __skip_rewrite_func_set && return true + + ( + has_ancestor(mod, Reactant.Ops) || + has_ancestor(mod, Reactant.TracedUtils) || + has_ancestor(mod, Reactant.MLIR) + ) && return true + + if is_base_or_core(mod) + modules = is_base_or_core.(Base.unwrap_unionall(mi.specTypes).parameters[2:end]) + all(modules) && return true + end + return false +end + +disable_call_with_reactant = false +vv = [] +vb = [] +@inline function typeinf_local(interp::CC.AbstractInterpreter, frame::CC.InferenceState) + @invoke CC.typeinf_local(interp::CC.AbstractInterpreter, frame) +end + +function CC.typeinf_local(interp::ReactantInterpreter, frame::CC.InferenceState) + mi = frame.linfo + global disable_call_with_reactant + disable_cwr = disable_call_with_reactant ? false : skip_rewrite(mi) + disable_cwr && (disable_call_with_reactant = true) + disable_call_with_reactant || push!(vb, (mi, CC.copy(frame.src))) + tl = typeinf_local(interp, frame) + disable_call_with_reactant || push!(vv, (mi, CC.copy(frame.src))) + disable_cwr && (disable_call_with_reactant = false) + return tl +end + +lead_to_dynamic_call(@nospecialize(ty)) = begin + isconcretetype(ty) && return false + ty == Union{} && return false + Base.isvarargtype(ty) && return true + (ty <: Type || ty <: Tuple) && return false + return true +end + +# Rewrite type unstable calls to recurse into call_with_reactant to ensure +# they continue to use our interpreter. +function need_rewrite_call(interp, @nospecialize(fn), @nospecialize(args)) + #UnionAll constructor cannot get a singleton type, and are not handled by the call_with_reactant macro: degradate type inference + isnothing(fn) && return false + #ignore constructor + fn isa Type && return false + + ft = typeof(fn) + (ft <: Core.IntrinsicFunction || ft <: Core.Builtin) && return false + ft in __skip_rewrite_func_set && return false + #Base.isstructtype(ft) && return false + if hasfield(typeof(ft), :name) && hasfield(typeof(ft.name), :module) + mod = ft.name.module + # Don't rewrite primitive ops, tracing utilities, or any MLIR-based functions + if has_ancestor(mod, Reactant.Ops) || + has_ancestor(mod, Reactant.TracedUtils) || + has_ancestor(mod, Reactant.MLIR) || + has_ancestor(mod, Core.Compiler) + return false + end + end + #ft isa Type && any(t -> ft <: t, __skip_rewrite_type_constructor_list) && return false + #ft in __skip_rewrite_func_set && return false + + #ft<: typeof(Core.kwcall) && return true + tt = Tuple{ft,args...} + match = CC._findsup(tt, REACTANT_METHOD_TABLE, CC.get_inference_world(interp))[1] + !isnothing(match) && return true + match = CC._findsup(tt, nothing, CC.get_inference_world(interp))[1] + isnothing(match) && return true + startswith( + string(match.method.name), "#(overlay (. Reactant (inert REACTANT_METHOD_TABLE))" + ) && return false + + # Avoid recursively interpreting into methods we define explicitly + # as overloads, which we assume should handle the entirety of the + # translation (and if not they can use call_in_reactant). + isdefined(match.method, :external_mt) && + match.method.external_mt === REACTANT_METHOD_TABLE && + return false + + match.method.file in __skip_files && return false + + #Dynamic dispatch handler + types = if match.method.nospecialize != 0 + match.method.sig + else + mi = CC.specialize_method(match) + mi.specTypes + end + + mask = lead_to_dynamic_call.(Base.unwrap_unionall(types).parameters) + #@error string(ft) mask types + return any(mask) +end + +function CC.abstract_eval_call( + interp::ReactantInterpreter, + e::Expr, + vtypes::Union{CC.VarTable,Nothing}, + sv::CC.AbsIntState, +) + if !(sv isa CC.IRInterpretationState) #during type inference, rewrite dynamic call with call_with_reactant + global disable_call_with_reactant + if !disable_call_with_reactant + argtypes = CC.collect_argtypes(interp, e.args, vtypes, sv) + args = CC.argtypes_to_type(argtypes).parameters + fn = CC.singleton_type(argtypes[1]) + if need_rewrite_call(interp, fn, args[2:end]) + @error fn string(argtypes) sv.linfo + log(interp) && push!( + interp.debug_data[].rewrite_call, + (fn, args[2:end], sv.linfo), #CC.copy(sv.src) + ) + e = Expr(:call, GlobalRef(@__MODULE__, :call_with_reactant), e.args...) + expr = sv.src.code[sv.currpc] + sv.src.code[sv.currpc] = if expr.head == :call + e + else + @assert expr.head == :(=) #CodeInfo slot write + Expr(:(=), expr.args[1], e) + end + end + else + log(interp) && push!( + interp.debug_data[].non_rewrite_call, + (sv.linfo, CC.collect_argtypes(interp, e.args, vtypes, sv)), + ) + end + end + + return @invoke CC.abstract_eval_call( + interp::CC.AbstractInterpreter, + e::Expr, + vtypes::Union{CC.VarTable,Nothing}, + sv::CC.AbsIntState, + ) +end + +using LLVM, LLVM.Interop + +struct CompilerInstance + lljit::LLVM.JuliaOJIT + lctm::LLVM.LazyCallThroughManager + ism::LLVM.IndirectStubsManager +end +const jit = Ref{CompilerInstance}() + +function get_trampoline(job) + (; lljit, lctm, ism) = jit[] + jd = JITDylib(lljit) + + target_sym = String(gensym(string(job.source))) + + # symbol flags (callable + exported) + flags = LLVM.API.LLVMJITSymbolFlags( + LLVM.API.LLVMJITSymbolGenericFlagsCallable | + LLVM.API.LLVMJITSymbolGenericFlagsExported, + 0, + ) + + sym = Ref(LLVM.API.LLVMOrcCSymbolFlagsMapPair(mangle(lljit, target_sym), flags)) + + # materialize callback: compile/emit module when symbols requested + function materialize(mr) + JuliaContext() do ctx + ir, meta = GPUCompiler.compile(:llvm, job; validate=false) + runtime_log(GPUCompiler.get_interpreter(job)) && @warn "materialize" job + @lk ir + # Ensure the module's entry has the target name we declared + LLVM.name!(meta.entry, target_sym) + r_symbols = string.(LLVM.get_requested_symbols(mr)) + #expose only the function defined in job + for f in LLVM.functions(ir) + isempty(LLVM.blocks(f)) && continue #declare functions + LLVM.name(f) in r_symbols && continue + LLVM.linkage!(f, LLVM.API.LLVMPrivateLinkage) + end + + #convert global alias to private linkage in order to not be relocatable + for g in LLVM.globals(ir) + ua = LLVM.API.LLVMGetUnnamedAddress(g) + (ua == LLVM.API.LLVMLocalUnnamedAddr || ua == LLVM.API.LLVMNoUnnamedAddr) || + continue + LLVM.isconstant(g) && continue + LLVM.API.LLVMSetUnnamedAddress(g, LLVM.API.LLVMNoUnnamedAddr) + LLVM.linkage!(g, LLVM.API.LLVMPrivateLinkage) + end + # serialize the module IR into a memory buffer + buf = convert(MemoryBuffer, ir) + # deserialize under a thread-safe context and emit via IRCompileLayer + ThreadSafeContext() do ts_ctx + tsm = context!(context(ts_ctx)) do + mod = parse(LLVM.Module, buf) + ThreadSafeModule(mod) + end + + il = LLVM.IRCompileLayer(lljit) + # Emit the ThreadSafeModule for the responsibility mr. + LLVM.emit(il, mr, tsm) + end + end + return nothing + end + + # discard callback (no-op for now) + function discard(jd_arg, sym) + @error "discard" sym + end + + # Create a single CustomMaterializationUnit that declares both entry and target. + # Name it something descriptive (e.g., the entry_sym) + mu = LLVM.CustomMaterializationUnit("MU_" * target_sym, sym, materialize, discard) + + # Define the MU in the JITDylib (declares the symbols as owned by this MU) + LLVM.define(jd, mu) + + # Lookup the entry address (this will trigger materialize if needed) + addr = lookup(lljit, target_sym) + return addr +end +import GPUCompiler: deferred_codegen_jobs + +function ccall_deferred(ptr::Ptr{Cvoid}) + return ccall("extern deferred_codegen", llvmcall, Ptr{Cvoid}, (Ptr{Cvoid},), ptr) +end + +""" + Reactant.REDUB_ARGUMENTS_NAME + +The variable name bound to `call_with_reactant`'s tuple of arguments in its +`@generated` method definition. + +This binding can be used to manually reference/destructure `call_with_reactants` arguments + +This is required because user arguments could have a name which clashes with whatever name we choose for +our argument. Thus we gensym to create it. + +This originates from + https://github.com/JuliaLabs/Cassette.jl/blob/c29b237c1ec0deda3a1037ec519eebe216952bfe/src/overdub.jl#L154 + https://github.com/JuliaGPU/GPUCompiler.jl/blob/master/examples/jit.jl +""" +const REDUB_ARGUMENTS_NAME = gensym("redub_arguments") + +function deferred_call_with_reactant( + world::UInt, source::LineNumberNode, self, @nospecialize(args) +) + f = args[1] + tt = Tuple{f,args[2:end]...} + match = CC._findsup(tt, REACTANT_METHOD_TABLE, world) + match = isnothing(match[1]) ? CC._findsup(tt, nothing, world) : match + + stub = Core.GeneratedFunctionStub( + identity, Core.svec(:call_with_reactant, REDUB_ARGUMENTS_NAME), Core.svec() + ) + + if isnothing(match[1]) + method_error = :(throw( + MethodError($REDUB_ARGUMENTS_NAME[1], $REDUB_ARGUMENTS_NAME[2:end], $world) + )) + return stub(world, source, method_error) + end + + mi = CC.specialize_method(match[1]) + + target = NativeCompilerTarget(; jlruntime=true, llvm_always_inline=false) + config = CompilerConfig( + target, + CompilerParams(); + kernel=false, + libraries=false, + toplevel=true, + validate=false, + strip=false, + optimize=true, + entry_abi=:func, + ) + job = CompilerJob(mi, config, world) + interp = GPUCompiler.get_interpreter(job) + + ci = CC.typeinf_ext(interp, mi) + @assert !isnothing(ci) + rt = ci.rettype + @lk ci job + runtime_log(interp) && @warn "ci rt" job ci rt + + addr = get_trampoline(job) + trampoline = pointer(addr) + id = Base.reinterpret(Int, trampoline) + + deferred_codegen_jobs[id] = job + + #build CodeInfo directly + code_info = begin + ir = CC.IRCode() + src = @ccall jl_new_code_info_uninit()::Ref{CC.CodeInfo} + src.slotnames = fill(:none, length(ir.argtypes) + 1) + src.slotflags = fill(zero(UInt8), length(ir.argtypes)) + src.slottypes = copy(ir.argtypes) + src.rettype = UInt64 + CC.ir_to_codeinf!(src, ir) + end + + overdubbed_code = Any[] + overdubbed_codelocs = Int32[] + function push_inst!(inst) + push!(overdubbed_code, inst) + push!(overdubbed_codelocs, code_info.codelocs[1]) + return Core.SSAValue(length(overdubbed_code)) + end + code_info.edges = Core.MethodInstance[job.source] + code_info.rettype = rt + + ptr = push_inst!(Expr(:call, :ccall_deferred, trampoline)) + + fn_args = [] + for i in 2:length(args) + named_tuple_ssa = Expr( + :call, Core.GlobalRef(Core, :getfield), Core.SlotNumber(2), i + ) + arg = push_inst!(named_tuple_ssa) + push!(fn_args, arg) + end + + f_arg = push_inst!(Expr(:call, Core.GlobalRef(Core, :getfield), Core.SlotNumber(2), 1)) + + args_vec = push_inst!( + Expr(:call, GlobalRef(Base, :getindex), GlobalRef(Base, :Any), fn_args...) + ) + + runtime_log(interp) && push_inst!( + Expr( + :call, + GlobalRef(Base, :println), + "before call_with_reactant ", + f_arg, + "(", + args_vec, + ")", + ), + ) + preserve = push_inst!(Expr(:gc_preserve_begin, args_vec)) + args_vec = push_inst!(Expr(:call, GlobalRef(Base, :pointer), args_vec)) + n_args = length(fn_args) + + #Use ccall internal directly to call the wrapped llvm function + result = push_inst!( + Expr( + :foreigncall, + ptr, + Ptr{rt}, + Core.svec(Any, Ptr{Any}, Int), + 0, + QuoteNode(:ccall), + f_arg, + args_vec, + n_args, + n_args, + args_vec, + f_arg, + ), + ) + + result = push_inst!(Expr(:call, GlobalRef(Base, :unsafe_pointer_to_objref), result)) + push_inst!(Expr(:gc_preserve_end, preserve)) + result = push_inst!(Expr(:call, GlobalRef(@__MODULE__, :barrier), result, rt)) + runtime_log(interp) && push_inst!( + Expr( + :call, + GlobalRef(Base, :println), + "after call_with_reactant ", + f_arg, + " ", + result, + ), + ) + push_inst!(Core.ReturnNode(result)) + + code_info.min_world = typemin(UInt) + code_info.max_world = typemax(UInt) + code_info.slotnames = Any[:call_with_reactant_, REDUB_ARGUMENTS_NAME] + code_info.slotflags = UInt8[0x00, 0x00] + code_info.code = overdubbed_code + code_info.codelocs = overdubbed_codelocs + code_info.ssavaluetypes = length(overdubbed_code) + code_info.ssaflags = [0x00 for _ in 1:length(overdubbed_code)] + return code_info +end + +@eval function call_with_reactant($(REDUB_ARGUMENTS_NAME)...) + $(Expr(:meta, :generated_only)) + return $(Expr(:meta, :generated, deferred_call_with_reactant)) +end +const jd_main = Ref{Any}() +function init_jit() + lljit = JuliaOJIT() + jd_main[] = JITDylib(lljit) + prefix = LLVM.get_prefix(lljit) + + dg = LLVM.CreateDynamicLibrarySearchGeneratorForProcess(prefix) + add!(jd_main[], dg) + + es = ExecutionSession(lljit) + + lctm = LLVM.LocalLazyCallThroughManager(triple(lljit), es) + ism = LLVM.LocalIndirectStubsManager(triple(lljit)) + + jit[] = CompilerInstance(lljit, lctm, ism) + atexit() do + (; lljit, lctm, ism) = jit[] + dispose(ism) + dispose(lctm) + dispose(lljit) + end +end + +function ir_to_codeinfo!(ir::CC.IRCode)::CC.CodeInfo + code_info = begin + src = ccall(:jl_new_code_info_uninit, Ref{CC.CodeInfo}, ()) + src.slotnames = fill(:none, length(ir.argtypes) + 1) + src.slotflags = fill(zero(UInt8), length(ir.argtypes)) + src.slottypes = copy(ir.argtypes) + src.rettype = Int + CC.ir_to_codeinf!(src, ir) + src.ssavaluetypes = length(src.ssavaluetypes) + src + end + return code_info +end + +struct FakeOc + f::Vector + ci::Vector{CC.CodeInfo} +end + +fake_oc_dict = FakeOc([], []) + +fake_oc(ir::CC.IRCode, return_type=Any) = begin + src = ir_to_codeinfo!(ir) + fake_oc(src, return_type) +end + +function fake_oc(src::CC.CodeInfo, return_type=Any; args=nothing) + @assert !isnothing(current_interpreter[]) + types = isnothing(args) ? src.slottypes[2:end] : args + global fake_oc_dict + index = findfirst(==(src), fake_oc_dict.ci) + !isnothing(index) && return fake_oc_dict.f[index] + + expr = (Expr(:(::), Symbol("arg_$i"), type) for (i, type) in enumerate(types)) + args = Expr(:tuple, (Symbol("arg_$i") for (i, type) in enumerate(types))...) + fn_name = gensym(:fake_oc) + call_expr = Expr(:call, fn_name, expr...) + f_expr = Expr( + :(=), + call_expr, + quote + Reactant.barrier($args, $return_type) + end, + ) + f = @eval @noinline $f_expr + mi = Base.method_instance(f, types) + @assert !isnothing(mi) + mi.def.source = CC.maybe_compress_codeinfo(current_interpreter[], mi, src) + push!(fake_oc_dict.f, f) + push!(fake_oc_dict.ci, src) + return f +end diff --git a/src/Overlay.jl b/src/Overlay.jl index 929df7f51d..7aad933d35 100644 --- a/src/Overlay.jl +++ b/src/Overlay.jl @@ -62,7 +62,7 @@ for randfun in (:rand, :randn, :randexp) end @warn "Reactant doesn't support sampling of $(T) with the current \ interpreter. Falling back to native interpreter." maxlog = 1 - return Base.inferencebarrier(Random.$(randfun))(rng, T, dims) + return call_with_native(Random.$(randfun), rng, T, dims) end @reactant_overlay @noinline function Random.$(randfun)( @@ -81,7 +81,7 @@ for randfun in (:rand, :randn, :randexp) end @warn "Reactant doesn't support sampling of $(T) with the current \ interpreter. Falling back to native interpreter." maxlog = 1 - return Base.inferencebarrier(Random.$(randfun))(rng, T, dim1, dims...) + return call_with_native(Random.$(randfun), rng, T, dim1, dims...) end # scalars @@ -93,7 +93,7 @@ for randfun in (:rand, :randn, :randexp) end @warn "Reactant doesn't support sampling of $(T) with the current \ interpreter. Falling back to native interpreter." maxlog = 1 - return Base.inferencebarrier(Random.$(randfun))(rng, T) + return call_with_native(Random.$(randfun), rng, T) end # inplace @@ -130,7 +130,7 @@ for (cT, aT, bT) in ( # Inference barrier is required when calling function recursively within # overload. This is required since otherwise type inference will think this # is a recursive edge rather than a call to the base method - Base.inferencebarrier(LinearAlgebra.mul!)(C, A, B, α, β) + call_with_native(LinearAlgebra.mul!, C, A, B, α, β) end return C end @@ -155,7 +155,7 @@ end # Inference barrier is required when calling function recursively within # overload. This is required since otherwise type inference will think this is # a recursive edge rather than a call to the base method - return Base.inferencebarrier(Base._stack)(dims, iter2) + return call_with_native(Base._stack, dims, iter2) end end end @@ -165,7 +165,7 @@ end if use_overlayed_version(A) error("Reactant doesn't have a `Base._unique_dims` with the current interpreter.") else - Base.inferencebarrier(Base._unique_dims)(A, dims) + call_with_native(Base._unique_dims, A, dims) end end @@ -178,8 +178,8 @@ end if use_overlayed_version(A) return TracedRArrayOverrides.overloaded_mapreduce(f, op, A; kwargs...) else - return Base.inferencebarrier(Base.mapreduce)( - CallWithReactant(f), CallWithReactant(op), A; kwargs... + return call_with_native( + Base.mapreduce, CallWithReactant(f), CallWithReactant(op), A; kwargs... ) end end @@ -188,7 +188,7 @@ end if use_overlayed_version(x) || looped_any(use_overlayed_version, ys) return TracedRArrayOverrides.overloaded_map(f, x, ys...) else - return Base.inferencebarrier(Base.map)(CallWithReactant(f), x, ys...) + return call_with_native(Base.map, CallWithReactant(f), x, ys...) end end @@ -202,7 +202,7 @@ end ) return TracedRArrayOverrides.overloaded_map!(f, y, x, xs...) else - return Base.inferencebarrier(Base.map!)(CallWithReactant(f), y, x, xs...) + return call_with_native(Base.map!, CallWithReactant(f), y, x, xs...) end end @@ -210,7 +210,7 @@ end if use_overlayed_version(x) return TracedRArrayOverrides.overloaded_mapreduce(f, &, x; dims) else - return Base.inferencebarrier(Base._all)(CallWithReactant(f), x, dims) + return call_with_native(Base._all, CallWithReactant(f), x, dims) end end @@ -218,7 +218,7 @@ end if use_overlayed_version(x) return TracedRArrayOverrides.overloaded_mapreduce(f, |, x; dims) else - return Base.inferencebarrier(Base._any)(CallWithReactant(f), x, dims) + return call_with_native(Base._any, CallWithReactant(f), x, dims) end end @@ -227,7 +227,7 @@ end if use_overlayed_version(x) return TracedLinearAlgebra.overloaded_lu(x, RowMaximum(); kwargs...) else - return Base.inferencebarrier(LinearAlgebra.lu)(x; kwargs...) + return call_with_native(LinearAlgebra.lu, x; kwargs...) end end @reactant_overlay @noinline function LinearAlgebra.lu( @@ -236,14 +236,14 @@ end if use_overlayed_version(x) return TracedLinearAlgebra.overloaded_lu(x, pivot; kwargs...) else - return Base.inferencebarrier(LinearAlgebra.lu)(x, pivot; kwargs...) + return call_with_native(LinearAlgebra.lu, x, pivot; kwargs...) end end @reactant_overlay @noinline function LinearAlgebra.lu!(x::AbstractArray; kwargs...) if use_overlayed_version(x) return TracedLinearAlgebra.overloaded_lu(x, RowMaximum(); kwargs...) else - return Base.inferencebarrier(LinearAlgebra.lu!)(x; kwargs...) + return call_with_native(LinearAlgebra.lu!, x; kwargs...) end end @reactant_overlay @noinline function LinearAlgebra.lu!( @@ -252,7 +252,7 @@ end if use_overlayed_version(x) return TracedLinearAlgebra.overloaded_lu(x, pivot; kwargs...) else - return Base.inferencebarrier(LinearAlgebra.lu!)(x, pivot; kwargs...) + return call_with_native(LinearAlgebra.lu!, x, pivot; kwargs...) end end @@ -260,7 +260,7 @@ end if use_overlayed_version(x) || use_overlayed_version(y) return TracedLinearAlgebra.overloaded_dot(x, y) else - return Base.inferencebarrier(LinearAlgebra.dot)(x, y) + return call_with_native(LinearAlgebra.dot, x, y) end end @reactant_overlay @noinline function LinearAlgebra.dot( @@ -269,6 +269,6 @@ end if use_overlayed_version(x) || use_overlayed_version(A) || use_overlayed_version(y) return TracedLinearAlgebra.overloaded_dot(x, A, y) else - return Base.inferencebarrier(LinearAlgebra.dot)(x, A, y) + return call_with_native(LinearAlgebra.dot, x, A, y) end end diff --git a/src/Precompile.jl b/src/Precompile.jl index c467bf7fd5..5ebd31f999 100644 --- a/src/Precompile.jl +++ b/src/Precompile.jl @@ -1,52 +1,5 @@ using PrecompileTools: @setup_workload, @compile_workload -function infer_sig(sig) - interp = ReactantInterpreter() - - min_world = Ref{UInt}(typemin(UInt)) - max_world = Ref{UInt}(typemax(UInt)) - - lookup_result = lookup_world( - sig, interp.world, Core.Compiler.method_table(interp), min_world, max_world - ) - match = lookup_result::Core.MethodMatch - # look up the method and code instance - mi = ccall( - :jl_specializations_get_linfo, - Ref{Core.MethodInstance}, - (Any, Any, Any), - match.method, - match.spec_types, - match.sparams, - ) - - @static if VERSION < v"1.11" - # For older Julia versions, we vendor in some of the code to prevent - # having to build the MethodInstance twice. - result = CC.InferenceResult(mi, CC.typeinf_lattice(interp)) - frame = CC.InferenceState(result, :no, interp) - @assert !isnothing(frame) - CC.typeinf(interp, frame) - ir = CC.run_passes(frame.src, CC.OptimizationState(frame, interp), result, nothing) - rt = CC.widenconst(CC.ignorelimited(result.result)) - else - ir, rt = CC.typeinf_ircode(interp, mi, nothing) - end -end - -function clear_oc_cache() - # Opaque closures capture the worldage of their compilation and thus are not relocatable - # Therefore we explicitly purge all OC's we have created here - for v in oc_capture_vec - if v isa Base.RefValue - p = Ptr{Ptr{Cvoid}}(pointer_from_objref(v)) - Base.atomic_pointerset(p, C_NULL, :monotonic) - else - empty!(v) - end - end -end - # Precompilation on 1.10 hits an apparent bug: https://github.com/JuliaLang/julia/issues/56947 function precompilation_supported() return (VERSION >= v"1.11" || VERSION >= v"1.10.8") && (VERSION < v"1.12-") @@ -54,6 +7,7 @@ end if Reactant_jll.is_available() @setup_workload begin + init_jit() initialize_dialect() if XLA.REACTANT_XLA_RUNTIME == "PJRT" @@ -95,6 +49,5 @@ if Reactant_jll.is_available() XLA.free_client(client) client.client = C_NULL deinitialize_dialect() - clear_oc_cache() end end diff --git a/src/Reactant.jl b/src/Reactant.jl index 7c31f1a8c5..e145f688ab 100644 --- a/src/Reactant.jl +++ b/src/Reactant.jl @@ -91,7 +91,7 @@ include("xla/XLA.jl") include("Configuration.jl") include("Sharding.jl") include("Devices.jl") -include("Interpreter.jl") +include("JIT.jl") include("Profiler.jl") include("Types.jl") include("Distributed.jl") @@ -342,7 +342,7 @@ function __init__() """ maxlog = 1 end end - + init_jit() return nothing end diff --git a/src/utils.jl b/src/utils.jl index c7cb254946..d23719f2ff 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -20,881 +20,9 @@ end function call_with_reactant end -function maybe_argextype(@nospecialize(x), src) - return try - Core.Compiler.argextype(x, src) - catch err - !(err isa Core.Compiler.InvalidIRError) && rethrow() - nothing - end -end - # Defined in KernelAbstractions Ext function ka_with_reactant end -""" - Reactant.REDUB_ARGUMENTS_NAME - -The variable name bound to `call_with_reactant`'s tuple of arguments in its -`@generated` method definition. - -This binding can be used to manually reference/destructure `call_with_reactants` arguments - -This is required because user arguments could have a name which clashes with whatever name we choose for -our argument. Thus we gensym to create it. - -This originates from https://github.com/JuliaLabs/Cassette.jl/blob/c29b237c1ec0deda3a1037ec519eebe216952bfe/src/overdub.jl#L154 -""" -const REDUB_ARGUMENTS_NAME = gensym("redub_arguments") - -function throw_method_error(argtys) - throw(MethodError(argtys[1], argtys[2:end])) -end - -@inline function lookup_world( - @nospecialize(sig::Type), - world::UInt, - mt::Union{Nothing,Core.MethodTable}, - min_world::Ref{UInt}, - max_world::Ref{UInt}, -) - res = ccall( - :jl_gf_invoke_lookup_worlds, - Any, - (Any, Any, Csize_t, Ref{Csize_t}, Ref{Csize_t}), - sig, - mt, - world, - min_world, - max_world, - ) - return res -end - -@inline function lookup_world( - @nospecialize(sig::Type), - world::UInt, - mt::Core.Compiler.InternalMethodTable, - min_world::Ref{UInt}, - max_world::Ref{UInt}, -) - res = lookup_world(sig, mt.world, nothing, min_world, max_world) - return res -end - -@inline function lookup_world( - @nospecialize(sig::Type), - world::UInt, - mt::Core.Compiler.OverlayMethodTable, - min_world::Ref{UInt}, - max_world::Ref{UInt}, -) - res = lookup_world(sig, mt.world, mt.mt, min_world, max_world) - if res !== nothing - return res - else - return lookup_world(sig, mt.world, nothing, min_world, max_world) - end -end - -function has_ancestor(query::Module, target::Module) - query == target && return true - while true - next = parentmodule(query) - next == target && return true - next == query && return false - query = next - end -end - -const __skip_rewrite_func_set_lock = ReentrantLock() -const __skip_rewrite_func_set = Set([ - # Avoid the 1.10 stackoverflow - typeof(Base.typed_hvcat), - typeof(Base.hvcat), - typeof(Core.Compiler.concrete_eval_eligible), - typeof(Core.Compiler.typeinf_type), - typeof(Core.Compiler.typeinf_ext), - # TODO: perhaps problematic calls in `traced_call` - # should be moved to TracedUtils.jl: - typeof(ReactantCore.traced_call), - typeof(ReactantCore.is_traced), - # Perf optimization - typeof(Base.typemax), - typeof(Base.typemin), - typeof(Base.getproperty), - typeof(Base.vect), - typeof(Base.eltype), - typeof(Base.argtail), - typeof(Base.identity), - typeof(Base.print), - typeof(Base.println), - typeof(Base.show), - typeof(Base.show_delim_array), - typeof(Base.sprint), - typeof(Adapt.adapt_structure), - typeof(Core.is_top_bit_set), - typeof(Base.setindex_widen_up_to), - typeof(Base.typejoin), - typeof(Base.argtype_decl), - typeof(Base.arg_decl_parts), - typeof(Base.StackTraces.show_spec_sig), - typeof(Core.Compiler.return_type), - typeof(Core.throw_inexacterror), - typeof(Base.throw_boundserror), - typeof(Base._shrink), - typeof(Base._shrink!), - typeof(Base.ht_keyindex), - typeof(Base.checkindex), - typeof(Base.to_index), - @static( - if VERSION >= v"1.11.0" - typeof(Base.memoryref) - end - ), - typeof(materialize_traced_array), -]) - -""" - @skip_rewrite_func f - -Mark function `f` so that Reactant's IR rewrite mechanism will skip it. -This can improve compilation time if it's safe to assume that no call inside `f` -will need a `@reactant_overlay` method. - -!!! info - Note that this marks the whole function, not a specific method with a type - signature. - -!!! warning - The macro call should be inside the `__init__` function. If you want to - mark it for precompilation, you must add the macro call in the global scope - too. - -See also: [`@skip_rewrite_type`](@ref) -""" -macro skip_rewrite_func(fname) - quote - @lock $(Reactant.__skip_rewrite_func_set_lock) push!( - $(Reactant.__skip_rewrite_func_set), typeof($(esc(fname))) - ) - end -end - -const __skip_rewrite_type_constructor_list_lock = ReentrantLock() -const __skip_rewrite_type_constructor_list = [ - # Don't rewrite Val - Type{Base.Val}, - # Don't rewrite exception constructors - Type{<:Core.Exception}, - # Don't rewrite traced constructors - Type{<:TracedRArray}, - Type{<:TracedRNumber}, - Type{MLIR.IR.Location}, - Type{MLIR.IR.Block}, -] - -""" - @skip_rewrite_type MyStruct - @skip_rewrite_type Type{<:MyStruct} - -Mark the construct function of `MyStruct` so that Reactant's IR rewrite mechanism -will skip it. It does the same as [`@skip_rewrite_func`](@ref) but for type -constructors. - -If you want to mark the set of constructors over it's type parameters or over its -abstract type, you should use then the `Type{<:MyStruct}` syntax. - -!!! warning - The macro call should be inside the `__init__` function. If you want to - mark it for precompilation, you must add the macro call in the global scope - too. -""" -macro skip_rewrite_type(typ) - typ = if Base.isexpr(typ, :curly) && typ.args[1] === :Type - typ - else - Expr(:curly, :Type, typ) - end - return quote - @lock $(Reactant.__skip_rewrite_type_constructor_list_lock) push!( - $(Reactant.__skip_rewrite_type_constructor_list), $(esc(typ)) - ) - end -end - -function should_rewrite_call(@nospecialize(ft)) - # Don't rewrite builtin or intrinsics - if ft <: Core.IntrinsicFunction || ft <: Core.Builtin - return false - end - if ft <: Core.Function - if hasfield(typeof(ft), :name) && - hasfield(typeof(ft.name), :name) && - isdefined(ft.name, :name) - namestr = String(ft.name.name) - if startswith(namestr, "##(overlay (. Reactant (inert REACTANT_METHOD_TABLE)") - return false - end - end - - # We need this for closures to work - if hasfield(typeof(ft), :name) && hasfield(typeof(ft.name), :module) - mod = ft.name.module - # Don't rewrite primitive ops, tracing utilities, or any MLIR-based functions - if has_ancestor(mod, Ops) || - has_ancestor(mod, TracedUtils) || - has_ancestor(mod, MLIR) - return false - end - if string(mod) == "CUDA" - if ft.name.name == Symbol("#launch_configuration") - return false - end - if ft.name.name == Symbol("cudaconvert") - return false - end - end - end - end - - # `ft isa Type` is for performance as it avoids checking against all the list, but can be removed if problematic - if ft isa Type && any(t -> ft <: t, __skip_rewrite_type_constructor_list) - return false - end - - if ft in __skip_rewrite_func_set - return false - end - - # Default assume all functions need to be reactant-ified - return true -end - -# by default, same as `should_rewrite_call` -function should_rewrite_invoke(@nospecialize(ft), @nospecialize(args)) - # TODO how can we extend `@skip_rewrite` to methods? - if ft <: typeof(repeat) && (args == Tuple{String,Int64} || args == Tuple{Char,Int64}) - return false - end - return should_rewrite_call(ft) -end - -# Avoid recursively interpreting into methods we define explicitly -# as overloads, which we assume should handle the entirety of the -# translation (and if not they can use call_in_reactant). -function is_reactant_method(mi::Core.MethodInstance) - meth = mi.def - if !isdefined(meth, :external_mt) - return false - end - mt = meth.external_mt - return mt === REACTANT_METHOD_TABLE -end - -struct MustThrowError end - -@generated function applyiterate_with_reactant( - iteratefn, applyfn, args::Vararg{Any,N} -) where {N} - if iteratefn != typeof(Base.iterate) - return quote - error("Unhandled apply_iterate with iteratefn=$iteratefn") - end - end - newargs = Vector{Expr}(undef, N) - for i in 1:N - @inbounds newargs[i] = :(args[$i]...) - end - quote - Base.@_inline_meta - call_with_reactant(applyfn, $(newargs...)) - end -end - -@generated function applyiterate_with_reactant( - mt::MustThrowError, iteratefn, applyfn, args::Vararg{Any,N} -) where {N} - @assert iteratefn == typeof(Base.iterate) - newargs = Vector{Expr}(undef, N) - for i in 1:N - @inbounds newargs[i] = :(args[$i]...) - end - quote - Base.@_inline_meta - call_with_reactant(mt, applyfn, $(newargs...)) - end -end - -function certain_error() - throw( - AssertionError( - "The inferred code was guaranteed to throw this error. And yet, it didn't. So here we are...", - ), - ) -end - -function rewrite_inst(inst, ir, interp, RT, guaranteed_error) - if Meta.isexpr(inst, :call) - # Even if type unstable we do not want (or need) to replace intrinsic - # calls or builtins with our version. - ft = Core.Compiler.widenconst(maybe_argextype(inst.args[1], ir)) - if ft == typeof(Core.kwcall) - ft = Core.Compiler.widenconst(maybe_argextype(inst.args[3], ir)) - end - if ft == typeof(Core._apply_iterate) - ft = Core.Compiler.widenconst(maybe_argextype(inst.args[3], ir)) - if Base.invokelatest(should_rewrite_call, ft) - if RT === Union{} - rep = Expr( - :call, - applyiterate_with_reactant, - MustThrowError(), - inst.args[2:end]..., - ) - return true, rep, Union{} - else - rep = Expr(:call, applyiterate_with_reactant, inst.args[2:end]...) - return true, rep, Any - end - end - elseif Base.invokelatest(should_rewrite_call, ft) - if RT === Union{} - rep = Expr(:call, call_with_reactant, MustThrowError(), inst.args...) - return true, rep, Union{} - else - rep = Expr(:call, call_with_reactant, inst.args...) - return true, rep, Any - end - end - end - if Meta.isexpr(inst, :invoke) - omi = inst.args[1]::Core.MethodInstance - sig = omi.specTypes - ft = sig.parameters[1] - argsig = sig.parameters[2:end] - if ft == typeof(Core.kwcall) - ft = sig.parameters[3] - argsig = sig.parameters[4:end] - end - argsig = Core.apply_type(Core.Tuple, argsig...) - if Base.invokelatest(should_rewrite_invoke, ft, argsig) && !is_reactant_method(omi) - method = omi.def::Core.Method - - min_world = Ref{UInt}(typemin(UInt)) - max_world = Ref{UInt}(typemax(UInt)) - - # RT = Any - - if !method.isva || !Base.isvarargtype(sig.parameters[end]) - if RT === Union{} - sig2 = Tuple{ - typeof(call_with_reactant),MustThrowError,sig.parameters... - } - else - sig2 = Tuple{typeof(call_with_reactant),sig.parameters...} - end - else - vartup = inst.args[end] - ns = Type[] - eT = sig.parameters[end].T - for i in 1:(length(inst.args) - 1 - (length(sig.parameters) - 1)) - push!(ns, eT) - end - if RT === Union{} - sig2 = Tuple{ - typeof(call_with_reactant), - MustThrowError, - sig.parameters[1:(end - 1)]..., - ns..., - } - else - sig2 = Tuple{ - typeof(call_with_reactant),sig.parameters[1:(end - 1)]...,ns... - } - end - end - - lookup_result = lookup_world( - sig2, interp.world, Core.Compiler.method_table(interp), min_world, max_world - ) - - match = lookup_result::Core.MethodMatch - # look up the method and code instance - mi = ccall( - :jl_specializations_get_linfo, - Ref{Core.MethodInstance}, - (Any, Any, Any), - match.method, - match.spec_types, - match.sparams, - ) - n_method_args = method.nargs - if RT === Union{} - rep = Expr( - :invoke, mi, call_with_reactant, MustThrowError(), inst.args[2:end]... - ) - return true, rep, Union{} - else - rep = Expr(:invoke, mi, call_with_reactant, inst.args[2:end]...) - return true, rep, Any - end - end - end - if isa(inst, Core.ReturnNode) && (!isdefined(inst, :val) || guaranteed_error) - min_world = Ref{UInt}(typemin(UInt)) - max_world = Ref{UInt}(typemax(UInt)) - - sig2 = Tuple{typeof(certain_error)} - - lookup_result = lookup_world( - sig2, interp.world, Core.Compiler.method_table(interp), min_world, max_world - ) - - match = lookup_result::Core.MethodMatch - # look up the method and code instance - mi = ccall( - :jl_specializations_get_linfo, - Ref{Core.MethodInstance}, - (Any, Any, Any), - match.method, - match.spec_types, - match.sparams, - ) - rep = Expr(:invoke, mi, certain_error) - return true, rep, Union{} - end - return false, inst, RT -end - -const oc_capture_vec = Vector{Any}() - -# Caching is both good to reducing compile times and necessary to work around julia bugs -# in OpaqueClosure's: https://github.com/JuliaLang/julia/issues/56833 -function make_oc_dict( - @nospecialize(oc_captures::Dict{FT,Core.OpaqueClosure}), - @nospecialize(sig::Type), - @nospecialize(rt::Type), - @nospecialize(src::Core.CodeInfo), - nargs::Int, - isva::Bool, - @nospecialize(f::FT) -)::Core.OpaqueClosure where {FT} - key = f - if haskey(oc_captures, key) - oc = oc_captures[key] - oc - else - ores = ccall( - :jl_new_opaque_closure_from_code_info, - Any, - (Any, Any, Any, Any, Any, Cint, Any, Cint, Cint, Any, Cint), - sig, - rt, - rt, - @__MODULE__, - src, - 0, - nothing, - nargs, - isva, - f, - true, - )::Core.OpaqueClosure - oc_captures[key] = ores - return ores - end -end - -function make_oc_ref( - oc_captures::Base.RefValue{Core.OpaqueClosure}, - @nospecialize(sig::Type), - @nospecialize(rt::Type), - @nospecialize(src::Core.CodeInfo), - nargs::Int, - isva::Bool, - @nospecialize(f) -)::Core.OpaqueClosure - if Base.isassigned(oc_captures) - return oc_captures[] - else - ores = ccall( - :jl_new_opaque_closure_from_code_info, - Any, - (Any, Any, Any, Any, Any, Cint, Any, Cint, Cint, Any, Cint), - sig, - rt, - rt, - @__MODULE__, - src, - 0, - nothing, - nargs, - isva, - f, - true, - )::Core.OpaqueClosure - oc_captures[] = ores - return ores - end -end - -function safe_print(name, x) - return ccall(:jl_, Cvoid, (Any,), name * " " * string(x)) -end - -const DEBUG_INTERP = Ref(false) - -# Rewrite type unstable calls to recurse into call_with_reactant to ensure -# they continue to use our interpreter. Reset the derived return type -# to Any if our interpreter would change the return type of any result. -# Also rewrite invoke (type stable call) to be :call, since otherwise apparently -# screws up type inference after this (TODO this should be fixed). -function rewrite_insts!(ir, interp, guaranteed_error) - any_changed = false - for (i, inst) in enumerate(ir.stmts) - # Explicitly skip any code which returns Union{} so that we throw the error - # instead of risking a segfault - RT = inst[:type] - @static if VERSION < v"1.11" - changed, next, RT = rewrite_inst(inst[:inst], ir, interp, RT, guaranteed_error) - Core.Compiler.setindex!(ir.stmts[i], next, :inst) - else - changed, next, RT = rewrite_inst(inst[:stmt], ir, interp, RT, guaranteed_error) - Core.Compiler.setindex!(ir.stmts[i], next, :stmt) - end - if changed - any_changed = true - Core.Compiler.setindex!(ir.stmts[i], RT, :type) - end - end - return ir, any_changed -end - -function rewrite_argnumbers_by_one!(ir) - # Add one dummy argument at the beginning - pushfirst!(ir.argtypes, Nothing) - - # Re-write all references to existing arguments to their new index (N + 1) - for idx in 1:length(ir.stmts) - urs = Core.Compiler.userefs(ir.stmts[idx][:inst]) - changed = false - it = Core.Compiler.iterate(urs) - while it !== nothing - (ur, next) = it - old = Core.Compiler.getindex(ur) - if old isa Core.Argument - # Replace the Argument(n) with Argument(n + 1) - Core.Compiler.setindex!(ur, Core.Argument(old.n + 1)) - changed = true - end - it = Core.Compiler.iterate(urs, next) - end - if changed - @static if VERSION < v"1.11" - Core.Compiler.setindex!(ir.stmts[idx], Core.Compiler.getindex(urs), :inst) - else - Core.Compiler.setindex!(ir.stmts[idx], Core.Compiler.getindex(urs), :stmt) - end - end - end - - return nothing -end - -# Generator function which ensures that all calls to the function are executed within the ReactantInterpreter -# In particular this entails two pieces: -# 1) We enforce the use of the ReactantInterpreter method table when generating the original methodinstance -# 2) Post type inference (using of course the reactant interpreter), all type unstable call functions are -# replaced with calls to `call_with_reactant`. This allows us to circumvent long standing issues in Julia -# using a custom interpreter in type unstable code. -# `redub_arguments` is `(typeof(original_function), map(typeof, original_args_tuple)...)` -function call_with_reactant_generator( - world::UInt, source::LineNumberNode, self, @nospecialize(redub_arguments) -) - @nospecialize - args = redub_arguments - if DEBUG_INTERP[] - safe_print("args", args) - end - - stub = Core.GeneratedFunctionStub( - identity, Core.svec(:call_with_reactant, REDUB_ARGUMENTS_NAME), Core.svec() - ) - - fn = args[1] - sig = Tuple{args...} - - guaranteed_error = false - if fn === MustThrowError - guaranteed_error = true - fn = args[2] - sig = Tuple{args[2:end]...} - end - - # look up the method match - builtin_error = - :(throw(AssertionError("Unsupported call_with_reactant of builtin $fn"))) - - if fn <: Core.Builtin - return stub(world, source, builtin_error) - end - - if guaranteed_error - method_error = :(throw( - MethodError($REDUB_ARGUMENTS_NAME[2], $REDUB_ARGUMENTS_NAME[3:end], $world) - )) - else - method_error = :(throw( - MethodError($REDUB_ARGUMENTS_NAME[1], $REDUB_ARGUMENTS_NAME[2:end], $world) - )) - end - - interp = ReactantInterpreter(; world) - - min_world = Ref{UInt}(typemin(UInt)) - max_world = Ref{UInt}(typemax(UInt)) - - lookup_result = lookup_world( - sig, world, Core.Compiler.method_table(interp), min_world, max_world - ) - - overdubbed_code = Any[] - overdubbed_codelocs = Int32[] - - # No method could be found (including in our method table), bail with an error - if lookup_result === nothing - return stub(world, source, method_error) - end - - match = lookup_result::Core.MethodMatch - # look up the method and code instance - mi = ccall( - :jl_specializations_get_linfo, - Ref{Core.MethodInstance}, - (Any, Any, Any), - match.method, - match.spec_types, - match.sparams, - ) - method = mi.def - - @static if VERSION < v"1.11" - # For older Julia versions, we vendor in some of the code to prevent - # having to build the MethodInstance twice. - result = CC.InferenceResult(mi, CC.typeinf_lattice(interp)) - frame = CC.InferenceState(result, :no, interp) - @assert !isnothing(frame) - CC.typeinf(interp, frame) - ir = CC.run_passes(frame.src, CC.OptimizationState(frame, interp), result, nothing) - rt = CC.widenconst(CC.ignorelimited(result.result)) - else - ir, rt = CC.typeinf_ircode(interp, mi, nothing) - end - - if guaranteed_error - if rt !== Union{} - safe_print("Inconsistent guaranteed error IR", ir) - end - rt = Union{} - end - - if DEBUG_INTERP[] - safe_print("ir", ir) - end - - mi = mi::Core.MethodInstance - - if !( - is_reactant_method(mi) || ( - mi.def.sig isa DataType && - !should_rewrite_invoke( - mi.def.sig.parameters[1], Tuple{mi.def.sig.parameters[2:end]...} - ) - ) - ) || guaranteed_error - ir, any_changed = rewrite_insts!(ir, interp, guaranteed_error) - end - - rewrite_argnumbers_by_one!(ir) - - src = ccall(:jl_new_code_info_uninit, Ref{Core.CodeInfo}, ()) - src.slotnames = fill(:none, length(ir.argtypes) + 1) - src.slotflags = fill(zero(UInt8), length(ir.argtypes)) - src.slottypes = copy(ir.argtypes) - src.rettype = rt - src = CC.ir_to_codeinf!(src, ir) - - if DEBUG_INTERP[] - safe_print("src", src) - end - - # prepare a new code info - code_info = copy(src) - static_params = match.sparams - signature = sig - - # propagate edge metadata, this method is invalidated if the original function we are calling - # is invalidated - code_info.edges = Core.MethodInstance[mi] - code_info.min_world = min_world[] - code_info.max_world = max_world[] - - # Rewrite the arguments to this function, to prepend the two new arguments, the function :call_with_reactant, - # and the REDUB_ARGUMENTS_NAME tuple of input arguments - code_info.slotnames = Any[:call_with_reactant, REDUB_ARGUMENTS_NAME] - code_info.slotflags = UInt8[0x00, 0x00] - n_prepended_slots = 2 - overdub_args_slot = Core.SlotNumber(n_prepended_slots) - - # For the sake of convenience, the rest of this pass will translate `code_info`'s fields - # into these overdubbed equivalents instead of updating `code_info` in-place. Then, at - # the end of the pass, we'll reset `code_info` fields accordingly. - overdubbed_code = Any[] - overdubbed_codelocs = Int32[] - function push_inst!(inst) - push!(overdubbed_code, inst) - push!(overdubbed_codelocs, code_info.codelocs[1]) - return Core.SSAValue(length(overdubbed_code)) - end - # Rewire the arguments from our tuple input of fn and args, to the corresponding calling convention - # required by the base method. - - # destructure the generated argument slots into the overdubbed method's argument slots. - - offset = 1 - fn_args = Any[] - n_method_args = method.nargs - n_actual_args = length(redub_arguments) - if guaranteed_error - offset += 1 - n_actual_args -= 1 - end - - tys = [] - - iter_args = n_actual_args - if method.isva - iter_args = min(n_actual_args, n_method_args - 1) - end - - for i in 1:iter_args - actual_argument = Expr( - :call, Core.GlobalRef(Core, :getfield), overdub_args_slot, offset - ) - arg = push_inst!(actual_argument) - offset += 1 - push!(fn_args, arg) - push!(tys, redub_arguments[i + (guaranteed_error ? 1 : 0)]) - - if DEBUG_INTERP[] - push_inst!( - Expr( - :call, - safe_print, - "fn arg[" * string(length(fn_args)) * "]", - fn_args[end], - ), - ) - end - end - - # If `method` is a varargs method, we have to restructure the original method call's - # trailing arguments into a tuple and assign that tuple to the expected argument slot. - if method.isva - trailing_arguments = Expr(:call, Core.GlobalRef(Core, :tuple)) - for i in n_method_args:n_actual_args - arg = push_inst!( - Expr(:call, Core.GlobalRef(Core, :getfield), overdub_args_slot, offset) - ) - push!(trailing_arguments.args, arg) - offset += 1 - end - - push!(fn_args, push_inst!(trailing_arguments)) - push!( - tys, - Tuple{ - redub_arguments[(n_method_args:n_actual_args) .+ (guaranteed_error ? 1 : 0)]..., - }, - ) - - if DEBUG_INTERP[] - push_inst!( - Expr( - :call, - safe_print, - "fn arg[" * string(length(fn_args)) * "]", - fn_args[end], - ), - ) - end - end - - # ocva = method.isva - - ocva = false # method.isva - - ocnargs = Int(method.nargs) - # octup = Tuple{mi.specTypes.parameters[2:end]...} - # octup = Tuple{method.sig.parameters[2:end]...} - octup = Tuple{tys[1:end]...} - ocva = false - - # jl_new_opaque_closure forcibly executes in the current world... This means that we won't get the right - # inner code during compilation without special handling (i.e. call_in_world_total). - # Opaque closures also require taking the function argument. We can work around the latter - # if the function is stateless. But regardless, to work around this we sadly create/compile the opaque closure - - dict, make_oc = (Base.Ref{Core.OpaqueClosure}(), make_oc_ref) - - push!(oc_capture_vec, dict) - - oc = if false && Base.issingletontype(fn) - res = Core._call_in_world_total( - world, make_oc, dict, octup, rt, src, ocnargs, ocva, fn.instance - )::Core.OpaqueClosure - else - farg = fn_args[1] - farg = nothing - rep = Expr(:call, make_oc, dict, octup, rt, src, ocnargs, ocva, farg) - push_inst!(rep) - Core.SSAValue(length(overdubbed_code)) - end - - push_inst!(Expr(:call, oc, fn_args[1:end]...)) - - ocres = Core.SSAValue(length(overdubbed_code)) - - if DEBUG_INTERP[] - push_inst!(Expr(:call, safe_print, "ocres", ocres)) - end - - push_inst!(Core.ReturnNode(ocres)) - - #=== set `code_info`/`reflection` fields accordingly ===# - - if code_info.method_for_inference_limit_heuristics === nothing - code_info.method_for_inference_limit_heuristics = method - end - - code_info.code = overdubbed_code - code_info.codelocs = overdubbed_codelocs - code_info.ssavaluetypes = length(overdubbed_code) - code_info.ssaflags = [0x00 for _ in 1:length(overdubbed_code)] # XXX we need to copy flags that are set for the original code - - if DEBUG_INTERP[] - safe_print("code_info", code_info) - end - - return code_info -end - -@eval function call_with_reactant($REDUB_ARGUMENTS_NAME...) - $(Expr(:meta, :generated_only)) - return $(Expr(:meta, :generated, call_with_reactant_generator)) -end - @static if isdefined(Core, :BFloat16) nmantissa(::Type{Core.BFloat16}) = 7 end From add3d4f8805c4141be125e91f340321f301c1145 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 22 Oct 2025 11:14:22 -0400 Subject: [PATCH 2/3] Apply suggestion from @avik-pal --- ext/ReactantCUDAExt.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ext/ReactantCUDAExt.jl b/ext/ReactantCUDAExt.jl index 2fa3bb6f3e..ebea33babe 100644 --- a/ext/ReactantCUDAExt.jl +++ b/ext/ReactantCUDAExt.jl @@ -1465,7 +1465,7 @@ end @static if !Sys.isapple() @setup_workload begin Reactant.initialize_dialect() - init_jit() + Reactant.init_jit() if Reactant.XLA.REACTANT_XLA_RUNTIME == "PJRT" client = Reactant.XLA.PJRT.CPUClient(; checkcount=false) elseif Reactant.XLA.REACTANT_XLA_RUNTIME == "IFRT" From 7e0e56598a990651c4baa0dfc6d15bc95e17baea Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ga=C3=ABtan=20LOUNES?= Date: Wed, 29 Oct 2025 16:20:55 +0100 Subject: [PATCH 3/3] blacklist objectid --- src/JIT.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/JIT.jl b/src/JIT.jl index d86d54c727..ffe2ba1d0e 100644 --- a/src/JIT.jl +++ b/src/JIT.jl @@ -31,6 +31,7 @@ const __skip_rewrite_func_set = Set([ typeof(task_local_storage), typeof(getproperty), typeof(invokelatest), + typeof(objectid) ]) const __skip_rewrite_func_set_lock = ReentrantLock()