Skip to content

Commit 6e1a02c

Browse files
avik-palwsmoses
andauthored
feat: some more 1.12 support
* test: version 1.12 Comment out older Go versions in CI matrix * test: v1.12 * ci: disable downgrade * fix: get 1.12 KAExt loading to work * Mildly functioning on 1.12 * ci * fix 1.10 * Change Julia version from 1.12 to 1.10 * chore: run formatter --------- Co-authored-by: William S. Moses <gh@wsmoses.com> Co-authored-by: William Moses <wmoses@google.com>
1 parent f498044 commit 6e1a02c

File tree

7 files changed

+115
-38
lines changed

7 files changed

+115
-38
lines changed

benchmark/aggregate.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,5 +15,5 @@ for backend in BACKENDS
1515
end
1616

1717
open(joinpath(dirname(@__FILE__), "results", "combinedbenchmarks.json"), "w") do io
18-
JSON3.pretty(io, JSON3.write(all_results))
18+
return JSON3.pretty(io, JSON3.write(all_results))
1919
end

benchmark/runbenchmarks.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ for (i, (k, v)) in enumerate(results)
4444
end
4545

4646
open(joinpath(filepath, filename), "w") do io
47-
JSON3.pretty(io, JSON3.write(standardized_results))
47+
return JSON3.pretty(io, JSON3.write(standardized_results))
4848
end
4949

5050
@info "Saved results to $(joinpath(filepath, filename))"

deps/build_local.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -252,7 +252,7 @@ run(Cmd(Cmd(build_cmd_list); dir=source_dir))
252252

253253
# Discover built libraries
254254
built_libs = filter(readdir(joinpath(source_dir, "bazel-bin"))) do file
255-
endswith(file, "Extra.so") && startswith(file, "lib")
255+
return endswith(file, "Extra.so") && startswith(file, "lib")
256256
end
257257

258258
lib_path = joinpath(source_dir, "bazel-bin", only(built_libs))

ext/ReactantKernelAbstractionsExt.jl

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -109,14 +109,26 @@ function (obj::KA.Kernel{ReactantBackend})(args...; ndrange=nothing, workgroupsi
109109
return nothing
110110
end
111111

112-
Reactant.@reactant_overlay Base.@nospecializeinfer @noinline function (
113-
obj::KA.Kernel{ReactantBackend}
114-
)(
115-
@nospecialize args...; ndrange=nothing, workgroupsize=nothing
116-
)
117-
return Reactant.call_with_reactant(
118-
Reactant.ka_with_reactant, ndrange, workgroupsize, obj, args...
112+
@static if VERSION < v"1.12-"
113+
Reactant.@reactant_overlay Base.@nospecializeinfer @noinline function (
114+
obj::KA.Kernel{ReactantBackend}
115+
)(
116+
@nospecialize args...; ndrange=nothing, workgroupsize=nothing
119117
)
118+
return Reactant.call_with_reactant(
119+
Reactant.ka_with_reactant, ndrange, workgroupsize, obj, args...
120+
)
121+
end
122+
else
123+
Reactant.@reactant_overlay function (obj::KA.Kernel{ReactantBackend})(
124+
args...; ndrange=nothing, workgroupsize=nothing
125+
)
126+
Base.@_noinline_meta
127+
Base.@_nospecializeinfer_meta
128+
return Reactant.call_with_reactant(
129+
Reactant.ka_with_reactant, ndrange, workgroupsize, obj, args...
130+
)
131+
end
120132
end
121133

122134
end

src/Interpreter.jl

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,18 +41,31 @@ function set_reactant_abi(
4141
if length(argtypes) != 1
4242
@static if VERSION < v"1.11.0-"
4343
return CallMeta(Union{}, Effects(), NoCallInfo())
44-
else
44+
elseif VERSION < v"1.12.0-"
4545
return CallMeta(Union{}, Union{}, Effects(), NoCallInfo())
46+
else
47+
return Core.Compiler.Future{Core.Compiler.CallMeta}(
48+
CallMeta(Union{}, Union{}, Effects(), NoCallInfo())
49+
)
4650
end
4751
end
4852
@static if VERSION < v"1.11.0-"
4953
return CallMeta(
5054
Core.Const(true), Core.Compiler.EFFECTS_TOTAL, MethodResultPure()
5155
)
52-
else
56+
elseif VERSION < v"1.12.0-"
5357
return CallMeta(
5458
Core.Const(true), Union{}, Core.Compiler.EFFECTS_TOTAL, MethodResultPure()
5559
)
60+
else
61+
return Core.Compiler.Future{Core.Compiler.CallMeta}(
62+
CallMeta(
63+
Core.Const(true),
64+
Union{},
65+
Core.Compiler.EFFECTS_TOTAL,
66+
MethodResultPure(),
67+
),
68+
)
5669
end
5770
end
5871

src/utils.jl

Lines changed: 73 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -369,7 +369,11 @@ function rewrite_inst(inst, ir, interp, RT, guaranteed_error)
369369
end
370370
end
371371
if Meta.isexpr(inst, :invoke)
372-
omi = inst.args[1]::Core.MethodInstance
372+
omi = if inst.args[1] isa Core.MethodInstance
373+
inst.args[1]
374+
else
375+
(inst.args[1]::Core.CodeInstance).def
376+
end
373377
sig = omi.specTypes
374378
ft = sig.parameters[1]
375379
argsig = sig.parameters[2:end]
@@ -518,22 +522,42 @@ function make_oc_ref(
518522
if Base.isassigned(oc_captures)
519523
return oc_captures[]
520524
else
521-
ores = ccall(
522-
:jl_new_opaque_closure_from_code_info,
523-
Any,
524-
(Any, Any, Any, Any, Any, Cint, Any, Cint, Cint, Any, Cint),
525-
sig,
526-
rt,
527-
rt,
528-
@__MODULE__,
529-
src,
530-
0,
531-
nothing,
532-
nargs,
533-
isva,
534-
f,
535-
true,
536-
)::Core.OpaqueClosure
525+
ores = @static if VERSION < v"1.11"
526+
ccall(
527+
:jl_new_opaque_closure_from_code_info,
528+
Any,
529+
(Any, Any, Any, Any, Any, Cint, Any, Cint, Cint, Any, Cint),
530+
sig,
531+
rt,
532+
rt,
533+
@__MODULE__,
534+
src,
535+
0,
536+
nothing,
537+
nargs,
538+
isva,
539+
f,
540+
true,
541+
)::Core.OpaqueClosure
542+
else
543+
ccall(
544+
:jl_new_opaque_closure_from_code_info,
545+
Any,
546+
(Any, Any, Any, Any, Any, Cint, Any, Cint, Cint, Any, Cint, Cint),
547+
sig, # jl_tupletype_t *argt
548+
rt, # jl_value_t *rt_lb
549+
rt, # jl_value_t *rt_ub
550+
@__MODULE__, # jl_module_t *mod
551+
src, # jl_code_info_t *ci
552+
0, # int lineno
553+
nothing, # jl_value_t *file
554+
nargs, # int nargs
555+
isva, # int isva
556+
f, # jl_value_t *env
557+
true, # int do_compile
558+
true, # int isinferred
559+
)::Core.OpaqueClosure
560+
end
537561
oc_captures[] = ores
538562
return ores
539563
end
@@ -725,7 +749,9 @@ function call_with_reactant_generator(
725749
src.slotnames = fill(:none, length(ir.argtypes) + 1)
726750
src.slotflags = fill(zero(UInt8), length(ir.argtypes))
727751
src.slottypes = copy(ir.argtypes)
728-
src.rettype = rt
752+
@static if VERSION < v"1.12.0-"
753+
src.rettype = rt
754+
end
729755
src = CC.ir_to_codeinf!(src, ir)
730756

731757
if DEBUG_INTERP[]
@@ -747,17 +773,31 @@ function call_with_reactant_generator(
747773
# and the REDUB_ARGUMENTS_NAME tuple of input arguments
748774
code_info.slotnames = Any[:call_with_reactant, REDUB_ARGUMENTS_NAME]
749775
code_info.slotflags = UInt8[0x00, 0x00]
776+
777+
if VERSION >= v"1.12-"
778+
code_info.nargs = length(code_info.slotnames)
779+
code_info.isva = true
780+
end
781+
750782
n_prepended_slots = 2
751783
overdub_args_slot = Core.SlotNumber(n_prepended_slots)
752784

753785
# For the sake of convenience, the rest of this pass will translate `code_info`'s fields
754786
# into these overdubbed equivalents instead of updating `code_info` in-place. Then, at
755787
# the end of the pass, we'll reset `code_info` fields accordingly.
756788
overdubbed_code = Any[]
757-
overdubbed_codelocs = Int32[]
789+
790+
overdubbed_codelocs = @static if isdefined(Core, :DebugInfo)
791+
nothing
792+
else
793+
Int32[]
794+
end
795+
758796
function push_inst!(inst)
759797
push!(overdubbed_code, inst)
760-
push!(overdubbed_codelocs, code_info.codelocs[1])
798+
@static if !isdefined(Core, :DebugInfo)
799+
push!(overdubbed_codelocs, code_info.codelocs[1])
800+
end
761801
return Core.SSAValue(length(overdubbed_code))
762802
end
763803
# Rewire the arguments from our tuple input of fn and args, to the corresponding calling convention
@@ -781,6 +821,11 @@ function call_with_reactant_generator(
781821
iter_args = min(n_actual_args, n_method_args - 1)
782822
end
783823

824+
if VERSION >= v"1.12-"
825+
src.nargs = length(src.slottypes)
826+
src.isva = false
827+
end
828+
784829
for i in 1:iter_args
785830
actual_argument = Expr(
786831
:call, Core.GlobalRef(Core, :getfield), overdub_args_slot, offset
@@ -862,12 +907,9 @@ function call_with_reactant_generator(
862907
farg = nothing
863908
rep = Expr(:call, make_oc, dict, octup, rt, src, ocnargs, ocva, farg)
864909
push_inst!(rep)
865-
Core.SSAValue(length(overdubbed_code))
866910
end
867911

868-
push_inst!(Expr(:call, oc, fn_args[1:end]...))
869-
870-
ocres = Core.SSAValue(length(overdubbed_code))
912+
ocres = push_inst!(Expr(:call, oc, fn_args[1:end]...))
871913

872914
if DEBUG_INTERP[]
873915
push_inst!(Expr(:call, safe_print, "ocres", ocres))
@@ -882,7 +924,13 @@ function call_with_reactant_generator(
882924
end
883925

884926
code_info.code = overdubbed_code
885-
code_info.codelocs = overdubbed_codelocs
927+
928+
@static if isdefined(Core, :DebugInfo)
929+
code_info.debuginfo = Core.DebugInfo(:none) # Core.DebugInfoStream(overdubbed_codelocs), length(overdubbed_codelocs))
930+
else
931+
code_info.codelocs = overdubbed_codelocs
932+
end
933+
886934
code_info.ssavaluetypes = length(overdubbed_code)
887935
code_info.ssaflags = [0x00 for _ in 1:length(overdubbed_code)] # XXX we need to copy flags that are set for the original code
888936

src/xla/PJRT/LoadedExecutable.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,11 @@ function XLA.compile(
105105
end
106106

107107
function execute_ir(N, M, n_outs, with_device::Bool, nmesh_ids::Int64)
108-
ptr = sizeof(Int) == sizeof(Int64) ? "i64" : "i32"
108+
ptr = @static if VERSION < v"1.12"
109+
sizeof(Int) == sizeof(Int64) ? "i64" : "i32"
110+
else
111+
"ptr"
112+
end
109113
cint = sizeof(Cint) == sizeof(Int64) ? "i64" : "i32"
110114
args = N > 0 ? ", [$N x $ptr] %inps, [$M x i8] %donated" : ""
111115
if with_device

0 commit comments

Comments
 (0)