Skip to content

Commit b47eddc

Browse files
authored
Merge pull request #568 from N5N3/cbfix
Some recursion tuning to allow more eager inference.
2 parents 43fe00f + 21f8b76 commit b47eddc

File tree

5 files changed

+43
-21
lines changed

5 files changed

+43
-21
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "Interpolations"
22
uuid = "a98d9a8b-a2ab-59e6-89dd-64a1c18fca59"
3-
version = "0.15.0"
3+
version = "0.15.1"
44

55
[deps]
66
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

src/Interpolations.jl

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -278,25 +278,34 @@ struct InterpGetindex{N,A<:AbstractArray{<:Any,N}}
278278
InterpGetindex(A::AbstractArray) = new{ndims(A),typeof(A)}(A)
279279
end
280280
@inline Base.getindex(A::InterpGetindex{N}, I::Vararg{Union{Int,WeightedIndex},N}) where {N} =
281-
interp_getindex(A.coeffs, ntuple(_ -> 0, Val(N)), map(indexflag, I)...)
282-
indexflag(I::Int) = I
283-
@inline indexflag(I::WeightedIndex) = indextuple(I), weights(I)
281+
interp_getindex(A.coeffs, ntuple(zero, Val(N)), map(indexflag, I)...)
282+
@inline indexflag(I) = indextuple(I), weights(I)
283+
284+
# Direct recursion would allow more eager inference before julia 1.11.
285+
# Normalize all index into the same format.
286+
struct One end # Singleton for express weights of no-interp dims
287+
indextuple(I::Int) = (I,)
288+
weights(::Int) = (One(),)
289+
290+
struct Zero end # Singleton for dim expansion termination
284291

285292
# A recursion-based `interp_getindex`, which follows a "move processed indexes to the back" strategy
286293
# `I` contains the processed index, and (wi1, wis...) contains the yet-to-be-processed indexes
287-
# Here we meet a no-interp dim, just append the index to `I`'s end.
288-
@inline interp_getindex(A, I, wi1::Int, wis...) =
289-
interp_getindex(A, (Base.tail(I)..., wi1), wis...)
290294
# Here we handle the expansion of a single dimension.
291-
@inline interp_getindex(A, I, wi1::NTuple{2,Tuple{Any,Vararg{Any,N}}}, wis...) where {N} =
292-
wi1[2][end] * interp_getindex(A, (Base.tail(I)..., wi1[1][end]), wis...) +
293-
interp_getindex(A, I, map(Base.front, wi1), wis...)
294-
@inline interp_getindex(A, I, wi1::NTuple{2,Tuple{Any}}, wis...) =
295-
wi1[2][1] * interp_getindex(A, (Base.tail(I)..., wi1[1][1]), wis...)
295+
@inline function interp_getindex(A, I, (is, ws)::NTuple{2,Tuple}, wis...)
296+
itped1 = interp_getindex(A, (Base.tail(I)..., is[end]), wis...)
297+
witped = interp_getindex(A, I, (Base.front(is), Base.front(ws)), wis...)
298+
_weight_itp(ws[end], itped1, witped)
299+
end
300+
interp_getindex(_, _, ::NTuple{2,Tuple{}}, ::Vararg) = Zero()
296301
# Termination
297302
@inline interp_getindex(A::AbstractArray{T,N}, I::Dims{N}) where {T,N} =
298303
@inbounds A[I...] # all bounds-checks have already happened
299304

305+
_weight_itp(w, i, wir) = w * i + wir
306+
_weight_itp(::One, i, ::Zero) = i
307+
_weight_itp(w, i, ::Zero) = w * i
308+
300309
"""
301310
w = value_weights(degree, δx)
302311

src/b-splines/indexing.jl

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -102,10 +102,12 @@ function weightedindexes(parts::Vararg{Union{Int,GradParts},N}) where N
102102
slot_substitute(parts, map(positions, parts), map(valuecoefs, parts), map(gradcoefs, parts))
103103
end
104104

105-
# Skip over NoInterp dimensions
106-
slot_substitute(kind::Tuple{Int,Vararg{Any}}, p, v, g) = slot_substitute(Base.tail(kind), p, v, g)
107105
# Substitute the dth dimension's gradient coefs for the remaining coefs
108-
slot_substitute(kind, p, v, g) = (map(maybe_weightedindex, p, substitute_ruled(v, kind, g)), slot_substitute(Base.tail(kind), p, v, g)...)
106+
function slot_substitute(kind, p, v, g)
107+
rest = slot_substitute(Base.tail(kind), p, v, g)
108+
kind[1] isa Int && return rest # Skip over NoInterp dimensions
109+
(map(maybe_weightedindex, p, substitute_ruled(v, kind, g)), rest...)
110+
end
109111
# Termination
110112
slot_substitute(kind::Tuple{}, p, v, g) = ()
111113

@@ -132,15 +134,14 @@ function _column(kind1::K, kind2::K, p, v, g, h) where {K<:Tuple}
132134
ss = substitute_ruled(v, kind1, h)
133135
(map(maybe_weightedindex, p, ss), _column(Base.tail(kind1), kind2, p, v, g, h)...)
134136
end
137+
_column(kind1::K, kind2::K, p, v, g, h) where {K<:Tuple{Int,Vararg}} = () # Skip over NoInterp dimensions
135138
function _column(kind1::Tuple, kind2::Tuple, p, v, g, h)
139+
rest = _column(Base.tail(kind1), kind2, p, v, g, h)
140+
kind1[1] isa Int && return rest # Skip over NoInterp dimensions
136141
ss = substitute_ruled(substitute_ruled(v, kind1, g), kind2, g)
137-
(map(maybe_weightedindex, p, ss), _column(Base.tail(kind1), kind2, p, v, g, h)...)
142+
(map(maybe_weightedindex, p, ss), rest...)
138143
end
139144
_column(::Tuple{}, ::Tuple, p, v, g, h) = ()
140-
# Skip over NoInterp dimensions
141-
slot_substitute(kind::Tuple{Int,Vararg{Any}}, p, v, g, h) = slot_substitute(Base.tail(kind), p, v, g, h)
142-
_column(kind1::Tuple{Int,Vararg{Any}}, kind2::Tuple, p, v, g, h) =
143-
_column(Base.tail(kind1), kind2, p, v, g, h)
144145

145146
weightedindex_parts(fs::F, itpflag::BSpline, ax, x) where F =
146147
weightedindex_parts(fs, degree(itpflag), ax, x)

test/issues/runtests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ using Interpolations, Test, ForwardDiff
156156
end
157157
@testset "issue 469" begin
158158
# We have different inference result on different version.
159-
max_dim = VERSION < v"1.3" ? 3 : isdefined(Base, :Any32) ? 7 : 5
159+
max_dim = isdefined(Base, :Any32) ? 7 : 5
160160
for dims in 3:max_dim
161161
A = zeros(Float64, ntuple(_ -> 5, dims))
162162
itp = interpolate(A, BSpline(Quadratic(Reflect(OnCell()))))

test/nointerp.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,15 @@
1414
# @test ae[0,1] === NaN
1515
# @test_throws InexactError ae(1.5,2)
1616
end
17+
18+
@testset "Stability of mixtrue with NoInterp and Interp" begin
19+
A = zeros(Float64, 5, 5, 5, 5, 5, 5, 5)
20+
st = BSpline(Quadratic(Reflect(OnCell()))), NoInterp(),
21+
BSpline(Linear()), NoInterp(),
22+
BSpline(Quadratic()), NoInterp(),
23+
BSpline(Quadratic(Reflect(OnCell())))
24+
itp = interpolate(A, st)
25+
@test (@inferred Interpolations.hessian(itp, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0)) == zeros(4,4)
26+
@test (@inferred Interpolations.gradient(itp, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0)) == zeros(4)
27+
@test (@inferred itp(1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0)) == 0
28+
end

0 commit comments

Comments
 (0)