Skip to content
This repository was archived by the owner on Jul 19, 2023. It is now read-only.

Commit ebdb9e0

Browse files
committed
coefficient handling for compositions
1 parent 25ae8ee commit ebdb9e0

File tree

2 files changed

+99
-30
lines changed

2 files changed

+99
-30
lines changed

src/derivative_operators/derivative_operator_functions.jl

Lines changed: 38 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,36 @@ function LinearAlgebra.mul!(x_temp::AbstractArray{T}, A::DerivativeOperator{T,N}
2626
end
2727
end
2828

29+
# Additive mul! that is used for handling compositions
30+
function mul_add!(x_temp::AbstractArray{T}, A::DerivativeOperator{T,N}, M::AbstractArray{T}) where {T,N}
31+
32+
# Check that x_temp has correct dimensions
33+
v = zeros(ndims(x_temp))
34+
v[N] = 2
35+
@assert [size(x_temp)...]+v == [size(M)...]
36+
37+
# Check that axis of differentiation is in the dimensions of M and x_temp
38+
ndimsM = ndims(M)
39+
@assert N <= ndimsM
40+
41+
dimsM = [axes(M)...]
42+
alldims = [1:ndims(M);]
43+
otherdims = setdiff(alldims, N)
44+
45+
idx = Any[first(ind) for ind in axes(M)]
46+
itershape = tuple(dimsM[otherdims]...)
47+
nidx = length(otherdims)
48+
indices = Iterators.drop(CartesianIndices(itershape), 0)
49+
50+
setindex!(idx, :, N)
51+
for I in indices
52+
Base.replace_tuples!(nidx, idx, idx, otherdims, I)
53+
convolve_interior_add!(view(x_temp, idx...), view(M, idx...), A)
54+
convolve_BC_right_add!(view(x_temp, idx...), view(M, idx...), A)
55+
convolve_BC_left_add!(view(x_temp, idx...), view(M, idx...), A)
56+
end
57+
end
58+
2959
# A more efficient mul! implementation for a single, regular-grid, centered difference,
3060
# scalar coefficient DerivativeOperator operating on a 2D or 3D AbstractArray
3161
for MT in [2,3]
@@ -145,11 +175,10 @@ function LinearAlgebra.mul!(x_temp::AbstractArray{T,2}, A::AbstractDiffEqComposi
145175
sl = L.stencil_length
146176
axis = typeof(L).parameters[2]
147177
offset = convert(Int64,(Wdims[axis] - sl)/2)
178+
coeff = L.coefficients isa Number ? L.coefficients : true
148179
for i in offset+1:Wdims[axis]-offset
149180
idx[axis]=i
150-
151-
W[idx...] += s[i-offset]
152-
181+
W[idx...] += coeff*s[i-offset]
153182
idx[axis] = mid_Wdims[axis]
154183
end
155184
end
@@ -271,7 +300,7 @@ function convolve_interior_add!(x_temp::AbstractVector{T}, x::AbstractVector{T},
271300
for i in (1+A.boundary_point_count) : (length(x_temp)-A.boundary_point_count)
272301
xtempi = zero(T)
273302
cur_stencil = eltype(stencil) <: AbstractVector ? stencil[i] : stencil
274-
cur_coeff = typeof(coeff) <: AbstractVector ? coeff[i] : true
303+
cur_coeff = typeof(coeff) <: AbstractVector ? coeff[i] : coeff isa Number ? coeff : true
275304
cur_stencil = use_winding(A) && cur_coeff < 0 ? reverse(cur_stencil) : cur_stencil
276305
for idx in 1:A.stencil_length
277306
xtempi += cur_coeff * cur_stencil[idx] * x[i - mid + idx]
@@ -288,7 +317,7 @@ function convolve_interior_add_range!(x_temp::AbstractVector{T}, x::AbstractVect
288317
for i in [(1+A.boundary_point_count):(A.boundary_point_count+offset); (length(x_temp)-A.boundary_point_count-offset+1):(length(x_temp)-A.boundary_point_count)]
289318
xtempi = zero(T)
290319
cur_stencil = eltype(stencil) <: AbstractVector ? stencil[i] : stencil
291-
cur_coeff = typeof(coeff) <: AbstractVector ? coeff[i] : true
320+
cur_coeff = typeof(coeff) <: AbstractVector ? coeff[i] : coeff isa Number ? coeff : true
292321
cur_stencil = use_winding(A) && cur_coeff < 0 ? reverse(cur_stencil) : cur_stencil
293322
for idx in 1:A.stencil_length
294323
xtempi += cur_coeff * cur_stencil[idx] * x[i - mid + idx]
@@ -301,10 +330,10 @@ function convolve_BC_left_add!(x_temp::AbstractVector{T}, x::AbstractVector{T},
301330
stencil = A.low_boundary_coefs
302331
coeff = A.coefficients
303332
for i in 1 : A.boundary_point_count
304-
xtempi = stencil[i][1]*x[1]
305333
cur_stencil = stencil[i]
306-
cur_coeff = typeof(coeff) <: AbstractVector ? coeff[i] : true
334+
cur_coeff = typeof(coeff) <: AbstractVector ? coeff[i] : coeff isa Number ? coeff : true
307335
cur_stencil = use_winding(A) && cur_coeff < 0 ? reverse(cur_stencil) : cur_stencil
336+
xtempi = cur_coeff*stencil[i][1]*x[1]
308337
for idx in 2:A.boundary_stencil_length
309338
xtempi += cur_coeff * cur_stencil[idx] * x[idx]
310339
end
@@ -316,10 +345,10 @@ function convolve_BC_right_add!(x_temp::AbstractVector{T}, x::AbstractVector{T},
316345
stencil = A.high_boundary_coefs
317346
coeff = A.coefficients
318347
for i in 1 : A.boundary_point_count
319-
xtempi = stencil[i][end]*x[end]
320348
cur_stencil = stencil[i]
321-
cur_coeff = typeof(coeff) <: AbstractVector ? coeff[i] : true
349+
cur_coeff = typeof(coeff) <: AbstractVector ? coeff[i] : coeff isa Number ? coeff : true
322350
cur_stencil = use_winding(A) && cur_coeff < 0 ? reverse(cur_stencil) : cur_stencil
351+
xtempi = cur_coeff*stencil[i][end]*x[end]
323352
for idx in (A.boundary_stencil_length-1):-1:1
324353
xtempi += cur_coeff * cur_stencil[end-idx] * x[end-idx]
325354
end

test/2D_3D_fast_multiplication.jl

Lines changed: 61 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,5 @@
11
using LinearAlgebra, DiffEqOperators, Random, Test, BandedMatrices, SparseArrays
22

3-
function fourth_deriv_approx_stencil(N)
4-
A = zeros(N,N+2)
5-
A[1,1:8] = [3.5 -56/3 42.5 -54.0 251/6 -20.0 5.5 -2/3]
6-
A[2,1:8] = [2/3 -11/6 0.0 31/6 -22/3 4.5 -4/3 1/6]
7-
A[N-1,N-5:end] = reverse([2/3 -11/6 0.0 31/6 -22/3 4.5 -4/3 1/6], dims=2)
8-
A[N,N-5:end] = reverse([3.5 -56/3 42.5 -54.0 251/6 -20.0 5.5 -2/3], dims=2)
9-
for i in 3:N-2
10-
A[i,i-2:i+4] = [-1/6 2.0 -13/2 28/3 -13/2 2.0 -1/6]
11-
end
12-
return A
13-
end
14-
15-
function second_derivative_stencil(N)
16-
A = zeros(N,N+2)
17-
for i in 1:N, j in 1:N+2
18-
(j-i==0 || j-i==2) && (A[i,j]=1)
19-
j-i==1 && (A[i,j]=-2)
20-
end
21-
A
22-
end
23-
243
@testset "2D Multiplication with no boundary points and dx = 1.0" begin
254

265
# Test (Lxx + Lyy)*M, dx = 1.0, no coefficient
@@ -409,3 +388,64 @@ end
409388
@test M_temp ((Lx2*M)[1:N,2:N+1]+(Ly2*M)[2:N+1,1:N]+(Lx3*M)[1:N,2:N+1] +(Ly3*M)[2:N+1,1:N] + (Lx4*M)[1:N,2:N+1] +(Ly4*M)[2:N+1,1:N])
410389

411390
end
391+
392+
# THis testset uses the last testset which has a several non-trivial cases,
393+
# and additionally tests coefficient handling. All operators are handled by the
394+
# fast 2D/3D dispatch
395+
@testset "2D coefficient handling" begin
396+
397+
dx = 0.1
398+
dy = 0.25
399+
N = 100
400+
M = zeros(N+2,N+2)
401+
M_temp = zeros(N,N+2)
402+
for i in 1:N+2
403+
for j in 1:N+2
404+
M[i,j] = cos(dx*i)+sin(dy*j)
405+
end
406+
end
407+
408+
# Lx2 has 0 boundary points
409+
Lx2 = 5.5*CenteredDifference{1}(2,2,dx,N)
410+
# Lx3 has 1 boundary point
411+
Lx3 = 1.45*CenteredDifference{1}(3,3,dx,N)
412+
# Lx4 has 2 boundary points
413+
Lx4 = 0.5*CenteredDifference{1}(4,4,dx,N)
414+
415+
# Test a single axis, multiple operators: (Lxx+Lxxxx)*M, dx = 1.0
416+
A = Lx2+Lx4
417+
mul!(M_temp, A, M)
418+
@test M_temp ((Lx2*M) + (Lx4*M))
419+
420+
# Test a single axis, multiple operators: (Lxx++Lxxx+Lxxxx)*M, dx = 1.0
421+
A += Lx3
422+
mul!(M_temp, A, M)
423+
@test M_temp ((Lx2*M) + (Lx3*M) + (Lx4*M))
424+
425+
426+
# Ly2 has 0 boundary points
427+
Ly2 = 8.14*CenteredDifference{2}(2,2,dy,N)
428+
# Ly3 has 1 boundary point
429+
Ly3 = 2.0*CenteredDifference{2}(3,3,dy,N)
430+
# Ly4 has 2 boundary points
431+
Ly4 = 4.567*CenteredDifference{2}(4,4,dy,N)
432+
M_temp = zeros(N+2,N)
433+
434+
# Test a single axis, multiple operators: (Lyy+Lyyyy)*M, dx = 1.0
435+
A = Ly2+Ly4
436+
mul!(M_temp, A, M)
437+
@test M_temp ((Ly2*M) + (Ly4*M))
438+
439+
# Test a single axis, multiple operators: (Lyy++Lyyy+Lyyyy)*M, dx = 1.0
440+
A += Ly3
441+
mul!(M_temp, A, M)
442+
@test M_temp ((Ly2*M) + (Ly3*M) + (Ly4*M))
443+
444+
445+
# Test multiple operators on both axis: (Lxx + Lyy + Lxxx + Lyyy + Lxxxx + Lyyyy)*M, no coefficient
446+
A = Lx2 + Ly2 + Lx3 + Ly3 + Lx4 + Ly4
447+
M_temp = zeros(100,100)
448+
mul!(M_temp, A, M)
449+
450+
@test M_temp ((Lx2*M)[1:N,2:N+1]+(Ly2*M)[2:N+1,1:N]+(Lx3*M)[1:N,2:N+1] +(Ly3*M)[2:N+1,1:N] + (Lx4*M)[1:N,2:N+1] +(Ly4*M)[2:N+1,1:N])
451+
end

0 commit comments

Comments
 (0)