Skip to content

Commit 6fc66b3

Browse files
authored
fix: loading reactant on 1.12 (#1755)
* fix: loading reactant on 1.12 * chore: add a warning msg for 1.12 * fix: flip recommendation order * fix: missing is_closure * test: fix
1 parent 4b88602 commit 6fc66b3

File tree

3 files changed

+38
-16
lines changed

3 files changed

+38
-16
lines changed

src/Compiler.jl

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3714,6 +3714,18 @@ struct Thunk{FTy,tag,IsClosure,ArgTypes,ExecTy,DeviceTy,ClientTy,GD,DAM}
37143714
donated_args_mask::DAM
37153715
end
37163716

3717+
for fn in (:get_tag, :get_isclosure, :get_compiled_argtypes)
3718+
@eval $fn(thunk::Thunk) = $fn(typeof(thunk))
3719+
end
3720+
3721+
function get_compiled_argtypes(::Type{<:Thunk{<:Any,<:Any,<:Any,ArgTypes}}) where {ArgTypes}
3722+
return ArgTypes
3723+
end
3724+
3725+
get_tag(::Type{<:Thunk{<:Any,tag}}) where {tag} = tag
3726+
3727+
get_isclosure(::Type{<:Thunk{<:Any,<:Any,IsClosure}}) where {IsClosure} = IsClosure
3728+
37173729
function Base.show(io::IO, thunk::Thunk{<:Any,tag}) where {tag}
37183730
return print(io, "Reactant compiled function $(thunk.f) (with tag $(tag))")
37193731
end
@@ -3752,24 +3764,13 @@ function Base.showerror(
37523764
)
37533765
end
37543766

3755-
@generated function (
3756-
thunk::Thunk{FTy,tag,IsClosure,ArgTypes,ExecTy,DeviceTy,ClientTy,GD,DAM}
3757-
)(
3758-
args...
3759-
) where {FTy,tag,IsClosure,ArgTypes,ExecTy,DeviceTy,ClientTy,GD,DAM}
3767+
@generated function (thunk::Thunk)(args...)
37603768
FoundTypes = Tuple{args...}
3761-
if ArgTypes != FoundTypes
3762-
return quote
3763-
throw(
3764-
$(MisMatchedThunkTypeError{
3765-
Thunk{FTy,tag,IsClosure,ArgTypes,ExecTy,DeviceTy,ClientTy,GD,DAM},
3766-
FoundTypes,
3767-
}()),
3768-
)
3769-
end
3769+
if get_compiled_argtypes(thunk) != FoundTypes
3770+
return :(throw($(MisMatchedThunkTypeError{thunk,FoundTypes}())))
37703771
end
3771-
body = __thunk_fwd_body_cache[tag]
3772-
if IsClosure
3772+
body = __thunk_fwd_body_cache[get_tag(thunk)]
3773+
if get_isclosure(thunk)
37733774
return quote
37743775
args = (thunk.f, args...)
37753776
$body

src/Reactant.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -330,6 +330,19 @@ function __init__()
330330
end
331331
end
332332

333+
@static if VERSION v"1.12-"
334+
if ccall(:jl_generating_output, Cint, ()) == 1
335+
@warn """
336+
Reactant.jl currently doesn't support versions of Julia 1.12 or newer. We are
337+
actively working on adding support for newer versions of Julia. For the time
338+
being we recommend using 1.11 or LTS.
339+
340+
For latest updates, check the status of support for Julia 1.12+ at
341+
https://github.com/EnzymeAD/Reactant.jl/issues/1736.
342+
""" maxlog = 1
343+
end
344+
end
345+
333346
return nothing
334347
end
335348

test/basic.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1671,3 +1671,11 @@ init_fill_0d(x) = fill(eltype(x)(5))
16711671
@test @jit(init_fill(x_ra)) init_fill(x)
16721672
@test @jit(init_fill_0d(x_ra)) init_fill_0d(x)
16731673
end
1674+
1675+
@testset "Mismatched Thunk Error" begin
1676+
x_ra1 = Reactant.to_rarray(rand(Float32, 2))
1677+
x_ra2 = Reactant.to_rarray(rand(Float64, 2))
1678+
1679+
fn = @compile sum(x_ra1)
1680+
@test_throws Reactant.Compiler.MisMatchedThunkTypeError fn(x_ra2)
1681+
end

0 commit comments

Comments
 (0)