From c9015fa07c73493a4c5a009e9722b20cb493ffa3 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Wed, 13 Apr 2022 15:08:40 +0200 Subject: [PATCH 01/13] =?UTF-8?q?Fix=20derivative=20of=20`=5Fget=5F=CE=BD`?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: willtebbutt --- src/basekernels/matern.jl | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/src/basekernels/matern.jl b/src/basekernels/matern.jl index dcc35f6be..7be417ab2 100644 --- a/src/basekernels/matern.jl +++ b/src/basekernels/matern.jl @@ -37,8 +37,16 @@ MaternKernel(; nu::Real=1.5, ν::Real=nu, metric=Euclidean()) = MaternKernel(ν, @functor MaternKernel +# workaround for Zygote +# unclear why it's needed but it is fine since it's stated officially that we don't support differentiation with respect to ν @inline _get_ν(k::MaternKernel) = only(k.ν) -ChainRulesCore.@non_differentiable _get_ν(k) # work-around; should be "NotImplemented" rather than NoTangent +function ChainRulesCore.rrule(::typeof(_get_ν), k::T) where {T<:MaternKernel} + function _get_ν_pullback(Δ) + dν = ChainRulesCore.@not_implemented("derivatives of `MaternKernel` w.r.t. order `ν` are not implemented.") + return Tangent{T}(ν=dν, metric=NoTangent()) + end + return _get_ν(k), _get_ν_pullback +end @inline function kappa(k::MaternKernel, d::Real) result = _matern(_get_ν(k), d) From c4d23e812ec0910cf4285a03aec9863643d17ec0 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Wed, 13 Apr 2022 15:09:33 +0200 Subject: [PATCH 02/13] Update Project.toml --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 47ee0d086..88c97aed8 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "KernelFunctions" uuid = "ec8451be-7e33-11e9-00cf-bbf324bd1392" -version = "0.10.37" +version = "0.10.38" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" From 08437ccc44fa134b86eb447f59d0bdfe651d6acd Mon Sep 17 00:00:00 2001 From: David Widmann Date: Wed, 13 Apr 2022 15:12:11 +0200 Subject: [PATCH 03/13] Update src/basekernels/matern.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/basekernels/matern.jl | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/basekernels/matern.jl b/src/basekernels/matern.jl index 7be417ab2..905b71940 100644 --- a/src/basekernels/matern.jl +++ b/src/basekernels/matern.jl @@ -42,8 +42,10 @@ MaternKernel(; nu::Real=1.5, ν::Real=nu, metric=Euclidean()) = MaternKernel(ν, @inline _get_ν(k::MaternKernel) = only(k.ν) function ChainRulesCore.rrule(::typeof(_get_ν), k::T) where {T<:MaternKernel} function _get_ν_pullback(Δ) - dν = ChainRulesCore.@not_implemented("derivatives of `MaternKernel` w.r.t. order `ν` are not implemented.") - return Tangent{T}(ν=dν, metric=NoTangent()) + dν = ChainRulesCore.@not_implemented( + "derivatives of `MaternKernel` w.r.t. order `ν` are not implemented." + ) + return Tangent{T}(; ν=dν, metric=NoTangent()) end return _get_ν(k), _get_ν_pullback end From 8e4ad8f917bca215ac0b99530d4cadd6bef38885 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Wed, 13 Apr 2022 15:52:02 +0200 Subject: [PATCH 04/13] Fix `rrule` --- src/basekernels/matern.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/basekernels/matern.jl b/src/basekernels/matern.jl index 905b71940..425da0ebc 100644 --- a/src/basekernels/matern.jl +++ b/src/basekernels/matern.jl @@ -45,7 +45,7 @@ function ChainRulesCore.rrule(::typeof(_get_ν), k::T) where {T<:MaternKernel} dν = ChainRulesCore.@not_implemented( "derivatives of `MaternKernel` w.r.t. order `ν` are not implemented." ) - return Tangent{T}(; ν=dν, metric=NoTangent()) + return NoTangent(), Tangent{T}(; ν=dν, metric=NoTangent()) end return _get_ν(k), _get_ν_pullback end From 9bb7beb0f35204f3fba97b41dfaf0b74ff9446d4 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Wed, 13 Apr 2022 15:56:06 +0200 Subject: [PATCH 05/13] Add test --- test/basekernels/matern.jl | 3 +++ 1 file changed, 3 insertions(+) diff --git a/test/basekernels/matern.jl b/test/basekernels/matern.jl index 025cb141b..40111fdac 100644 --- a/test/basekernels/matern.jl +++ b/test/basekernels/matern.jl @@ -18,6 +18,9 @@ @test metric(k2) isa WeightedEuclidean @test k2(v1, v2) ≈ k(v1, v2) + # Test custom `rrule` (Zygote workaround). + test_rrule(KernelFunctions._get_ν, MaternKernel(; nu=rand())) + # Standardised tests. TestUtils.test_interface(k, Float64) test_ADs(() -> MaternKernel(; nu=ν)) From 6a467f9119c1ba83834dbca163ac984c50efadd1 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Wed, 13 Apr 2022 16:15:27 +0200 Subject: [PATCH 06/13] Load ChainRulesTestUtils --- test/Project.toml | 2 ++ test/runtests.jl | 1 + 2 files changed, 3 insertions(+) diff --git a/test/Project.toml b/test/Project.toml index d71c3dfc0..0f476ded2 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,5 +1,6 @@ [deps] AxisArrays = "39de3d68-74b9-583c-8d2d-e117c070f3a9" +ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" @@ -19,6 +20,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] AxisArrays = "0.4.3" +ChainRulesTestUtils = "1.7" Compat = "3" Distances = "0.10" Documenter = "0.25, 0.26, 0.27" diff --git a/test/runtests.jl b/test/runtests.jl index a1f5c395a..8f58f0e40 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,5 +1,6 @@ using KernelFunctions using AxisArrays +using ChainRulesTestUtils using Distances using Documenter using Functors: functor From b95a84a881433e8633f5c28bcb285d4e543f78b9 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Wed, 13 Apr 2022 16:43:14 +0200 Subject: [PATCH 07/13] Try to specify cotangent --- test/basekernels/matern.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/basekernels/matern.jl b/test/basekernels/matern.jl index 40111fdac..4ce120522 100644 --- a/test/basekernels/matern.jl +++ b/test/basekernels/matern.jl @@ -19,7 +19,8 @@ @test k2(v1, v2) ≈ k(v1, v2) # Test custom `rrule` (Zygote workaround). - test_rrule(KernelFunctions._get_ν, MaternKernel(; nu=rand())) + k = MaternKernel(; ν=rand()) + test_rrule(KernelFunctions._get_ν, k ⊢ Tangent{typeof(k)}(; ν=randn(), metric=NoTangent())) # Standardised tests. TestUtils.test_interface(k, Float64) From c747f706ae725723a7463c1d4d911eec5faa8a4c Mon Sep 17 00:00:00 2001 From: David Widmann Date: Wed, 13 Apr 2022 16:44:50 +0200 Subject: [PATCH 08/13] Update test/basekernels/matern.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- test/basekernels/matern.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/test/basekernels/matern.jl b/test/basekernels/matern.jl index 4ce120522..0b568a812 100644 --- a/test/basekernels/matern.jl +++ b/test/basekernels/matern.jl @@ -20,7 +20,9 @@ # Test custom `rrule` (Zygote workaround). k = MaternKernel(; ν=rand()) - test_rrule(KernelFunctions._get_ν, k ⊢ Tangent{typeof(k)}(; ν=randn(), metric=NoTangent())) + test_rrule( + KernelFunctions._get_ν, k ⊢ Tangent{typeof(k)}(; ν=randn(), metric=NoTangent()) + ) # Standardised tests. TestUtils.test_interface(k, Float64) From aadcfa86c11e32c11fd86346a88d9be5f386f0af Mon Sep 17 00:00:00 2001 From: David Widmann Date: Wed, 13 Apr 2022 17:01:14 +0200 Subject: [PATCH 09/13] Update matern.jl --- test/basekernels/matern.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/basekernels/matern.jl b/test/basekernels/matern.jl index 0b568a812..6d8b38bba 100644 --- a/test/basekernels/matern.jl +++ b/test/basekernels/matern.jl @@ -21,7 +21,7 @@ # Test custom `rrule` (Zygote workaround). k = MaternKernel(; ν=rand()) test_rrule( - KernelFunctions._get_ν, k ⊢ Tangent{typeof(k)}(; ν=randn(), metric=NoTangent()) + KernelFunctions._get_ν, k ⊢ ChainRulesTestUtils.Tangent{typeof(k)}(; ν=randn(), metric=ChainRulesTestUtils.NoTangent()) ) # Standardised tests. From d87cc841559940fa58f25f3e541836fa16f1f276 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Wed, 13 Apr 2022 17:03:04 +0200 Subject: [PATCH 10/13] Update test/basekernels/matern.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- test/basekernels/matern.jl | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/test/basekernels/matern.jl b/test/basekernels/matern.jl index 6d8b38bba..abeb8404a 100644 --- a/test/basekernels/matern.jl +++ b/test/basekernels/matern.jl @@ -21,7 +21,10 @@ # Test custom `rrule` (Zygote workaround). k = MaternKernel(; ν=rand()) test_rrule( - KernelFunctions._get_ν, k ⊢ ChainRulesTestUtils.Tangent{typeof(k)}(; ν=randn(), metric=ChainRulesTestUtils.NoTangent()) + KernelFunctions._get_ν, + k ⊢ ChainRulesTestUtils.Tangent{typeof(k)}(; + ν=randn(), metric=ChainRulesTestUtils.NoTangent() + ), ) # Standardised tests. From e7f658b8afff1a80638a7a7a88efde95916cf179 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Thu, 14 Apr 2022 01:57:17 +0200 Subject: [PATCH 11/13] Remove workaround --- src/basekernels/matern.jl | 15 +-------------- test/Project.toml | 4 +--- test/basekernels/matern.jl | 9 --------- test/runtests.jl | 1 - 4 files changed, 2 insertions(+), 27 deletions(-) diff --git a/src/basekernels/matern.jl b/src/basekernels/matern.jl index 425da0ebc..c050cad0d 100644 --- a/src/basekernels/matern.jl +++ b/src/basekernels/matern.jl @@ -37,21 +37,8 @@ MaternKernel(; nu::Real=1.5, ν::Real=nu, metric=Euclidean()) = MaternKernel(ν, @functor MaternKernel -# workaround for Zygote -# unclear why it's needed but it is fine since it's stated officially that we don't support differentiation with respect to ν -@inline _get_ν(k::MaternKernel) = only(k.ν) -function ChainRulesCore.rrule(::typeof(_get_ν), k::T) where {T<:MaternKernel} - function _get_ν_pullback(Δ) - dν = ChainRulesCore.@not_implemented( - "derivatives of `MaternKernel` w.r.t. order `ν` are not implemented." - ) - return NoTangent(), Tangent{T}(; ν=dν, metric=NoTangent()) - end - return _get_ν(k), _get_ν_pullback -end - @inline function kappa(k::MaternKernel, d::Real) - result = _matern(_get_ν(k), d) + result = _matern(only(k.ν), d) return ifelse(iszero(d), one(result), result) end diff --git a/test/Project.toml b/test/Project.toml index 0f476ded2..8860a9fb2 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,6 +1,5 @@ [deps] AxisArrays = "39de3d68-74b9-583c-8d2d-e117c070f3a9" -ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" @@ -20,7 +19,6 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] AxisArrays = "0.4.3" -ChainRulesTestUtils = "1.7" Compat = "3" Distances = "0.10" Documenter = "0.25, 0.26, 0.27" @@ -32,4 +30,4 @@ LogExpFunctions = "0.2, 0.3" PDMats = "0.9, 0.10, 0.11" ReverseDiff = "1.2" SpecialFunctions = "0.10, 1, 2" -Zygote = "0.4, 0.5, 0.6" +Zygote = "0.6.38" diff --git a/test/basekernels/matern.jl b/test/basekernels/matern.jl index abeb8404a..025cb141b 100644 --- a/test/basekernels/matern.jl +++ b/test/basekernels/matern.jl @@ -18,15 +18,6 @@ @test metric(k2) isa WeightedEuclidean @test k2(v1, v2) ≈ k(v1, v2) - # Test custom `rrule` (Zygote workaround). - k = MaternKernel(; ν=rand()) - test_rrule( - KernelFunctions._get_ν, - k ⊢ ChainRulesTestUtils.Tangent{typeof(k)}(; - ν=randn(), metric=ChainRulesTestUtils.NoTangent() - ), - ) - # Standardised tests. TestUtils.test_interface(k, Float64) test_ADs(() -> MaternKernel(; nu=ν)) diff --git a/test/runtests.jl b/test/runtests.jl index 8f58f0e40..a1f5c395a 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,6 +1,5 @@ using KernelFunctions using AxisArrays -using ChainRulesTestUtils using Distances using Documenter using Functors: functor From 84733860ce67012f0c7413b1b05511737dcf54f0 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Tue, 19 Apr 2022 18:54:08 +0200 Subject: [PATCH 12/13] Update Project.toml --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 88c97aed8..48646036b 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "KernelFunctions" uuid = "ec8451be-7e33-11e9-00cf-bbf324bd1392" -version = "0.10.38" +version = "0.10.39" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" From 4c45fe8f87da9afc70d7847294c9261064af5a19 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Tue, 19 Apr 2022 19:53:13 +0200 Subject: [PATCH 13/13] Fix test_interface --- src/TestUtils.jl | 33 ++++++++++++++------------------- 1 file changed, 14 insertions(+), 19 deletions(-) diff --git a/src/TestUtils.jl b/src/TestUtils.jl index 26b4fe467..b5de47f33 100644 --- a/src/TestUtils.jl +++ b/src/TestUtils.jl @@ -6,20 +6,14 @@ using KernelFunctions using Random using Test -# default tolerance values for test_interface: -const __ATOL = sqrt(eps(Float64)) -const __RTOL = sqrt(eps(Float64)) -# ≈ 1.5e-8; chosen for no particular reason other than because it seems to -# satisfy our own test cases within KernelFunctions.jl - """ test_interface( k::Kernel, x0::AbstractVector, x1::AbstractVector, x2::AbstractVector; - atol=__ATOL, - rtol=__RTOL, + rtol=1e-6, + atol=rtol, ) Run various consistency checks on `k` at the inputs `x0`, `x1`, and `x2`. @@ -29,22 +23,14 @@ be of different lengths. These tests are intended to pick up on really substantial issues with a kernel implementation (e.g. substantial asymmetry in the kernel matrix, large negative eigenvalues), rather than to test the numerics in detail, which can be kernel-specific. -The default value of `__ATOL` and `__RTOL` is `sqrt(eps(Float64)) ≈ 1.5e-8`, which satisfied -this intention in the cases tested within KernelFunctions.jl itself. - - test_interface([rng::AbstractRNG], k::Kernel, T::Type{<:Real}; atol=__ATOL, rtol=__RTOL) - -`test_interface` offers automated test data generation for kernels whose inputs are reals. -This will run the tests for `Vector{T}`, `Vector{Vector{T}}`, `ColVecs{T}`, and `RowVecs{T}`. -For other input vector types, please provide the data manually. """ function test_interface( k::Kernel, x0::AbstractVector, x1::AbstractVector, x2::AbstractVector; - atol=__ATOL, - rtol=__RTOL, + rtol=1e-6, + atol=rtol, ) # Ensure that we have the required inputs. @assert length(x0) == length(x1) @@ -160,7 +146,16 @@ function test_interface(k::Kernel, T::Type{<:AbstractVector}; kwargs...) return test_interface(Random.GLOBAL_RNG, k, T; kwargs...) end -function test_interface(rng::AbstractRNG, k::Kernel, T::Type{<:Real}; kwargs...) +""" + test_interface([rng::AbstractRNG], k::Kernel, ::Type{T}; kwargs...) where {T<:Real} + +Run the [`test_interface`](@ref) tests for randomly generated inputs of types `Vector{T}`, `Vector{Vector{T}}`, `ColVecs{T}`, and `RowVecs{T}`. + +For other input types, please provide the data manually. + +The keyword arguments are forwarded to the invocations of [`test_interface`](@ref) with the randomly generated inputs. +""" +function test_interface(rng::AbstractRNG, k::Kernel, ::Type{T}; kwargs...) where {T<:Real} @testset "Vector{$T}" begin test_interface(rng, k, Vector{T}; kwargs...) end