Skip to content

Commit 85d3ccd

Browse files
authored
Gpu support for FilledExtrapolation (#541)
1 parent 1fd5768 commit 85d3ccd

File tree

3 files changed

+28
-1
lines changed

3 files changed

+28
-1
lines changed

src/extrapolation/filled.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
mutable struct FilledExtrapolation{T,N,ITP<:AbstractInterpolation,IT,FT} <: AbstractExtrapolation{T,N,ITP,IT}
1+
struct FilledExtrapolation{T,N,ITP<:AbstractInterpolation,IT,FT} <: AbstractExtrapolation{T,N,ITP,IT}
22
itp::ITP
33
fillvalue::FT
44
end

src/gpu_support.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,12 @@ function adapt_structure(to, itp::Extrapolation{T,N}) where {T,N}
4141
Extrapolation{eltype(itp′),N,typeof(itp′),itptype(itp),typeof(et)}(itp′, et)
4242
end
4343

44+
function adapt_structure(to, itp::FilledExtrapolation{T,N}) where {T,N}
45+
fillvalue = itp.fillvalue
46+
itp′ = adapt(to, itp.itp)
47+
FilledExtrapolation{eltype(itp′),N,typeof(itp′),itptype(itp),typeof(fillvalue)}(itp′, fillvalue)
48+
end
49+
4450
import Base.Broadcast: broadcasted, BroadcastStyle
4551
using Base.Broadcast: broadcastable, combine_styles, AbstractArrayStyle
4652
function broadcasted(itp::AbstractInterpolation, args...)
@@ -58,6 +64,7 @@ Some array wrappers, like `OffsetArray`, should be skipped.
5864
"""
5965
root_storage_type(::Type{T}) where {T<:AbstractInterpolation} = Array{eltype(T),ndims(T)} # fallback to `Array` by default.
6066
root_storage_type(::Type{T}) where {T<:Extrapolation} = root_storage_type(fieldtype(T, 1))
67+
root_storage_type(::Type{T}) where {T<:FilledExtrapolation} = root_storage_type(fieldtype(T, 1))
6168
root_storage_type(::Type{T}) where {T<:ScaledInterpolation} = root_storage_type(fieldtype(T, 1))
6269
root_storage_type(::Type{T}) where {T<:BSplineInterpolation} = root_storage_type(fieldtype(T, 1))
6370
root_storage_type(::Type{T}) where {T<:LanczosInterpolation} = root_storage_type(fieldtype(T, 1))

test/gpu_support.jl

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,15 @@ JLArrays.allowscalar(false)
3131
@test gradient.(Ref(esitp), idx) ==
3232
collect(gradient.(Ref(jlesitp), idx)) ==
3333
collect(gradient.(Ref(jlesitp), jlidx))
34+
35+
esitp = extrapolate(sitp, 0.0)
36+
jlesitp = jl(esitp)
37+
idx = -1.0:0.84:41.0
38+
jlidx = jl(collect(idx))
39+
@test esitp.(idx) == collect(jlesitp.(idx)) == collect(jlesitp.(jlidx))
40+
@test gradient.(Ref(esitp), idx) ==
41+
collect(gradient.(Ref(jlesitp), idx)) ==
42+
collect(gradient.(Ref(jlesitp), jlidx))
3443
end
3544

3645
@testset "2d GPU Interpolation" begin
@@ -69,6 +78,16 @@ end
6978
@test gradient.(Ref(esitp), idx, idx') ==
7079
collect(gradient.(Ref(jlesitp), idx, idx')) ==
7180
collect(gradient.(Ref(jlesitp), jlidx, jlidx'))
81+
82+
esitp = extrapolate(sitp, 0.0)
83+
jlesitp = jl(esitp)
84+
idx = -1.0:0.84:41.0
85+
jlidx = jl(collect(idx))
86+
@test esitp.(idx, idx') == collect(jlesitp.(idx, idx')) == collect(jlesitp.(jlidx, jlidx'))
87+
# gradient for `extrapolation` is currently broken under CUDA
88+
@test gradient.(Ref(esitp), idx, idx') ==
89+
collect(gradient.(Ref(jlesitp), idx, idx')) ==
90+
collect(gradient.(Ref(jlesitp), jlidx, jlidx'))
7291
end
7392

7493
@testset "Lanczos on gpu" begin
@@ -99,6 +118,7 @@ end
99118
@test eltype(adapt(Array{Real}, itp)) === Float64
100119
@test eltype(adapt(Array{Float32}, scale(itp, A_x))) === Float32
101120
@test eltype(adapt(Array{Float32}, extrapolate(scale(itp, A_x), Flat()))) === Float32
121+
@test eltype(adapt(Array{Float32}, extrapolate(scale(itp, A_x), 0.0))) === Float32
102122
itp = interpolate((-1:0.2:1, -1:0.2:1), randn(11, 11), Gridded(Linear()))
103123
@test eltype(adapt(Array{Float32}, itp)) === Float32
104124
itp = interpolate((1.0:0.0, 1.:0.), randn(0, 0), Gridded(Linear()))

0 commit comments

Comments
 (0)