@@ -40,10 +40,6 @@ fwd(Mode, RT, x, y) = Enzyme.autodiff(Mode, square, RT, Duplicated(x, y))
4040end
4141
4242@testset " Basic Forward Mode" begin
43- ores1 = fwd (Forward, Duplicated, ones (3 , 2 ), 3.1 * ones (3 , 2 ))
44- @test typeof (ores1) == NamedTuple{(Symbol (" 1" ),),Tuple{Array{Float64,2 }}}
45- @test ores1[1 ] ≈ 6.2 * ones (3 , 2 )
46-
4743 res1 = @jit (
4844 fwd (
4945 Forward,
5450 )
5551
5652 @test res1 isa Tuple{<: ConcreteRArray{Float64,2} }
57- @test res1[1 ] ≈ ores1[1 ]
58-
59- ores1 = fwd (ForwardWithPrimal, Duplicated, ones (3 , 2 ), 3.1 * ones (3 , 2 ))
60- @test typeof (ores1) ==
61- NamedTuple{(Symbol (" 1" ), Symbol (" 2" )),Tuple{Array{Float64,2 },Array{Float64,2 }}}
53+ @test res1[1 ] ≈ fill (6.2 , 3 , 2 )
6254
6355 res1 = @jit (
6456 fwd (
7062 )
7163
7264 @test res1 isa Tuple{<: ConcreteRArray{Float64,2} ,<: ConcreteRArray{Float64,2} }
73- @test res1[1 ] ≈ ores1[1 ]
74- @test res1[2 ] ≈ ores1[2 ]
75-
76- ores1 = fwd (Forward, Const, ones (3 , 2 ), 3.1 * ones (3 , 2 ))
77- @test typeof (ores1) == Tuple{}
65+ @test res1[1 ] ≈ fill (6.2 , 3 , 2 )
66+ @test res1[2 ] ≈ fill (2.0 , 3 , 2 )
7867
7968 res1 = @jit (
8069 fwd (
8776
8877 @test typeof (res1) == Tuple{}
8978
90- ores1 = fwd (ForwardWithPrimal, Const, ones (3 , 2 ), 3.1 * ones (3 , 2 ))
91- @test typeof (ores1) == NamedTuple{(Symbol (" 1" ),),Tuple{Array{Float64,2 }}}
92-
9379 res1 = @jit (
9480 fwd (
9581 set_abi (ForwardWithPrimal, Reactant. ReactantABI),
10086 )
10187
10288 @test res1 isa Tuple{<: ConcreteRArray{Float64,2} }
103- @test res1[1 ] ≈ ores1[ 1 ]
89+ @test res1[1 ] ≈ fill ( 2.0 , 3 , 2 )
10490end
10591
10692function gw (z)
@@ -140,7 +126,7 @@ function cached_return(x, stret::StateReturn1)
140126end
141127
142128@testset " Cached Return: Issue #416" begin
143- x = rand ( 10 )
129+ x = Reactant . TestUtils . construct_test_array (Float64, 10 )
144130 x_ra = Reactant. to_rarray (x)
145131
146132 stret = StateReturn (nothing )
187173end
188174
189175@testset " onehot" begin
190- x = Reactant. to_rarray (rand (3 , 4 ))
176+ x = Reactant. to_rarray (ones (3 , 4 ))
191177 hlo = @code_hlo optimize = false Enzyme. onehot (x)
192178 @test ! contains (" stablehlo.constant" , repr (hlo))
193179end
@@ -202,27 +188,26 @@ end
202188 x = reshape (collect (Float32, 1 : 6 ), 3 , 2 )
203189 x_ra = Reactant. to_rarray (x)
204190 res = @jit vector_forward_ad (x_ra)
205- res_enz = vector_forward_ad (x)
206191
207192 @test x_ra ≈ x # See https://github.com/EnzymeAD/Reactant.jl/issues/1733
208- @test res[1 ][1 ] ≈ res_enz[ 1 ][ 1 ]
209- @test res[1 ][2 ] ≈ res_enz[ 1 ][ 2 ]
210- @test res[1 ][3 ] ≈ res_enz[ 1 ][ 3 ]
211- @test res[1 ][4 ] ≈ res_enz[ 1 ][ 4 ]
212- @test res[1 ][5 ] ≈ res_enz[ 1 ][ 5 ]
213- @test res[1 ][6 ] ≈ res_enz[ 1 ][ 6 ]
193+ @test res[1 ][1 ] ≈ 2
194+ @test res[1 ][2 ] ≈ 4
195+ @test res[1 ][3 ] ≈ 6
196+ @test res[1 ][4 ] ≈ 8
197+ @test res[1 ][5 ] ≈ 10
198+ @test res[1 ][6 ] ≈ 12
214199
215200 oh = Enzyme. onehot (x)
216201 oh_stacked = stack (oh)
217202 oh_ra = Reactant. to_rarray (oh_stacked)
218203 res2 = @jit vector_forward_ad2 (x_ra, oh_ra)
219204
220- @test res2[1 ][1 ] ≈ res_enz[ 1 ][ 1 ]
221- @test res2[1 ][2 ] ≈ res_enz[ 1 ][ 2 ]
222- @test res2[1 ][3 ] ≈ res_enz[ 1 ][ 3 ]
223- @test res2[1 ][4 ] ≈ res_enz[ 1 ][ 4 ]
224- @test res2[1 ][5 ] ≈ res_enz[ 1 ][ 5 ]
225- @test res2[1 ][6 ] ≈ res_enz[ 1 ][ 6 ]
205+ @test res2[1 ][1 ] ≈ 2
206+ @test res2[1 ][2 ] ≈ 4
207+ @test res2[1 ][3 ] ≈ 6
208+ @test res2[1 ][4 ] ≈ 8
209+ @test res2[1 ][5 ] ≈ 10
210+ @test res2[1 ][6 ] ≈ 12
226211end
227212
228213function fn2! (y, x)
@@ -245,34 +230,27 @@ end
245230 dx3_ra = Reactant. to_rarray (dx3)
246231 dx4_ra = Reactant. to_rarray (dx4)
247232
248- dy1 = zeros (2 )
249- dy2 = zeros (2 )
250- dy3 = zeros (2 )
251- dy4 = zeros (2 )
233+ dy1 = ones (2 ) .* 1
234+ dy2 = ones (2 ) .* 2
235+ dy3 = ones (2 ) .* 3
236+ dy4 = ones (2 ) .* 4
252237 dy1_ra = Reactant. to_rarray (dy1)
253238 dy2_ra = Reactant. to_rarray (dy2)
254239 dy3_ra = Reactant. to_rarray (dy3)
255240 dy4_ra = Reactant. to_rarray (dy4)
256241
257- autodiff (
258- ReverseWithPrimal,
259- fn2!,
260- BatchDuplicated (y, (dy1, dy2, dy3, dy4)),
261- BatchDuplicated (x, (dx1, dx2, dx3, dx4)),
262- )
263-
264242 @jit autodiff (
265243 Reverse,
266244 fn2!,
267245 BatchDuplicated (y_ra, (dy1_ra, dy2_ra, dy3_ra, dy4_ra)),
268246 BatchDuplicated (x_ra, (dx1_ra, dx2_ra, dx3_ra, dx4_ra)),
269247 )
270248
271- @test y ≈ y_ra
272- @test dy1 ≈ dy1_ra
273- @test dy2 ≈ dy2_ra
274- @test dy3 ≈ dy3_ra
275- @test dy4 ≈ dy4_ra
249+ @test y_ra ≈ x .^ 2
250+ @test dx1_ra ≈ 2 .* x .* dy1
251+ @test dx2_ra ≈ 2 .* x .* dy2
252+ @test dx3_ra ≈ 2 .* x .* dy3
253+ @test dx4_ra ≈ 2 .* x .* dy4
276254end
277255
278256@testset " make_zero!" begin
@@ -300,7 +278,7 @@ function gradient_fn(x, st)
300278end
301279
302280@testset " seed" begin
303- x = Reactant. to_rarray (rand ( 2 , 2 ))
281+ x = Reactant. to_rarray (Reactant . TestUtils . construct_test_array (Float64, 2 , 2 ))
304282 st = (; rng= Reactant. ReactantRNG ())
305283
306284 @test begin
@@ -344,7 +322,7 @@ function zero_grad2(x)
344322end
345323
346324@testset " ignore_derivatives" begin
347- x = Reactant. to_rarray (rand (Float32, 4 , 4 ))
325+ x = Reactant. to_rarray (Reactant . TestUtils . construct_test_array (Float32, 4 , 4 ))
348326
349327 res1 = @jit Enzyme. gradient (Reverse, simple_grad_without_ignore, x)
350328 @test res1[1 ] ≈ (2 .* Array (x) .+ 4 )
0 commit comments