@@ -26,6 +26,36 @@ function LinearAlgebra.mul!(x_temp::AbstractArray{T}, A::DerivativeOperator{T,N}
2626 end
2727end
2828
29+ # Additive mul! that is used for handling compositions
30+ function mul_add! (x_temp:: AbstractArray{T} , A:: DerivativeOperator{T,N} , M:: AbstractArray{T} ) where {T,N}
31+
32+ # Check that x_temp has correct dimensions
33+ v = zeros (ndims (x_temp))
34+ v[N] = 2
35+ @assert [size (x_temp)... ]+ v == [size (M)... ]
36+
37+ # Check that axis of differentiation is in the dimensions of M and x_temp
38+ ndimsM = ndims (M)
39+ @assert N <= ndimsM
40+
41+ dimsM = [axes (M)... ]
42+ alldims = [1 : ndims (M);]
43+ otherdims = setdiff (alldims, N)
44+
45+ idx = Any[first (ind) for ind in axes (M)]
46+ itershape = tuple (dimsM[otherdims]. .. )
47+ nidx = length (otherdims)
48+ indices = Iterators. drop (CartesianIndices (itershape), 0 )
49+
50+ setindex! (idx, :, N)
51+ for I in indices
52+ Base. replace_tuples! (nidx, idx, idx, otherdims, I)
53+ convolve_interior_add! (view (x_temp, idx... ), view (M, idx... ), A)
54+ convolve_BC_right_add! (view (x_temp, idx... ), view (M, idx... ), A)
55+ convolve_BC_left_add! (view (x_temp, idx... ), view (M, idx... ), A)
56+ end
57+ end
58+
2959# A more efficient mul! implementation for a single, regular-grid, centered difference,
3060# scalar coefficient DerivativeOperator operating on a 2D or 3D AbstractArray
3161for MT in [2 ,3 ]
@@ -145,11 +175,10 @@ function LinearAlgebra.mul!(x_temp::AbstractArray{T,2}, A::AbstractDiffEqComposi
145175 sl = L. stencil_length
146176 axis = typeof (L). parameters[2 ]
147177 offset = convert (Int64,(Wdims[axis] - sl)/ 2 )
178+ coeff = L. coefficients isa Number ? L. coefficients : true
148179 for i in offset+ 1 : Wdims[axis]- offset
149180 idx[axis]= i
150-
151- W[idx... ] += s[i- offset]
152-
181+ W[idx... ] += coeff* s[i- offset]
153182 idx[axis] = mid_Wdims[axis]
154183 end
155184 end
@@ -271,7 +300,7 @@ function convolve_interior_add!(x_temp::AbstractVector{T}, x::AbstractVector{T},
271300 for i in (1 + A. boundary_point_count) : (length (x_temp)- A. boundary_point_count)
272301 xtempi = zero (T)
273302 cur_stencil = eltype (stencil) <: AbstractVector ? stencil[i] : stencil
274- cur_coeff = typeof (coeff) <: AbstractVector ? coeff[i] : true
303+ cur_coeff = typeof (coeff) <: AbstractVector ? coeff[i] : coeff isa Number ? coeff : true
275304 cur_stencil = use_winding (A) && cur_coeff < 0 ? reverse (cur_stencil) : cur_stencil
276305 for idx in 1 : A. stencil_length
277306 xtempi += cur_coeff * cur_stencil[idx] * x[i - mid + idx]
@@ -288,7 +317,7 @@ function convolve_interior_add_range!(x_temp::AbstractVector{T}, x::AbstractVect
288317 for i in [(1 + A. boundary_point_count): (A. boundary_point_count+ offset); (length (x_temp)- A. boundary_point_count- offset+ 1 ): (length (x_temp)- A. boundary_point_count)]
289318 xtempi = zero (T)
290319 cur_stencil = eltype (stencil) <: AbstractVector ? stencil[i] : stencil
291- cur_coeff = typeof (coeff) <: AbstractVector ? coeff[i] : true
320+ cur_coeff = typeof (coeff) <: AbstractVector ? coeff[i] : coeff isa Number ? coeff : true
292321 cur_stencil = use_winding (A) && cur_coeff < 0 ? reverse (cur_stencil) : cur_stencil
293322 for idx in 1 : A. stencil_length
294323 xtempi += cur_coeff * cur_stencil[idx] * x[i - mid + idx]
@@ -301,10 +330,10 @@ function convolve_BC_left_add!(x_temp::AbstractVector{T}, x::AbstractVector{T},
301330 stencil = A. low_boundary_coefs
302331 coeff = A. coefficients
303332 for i in 1 : A. boundary_point_count
304- xtempi = stencil[i][1 ]* x[1 ]
305333 cur_stencil = stencil[i]
306- cur_coeff = typeof (coeff) <: AbstractVector ? coeff[i] : true
334+ cur_coeff = typeof (coeff) <: AbstractVector ? coeff[i] : coeff isa Number ? coeff : true
307335 cur_stencil = use_winding (A) && cur_coeff < 0 ? reverse (cur_stencil) : cur_stencil
336+ xtempi = cur_coeff* stencil[i][1 ]* x[1 ]
308337 for idx in 2 : A. boundary_stencil_length
309338 xtempi += cur_coeff * cur_stencil[idx] * x[idx]
310339 end
@@ -316,10 +345,10 @@ function convolve_BC_right_add!(x_temp::AbstractVector{T}, x::AbstractVector{T},
316345 stencil = A. high_boundary_coefs
317346 coeff = A. coefficients
318347 for i in 1 : A. boundary_point_count
319- xtempi = stencil[i][end ]* x[end ]
320348 cur_stencil = stencil[i]
321- cur_coeff = typeof (coeff) <: AbstractVector ? coeff[i] : true
349+ cur_coeff = typeof (coeff) <: AbstractVector ? coeff[i] : coeff isa Number ? coeff : true
322350 cur_stencil = use_winding (A) && cur_coeff < 0 ? reverse (cur_stencil) : cur_stencil
351+ xtempi = cur_coeff* stencil[i][end ]* x[end ]
323352 for idx in (A. boundary_stencil_length- 1 ): - 1 : 1
324353 xtempi += cur_coeff * cur_stencil[end - idx] * x[end - idx]
325354 end
0 commit comments