@@ -31,6 +31,15 @@ JLArrays.allowscalar(false)
3131 @test gradient .(Ref (esitp), idx) ==
3232 collect (gradient .(Ref (jlesitp), idx)) ==
3333 collect (gradient .(Ref (jlesitp), jlidx))
34+
35+ esitp = extrapolate (sitp, 0.0 )
36+ jlesitp = jl (esitp)
37+ idx = - 1.0 : 0.84 : 41.0
38+ jlidx = jl (collect (idx))
39+ @test esitp .(idx) == collect (jlesitp .(idx)) == collect (jlesitp .(jlidx))
40+ @test gradient .(Ref (esitp), idx) ==
41+ collect (gradient .(Ref (jlesitp), idx)) ==
42+ collect (gradient .(Ref (jlesitp), jlidx))
3443end
3544
3645@testset " 2d GPU Interpolation" begin
6978 @test gradient .(Ref (esitp), idx, idx' ) ==
7079 collect (gradient .(Ref (jlesitp), idx, idx' )) ==
7180 collect (gradient .(Ref (jlesitp), jlidx, jlidx' ))
81+
82+ esitp = extrapolate (sitp, 0.0 )
83+ jlesitp = jl (esitp)
84+ idx = - 1.0 : 0.84 : 41.0
85+ jlidx = jl (collect (idx))
86+ @test esitp .(idx, idx' ) == collect (jlesitp .(idx, idx' )) == collect (jlesitp .(jlidx, jlidx' ))
87+ # gradient for `extrapolation` is currently broken under CUDA
88+ @test gradient .(Ref (esitp), idx, idx' ) ==
89+ collect (gradient .(Ref (jlesitp), idx, idx' )) ==
90+ collect (gradient .(Ref (jlesitp), jlidx, jlidx' ))
7291end
7392
7493@testset " Lanczos on gpu" begin
99118 @test eltype (adapt (Array{Real}, itp)) === Float64
100119 @test eltype (adapt (Array{Float32}, scale (itp, A_x))) === Float32
101120 @test eltype (adapt (Array{Float32}, extrapolate (scale (itp, A_x), Flat ()))) === Float32
121+ @test eltype (adapt (Array{Float32}, extrapolate (scale (itp, A_x), 0.0 ))) === Float32
102122 itp = interpolate ((- 1 : 0.2 : 1 , - 1 : 0.2 : 1 ), randn (11 , 11 ), Gridded (Linear ()))
103123 @test eltype (adapt (Array{Float32}, itp)) === Float32
104124 itp = interpolate ((1.0 : 0.0 , 1. :0. ), randn (0 , 0 ), Gridded (Linear ()))
0 commit comments