@@ -69,6 +69,10 @@ function gradient(f, ::Val{:FiniteDiff}, args)
6969 return first (FiniteDifferences. grad (FDM, f, args))
7070end
7171
72+ function compare_gradient (f, :: Val{:FiniteDiff} , args)
73+ @test_nowarn gradient (f, :FiniteDiff , args)
74+ end
75+
7276function compare_gradient (f, AD:: Symbol , args)
7377 grad_AD = gradient (f, AD, args)
7478 grad_FD = gradient (f, :FiniteDiff , args)
@@ -88,7 +92,7 @@ testdiagfunction(k::MOKernel, A, B) = sum(kernelmatrix_diag(k, A, B))
8892function test_ADs (
8993 kernelfunction, args= nothing ; ADs= [:Zygote , :ForwardDiff , :ReverseDiff ], dims= [3 , 3 ]
9094)
91- test_fd = test_FiniteDiff ( kernelfunction, args, dims)
95+ test_fd = test_AD ( :FiniteDiff , kernelfunction, args, dims)
9296 if ! test_fd. anynonpass
9397 for AD in ADs
9498 test_AD (AD, kernelfunction, args, dims)
@@ -100,7 +104,7 @@ function check_zygote_type_stability(f, args...; ctx=Zygote.Context())
100104 @inferred f (args... )
101105 @inferred Zygote. _pullback (ctx, f, args... )
102106 out, pb = Zygote. _pullback (ctx, f, args... )
103- @test_throws ErrorException @ inferred pb (out)
107+ @inferred pb (out)
104108end
105109
106110function test_ADs (
@@ -114,70 +118,6 @@ function test_ADs(
114118 end
115119end
116120
117- function test_FiniteDiff (kernelfunction, args= nothing , dims= [3 , 3 ])
118- # Init arguments :
119- k = if args === nothing
120- kernelfunction ()
121- else
122- kernelfunction (args)
123- end
124- rng = MersenneTwister (42 )
125- @testset " FiniteDifferences" begin
126- if k isa SimpleKernel
127- for d in log .([eps (), rand (rng)])
128- @test_nowarn gradient (:FiniteDiff , [d]) do x
129- kappa (k, exp (first (x)))
130- end
131- end
132- end
133- # # Testing Kernel Functions
134- x = rand (rng, dims[1 ])
135- y = rand (rng, dims[1 ])
136- @test_nowarn gradient (:FiniteDiff , x) do x
137- k (x, y)
138- end
139- if ! (args === nothing )
140- @test_nowarn gradient (:FiniteDiff , args) do p
141- kernelfunction (p)(x, y)
142- end
143- end
144- # # Testing Kernel Matrices
145- A = rand (rng, dims... )
146- B = rand (rng, dims... )
147- for dim in 1 : 2
148- @test_nowarn gradient (:FiniteDiff , A) do a
149- testfunction (k, a, dim)
150- end
151- @test_nowarn gradient (:FiniteDiff , A) do a
152- testfunction (k, a, B, dim)
153- end
154- @test_nowarn gradient (:FiniteDiff , B) do b
155- testfunction (k, A, b, dim)
156- end
157- if ! (args === nothing )
158- @test_nowarn gradient (:FiniteDiff , args) do p
159- testfunction (kernelfunction (p), A, B, dim)
160- end
161- end
162-
163- @test_nowarn gradient (:FiniteDiff , A) do a
164- testdiagfunction (k, a, dim)
165- end
166- @test_nowarn gradient (:FiniteDiff , A) do a
167- testdiagfunction (k, a, B, dim)
168- end
169- @test_nowarn gradient (:FiniteDiff , B) do b
170- testdiagfunction (k, A, b, dim)
171- end
172- if args != = nothing
173- @test_nowarn gradient (:FiniteDiff , args) do p
174- testdiagfunction (kernelfunction (p), A, B, dim)
175- end
176- end
177- end
178- end
179- end
180-
181121function test_FiniteDiff (k:: MOKernel , dims= (in= 3 , out= 2 , obs= 3 ))
182122 rng = MersenneTwister (42 )
183123 @testset " FiniteDifferences" begin
@@ -224,68 +164,68 @@ end
224164
225165function test_AD (AD:: Symbol , kernelfunction, args= nothing , dims= [3 , 3 ])
226166 @testset " $(AD) " begin
227- # Test kappa function
228167 k = if args === nothing
229168 kernelfunction ()
230169 else
231170 kernelfunction (args)
232171 end
233172 rng = MersenneTwister (42 )
173+
234174 if k isa SimpleKernel
235- for d in log .([eps (), rand (rng)])
236- compare_gradient (AD, [d]) do x
237- kappa (k, exp (x[1 ]))
175+ @testset " kappa function" begin
176+ for d in log .([eps (), rand (rng)])
177+ compare_gradient (AD, [d]) do x
178+ kappa (k, exp (x[1 ]))
179+ end
238180 end
239181 end
240182 end
241- # Testing kernel evaluations
242- x = rand (rng, dims[1 ])
243- y = rand (rng, dims[1 ])
244- compare_gradient (AD, x) do x
245- k (x, y)
246- end
247- compare_gradient (AD, y) do y
248- k (x, y)
249- end
250- if ! (args === nothing )
251- compare_gradient (AD, args) do p
252- kernelfunction (p)(x, y)
253- end
254- end
255- # Testing kernel matrices
256- A = rand (rng, dims... )
257- B = rand (rng, dims... )
258- for dim in 1 : 2
259- compare_gradient (AD, A) do a
260- testfunction (k, a, dim)
261- end
262- compare_gradient (AD, A) do a
263- testfunction (k, a, B, dim)
183+
184+ @testset " kernel evaluations" begin
185+ x = rand (rng, dims[1 ])
186+ y = rand (rng, dims[1 ])
187+ compare_gradient (AD, x) do x
188+ k (x, y)
264189 end
265- compare_gradient (AD, B ) do b
266- testfunction (k, A, b, dim )
190+ compare_gradient (AD, y ) do y
191+ k (x, y )
267192 end
268193 if ! (args === nothing )
269- compare_gradient (AD, args) do p
270- testfunction (kernelfunction (p), A, dim)
194+ @testset " hyperparameters" begin
195+ compare_gradient (AD, args) do p
196+ kernelfunction (p)(x, y)
197+ end
271198 end
272199 end
200+ end
273201
274- compare_gradient (AD, A) do a
275- testdiagfunction (k, a, dim)
276- end
277- compare_gradient (AD, A) do a
278- testdiagfunction (k, a, B, dim)
279- end
280- compare_gradient (AD, B) do b
281- testdiagfunction (k, A, b, dim)
282- end
283- if args != = nothing
284- compare_gradient (AD, args) do p
285- testdiagfunction (kernelfunction (p), A, dim)
202+ @testset " kernel matrices" begin
203+ A = rand (rng, dims... )
204+ B = rand (rng, dims... )
205+ @testset " $(_testfn) " for _testfn in (testfunction, testdiagfunction)
206+ for dim in 1 : 2
207+ compare_gradient (AD, A) do a
208+ _testfn (k, a, dim)
209+ end
210+ compare_gradient (AD, A) do a
211+ _testfn (k, a, B, dim)
212+ end
213+ compare_gradient (AD, B) do b
214+ _testfn (k, A, b, dim)
215+ end
216+ if ! (args === nothing )
217+ @testset " hyperparameters" begin
218+ compare_gradient (AD, args) do p
219+ _testfn (kernelfunction (p), A, dim)
220+ end
221+ compare_gradient (AD, args) do p
222+ _testfn (kernelfunction (p), A, B, dim)
223+ end
224+ end
225+ end
286226 end
287227 end
288- end
228+ end # kernel matrices
289229 end
290230end
291231
0 commit comments