Skip to content

Commit f70db12

Browse files
authored
test: remove randomness from testing (#1808)
* test: remove randomness from testing * test: more randomness + remove enzymellvm from testing * test: more * test: more * test: remove enzyme testing with finite diff * test: revert linalg
1 parent e5955c8 commit f70db12

29 files changed

+589
-626
lines changed

src/Reactant.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -351,6 +351,9 @@ function set_default_backend(backend::Union{String,XLA.AbstractClient})
351351
return nothing
352352
end
353353

354+
# Not part of the public API. Exclusively for testing purposes.
355+
include("TestUtils.jl")
356+
354357
include("Precompile.jl")
355358

356359
end # module

src/TestUtils.jl

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
module TestUtils
2+
3+
using ..Reactant: Reactant, TracedRArray
4+
using ReactantCore: ReactantCore
5+
using LinearAlgebra: LinearAlgebra
6+
7+
function construct_test_array(::Type{T}, dims::Int...) where {T<:AbstractFloat}
8+
flat_vector = collect(T, 1:prod(dims))
9+
flat_vector ./= prod(dims)
10+
return reshape(flat_vector, dims...)
11+
end
12+
13+
function construct_test_array(::Type{Complex{T}}, dims::Int...) where {T<:AbstractFloat}
14+
flat_vector = collect(T, 1:prod(dims))
15+
flat_vector ./= prod(dims)
16+
return reshape(complex.(flat_vector, flat_vector), dims...)
17+
end
18+
19+
function construct_test_array(::Type{T}, dims::Int...) where {T}
20+
return reshape(collect(T, 1:prod(dims)), dims...)
21+
end
22+
23+
function finite_difference_gradient(
24+
f, x::AbstractArray{T}; epsilon=eps(T)^(3 / 4)
25+
) where {T}
26+
onehot_matrix = Reactant.promote_to(
27+
TracedRArray{Reactant.unwrapped_eltype(T),2}, LinearAlgebra.I(length(x))
28+
)
29+
perturbation = reshape(onehot_matrix .* epsilon, size(x)..., length(x))
30+
f_input = cat(x .+ perturbation, x .- perturbation; dims=ndims(x) + 1)
31+
32+
f_evaluated = mapslices(f, f_input; dims=ntuple(identity, ndims(x)))
33+
return ReactantCore.materialize_traced_array(
34+
reshape(
35+
(f_evaluated[1:length(x)] - f_evaluated[(length(x) + 1):end]) ./ (2 * epsilon),
36+
size(x),
37+
),
38+
)
39+
end
40+
41+
end

test/autodiff.jl

Lines changed: 29 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -40,10 +40,6 @@ fwd(Mode, RT, x, y) = Enzyme.autodiff(Mode, square, RT, Duplicated(x, y))
4040
end
4141

4242
@testset "Basic Forward Mode" begin
43-
ores1 = fwd(Forward, Duplicated, ones(3, 2), 3.1 * ones(3, 2))
44-
@test typeof(ores1) == NamedTuple{(Symbol("1"),),Tuple{Array{Float64,2}}}
45-
@test ores1[1] 6.2 * ones(3, 2)
46-
4743
res1 = @jit(
4844
fwd(
4945
Forward,
@@ -54,11 +50,7 @@ end
5450
)
5551

5652
@test res1 isa Tuple{<:ConcreteRArray{Float64,2}}
57-
@test res1[1] ores1[1]
58-
59-
ores1 = fwd(ForwardWithPrimal, Duplicated, ones(3, 2), 3.1 * ones(3, 2))
60-
@test typeof(ores1) ==
61-
NamedTuple{(Symbol("1"), Symbol("2")),Tuple{Array{Float64,2},Array{Float64,2}}}
53+
@test res1[1] fill(6.2, 3, 2)
6254

6355
res1 = @jit(
6456
fwd(
@@ -70,11 +62,8 @@ end
7062
)
7163

7264
@test res1 isa Tuple{<:ConcreteRArray{Float64,2},<:ConcreteRArray{Float64,2}}
73-
@test res1[1] ores1[1]
74-
@test res1[2] ores1[2]
75-
76-
ores1 = fwd(Forward, Const, ones(3, 2), 3.1 * ones(3, 2))
77-
@test typeof(ores1) == Tuple{}
65+
@test res1[1] fill(6.2, 3, 2)
66+
@test res1[2] fill(2.0, 3, 2)
7867

7968
res1 = @jit(
8069
fwd(
@@ -87,9 +76,6 @@ end
8776

8877
@test typeof(res1) == Tuple{}
8978

90-
ores1 = fwd(ForwardWithPrimal, Const, ones(3, 2), 3.1 * ones(3, 2))
91-
@test typeof(ores1) == NamedTuple{(Symbol("1"),),Tuple{Array{Float64,2}}}
92-
9379
res1 = @jit(
9480
fwd(
9581
set_abi(ForwardWithPrimal, Reactant.ReactantABI),
@@ -100,7 +86,7 @@ end
10086
)
10187

10288
@test res1 isa Tuple{<:ConcreteRArray{Float64,2}}
103-
@test res1[1] ores1[1]
89+
@test res1[1] fill(2.0, 3, 2)
10490
end
10591

10692
function gw(z)
@@ -140,7 +126,7 @@ function cached_return(x, stret::StateReturn1)
140126
end
141127

142128
@testset "Cached Return: Issue #416" begin
143-
x = rand(10)
129+
x = Reactant.TestUtils.construct_test_array(Float64, 10)
144130
x_ra = Reactant.to_rarray(x)
145131

146132
stret = StateReturn(nothing)
@@ -187,7 +173,7 @@ end
187173
end
188174

189175
@testset "onehot" begin
190-
x = Reactant.to_rarray(rand(3, 4))
176+
x = Reactant.to_rarray(ones(3, 4))
191177
hlo = @code_hlo optimize = false Enzyme.onehot(x)
192178
@test !contains("stablehlo.constant", repr(hlo))
193179
end
@@ -202,27 +188,26 @@ end
202188
x = reshape(collect(Float32, 1:6), 3, 2)
203189
x_ra = Reactant.to_rarray(x)
204190
res = @jit vector_forward_ad(x_ra)
205-
res_enz = vector_forward_ad(x)
206191

207192
@test x_ra x # See https://github.com/EnzymeAD/Reactant.jl/issues/1733
208-
@test res[1][1] res_enz[1][1]
209-
@test res[1][2] res_enz[1][2]
210-
@test res[1][3] res_enz[1][3]
211-
@test res[1][4] res_enz[1][4]
212-
@test res[1][5] res_enz[1][5]
213-
@test res[1][6] res_enz[1][6]
193+
@test res[1][1] 2
194+
@test res[1][2] 4
195+
@test res[1][3] 6
196+
@test res[1][4] 8
197+
@test res[1][5] 10
198+
@test res[1][6] 12
214199

215200
oh = Enzyme.onehot(x)
216201
oh_stacked = stack(oh)
217202
oh_ra = Reactant.to_rarray(oh_stacked)
218203
res2 = @jit vector_forward_ad2(x_ra, oh_ra)
219204

220-
@test res2[1][1] res_enz[1][1]
221-
@test res2[1][2] res_enz[1][2]
222-
@test res2[1][3] res_enz[1][3]
223-
@test res2[1][4] res_enz[1][4]
224-
@test res2[1][5] res_enz[1][5]
225-
@test res2[1][6] res_enz[1][6]
205+
@test res2[1][1] 2
206+
@test res2[1][2] 4
207+
@test res2[1][3] 6
208+
@test res2[1][4] 8
209+
@test res2[1][5] 10
210+
@test res2[1][6] 12
226211
end
227212

228213
function fn2!(y, x)
@@ -245,34 +230,27 @@ end
245230
dx3_ra = Reactant.to_rarray(dx3)
246231
dx4_ra = Reactant.to_rarray(dx4)
247232

248-
dy1 = zeros(2)
249-
dy2 = zeros(2)
250-
dy3 = zeros(2)
251-
dy4 = zeros(2)
233+
dy1 = ones(2) .* 1
234+
dy2 = ones(2) .* 2
235+
dy3 = ones(2) .* 3
236+
dy4 = ones(2) .* 4
252237
dy1_ra = Reactant.to_rarray(dy1)
253238
dy2_ra = Reactant.to_rarray(dy2)
254239
dy3_ra = Reactant.to_rarray(dy3)
255240
dy4_ra = Reactant.to_rarray(dy4)
256241

257-
autodiff(
258-
ReverseWithPrimal,
259-
fn2!,
260-
BatchDuplicated(y, (dy1, dy2, dy3, dy4)),
261-
BatchDuplicated(x, (dx1, dx2, dx3, dx4)),
262-
)
263-
264242
@jit autodiff(
265243
Reverse,
266244
fn2!,
267245
BatchDuplicated(y_ra, (dy1_ra, dy2_ra, dy3_ra, dy4_ra)),
268246
BatchDuplicated(x_ra, (dx1_ra, dx2_ra, dx3_ra, dx4_ra)),
269247
)
270248

271-
@test y y_ra
272-
@test dy1 dy1_ra
273-
@test dy2 dy2_ra
274-
@test dy3 dy3_ra
275-
@test dy4 dy4_ra
249+
@test y_ra x .^ 2
250+
@test dx1_ra 2 .* x .* dy1
251+
@test dx2_ra 2 .* x .* dy2
252+
@test dx3_ra 2 .* x .* dy3
253+
@test dx4_ra 2 .* x .* dy4
276254
end
277255

278256
@testset "make_zero!" begin
@@ -300,7 +278,7 @@ function gradient_fn(x, st)
300278
end
301279

302280
@testset "seed" begin
303-
x = Reactant.to_rarray(rand(2, 2))
281+
x = Reactant.to_rarray(Reactant.TestUtils.construct_test_array(Float64, 2, 2))
304282
st = (; rng=Reactant.ReactantRNG())
305283

306284
@test begin
@@ -344,7 +322,7 @@ function zero_grad2(x)
344322
end
345323

346324
@testset "ignore_derivatives" begin
347-
x = Reactant.to_rarray(rand(Float32, 4, 4))
325+
x = Reactant.to_rarray(Reactant.TestUtils.construct_test_array(Float32, 4, 4))
348326

349327
res1 = @jit Enzyme.gradient(Reverse, simple_grad_without_ignore, x)
350328
@test res1[1] (2 .* Array(x) .+ 4)

0 commit comments

Comments
 (0)