Skip to content

Commit 151ff21

Browse files
committed
fix: correct type for ad annotation
1 parent d8a27f4 commit 151ff21

File tree

3 files changed

+34
-0
lines changed

3 files changed

+34
-0
lines changed

src/Enzyme.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -295,6 +295,16 @@ function act_attr(val)
295295
return MLIR.IR.Attribute(val)
296296
end
297297

298+
function overload_autodiff(
299+
mode::CMode, f::FA, args::Vararg{Annotation,Nargs}
300+
) where {CMode<:Mode,FA<:Annotation,Nargs}
301+
# need to guess the correct activity here. Execute the function, we will DCE it
302+
res = call_with_reactant(f.val, [x.val for x in args]...)
303+
return overload_autodiff(
304+
mode, f, Enzyme.guess_activity(Core.Typeof(res), mode), args...
305+
)
306+
end
307+
298308
function overload_autodiff(
299309
::CMode, f::FA, ::Type{A}, args::Vararg{Annotation,Nargs}
300310
) where {CMode<:Mode,FA<:Annotation,A<:Annotation,Nargs}

src/Overlay.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,18 @@
99
end
1010

1111
# Enzyme.jl overlays
12+
@reactant_overlay @noinline function Enzyme.autodiff_deferred(
13+
rmode::Enzyme.Mode, f::FA, args::Vararg{Annotation,Nargs}
14+
) where {FA<:Annotation,Nargs}
15+
return overload_autodiff(rmode, f, args...)
16+
end
17+
18+
@reactant_overlay @noinline function Enzyme.autodiff(
19+
rmode::Enzyme.Mode, f::FA, args::Vararg{Annotation,Nargs}
20+
) where {FA<:Annotation,Nargs}
21+
return overload_autodiff(rmode, f, args...)
22+
end
23+
1224
@reactant_overlay @noinline function Enzyme.autodiff_deferred(
1325
rmode::Enzyme.Mode, f::FA, rt::Type{A}, args::Vararg{Annotation,Nargs}
1426
) where {FA<:Annotation,A<:Annotation,Nargs}

test/autodiff.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@ using Enzyme, Reactant, Test, Random
22

33
square(x) = x * 2
44

5+
sum_without_activity(x) = sum(abs2, x; dims=(2, 3))
6+
57
fwd(Mode, RT, x, y) = Enzyme.autodiff(Mode, square, RT, Duplicated(x, y))
68

79
@testset "Activity" begin
@@ -39,6 +41,16 @@ fwd(Mode, RT, x, y) = Enzyme.autodiff(Mode, square, RT, Duplicated(x, y))
3941
Enzyme.Duplicated
4042
end
4143

44+
@testset "Correct Activity Guess" begin
45+
x = Reactant.to_rarray(Reactant.TestUtils.construct_test_array(Float32, 3, 4, 5, 6))
46+
bx = Reactant.to_rarray(Reactant.TestUtils.construct_test_array(Float32, 3, 4, 5, 6))
47+
48+
res = only(@jit(Enzyme.autodiff(Forward, sum_without_activity, Duplicated(x, bx))))
49+
@test res isa Reactant.ConcreteRArray{Float32,4}
50+
@test size(res) == (3, 1, 1, 6)
51+
@test 2 .* sum_without_activity(bx) res
52+
end
53+
4254
@testset "Basic Forward Mode" begin
4355
res1 = @jit(
4456
fwd(

0 commit comments

Comments
 (0)