@@ -26,7 +26,7 @@ function LinearAlgebra.mul!(x_temp::AbstractArray{T}, A::DerivativeOperator{T,N}
2626 end
2727end
2828
29- # Additive mul! that is used for handling compositions
29+ # Additive mul! fallback that is necessary for handling compositions
3030function mul_add! (x_temp:: AbstractArray{T} , A:: DerivativeOperator{T,N} , M:: AbstractArray{T} ) where {T,N}
3131
3232 # Check that x_temp has correct dimensions
@@ -57,7 +57,7 @@ function mul_add!(x_temp::AbstractArray{T}, A::DerivativeOperator{T,N}, M::Abstr
5757end
5858
5959# A more efficient mul! implementation for a single, regular-grid, centered difference,
60- # scalar coefficient DerivativeOperator operating on a 2D or 3D AbstractArray
60+ # scalar coefficient, non-winding, DerivativeOperator operating on a 2D or 3D AbstractArray
6161for MT in [2 ,3 ]
6262 @eval begin
6363 function LinearAlgebra. mul! (x_temp:: AbstractArray{T,$MT} , A:: DerivativeOperator{T,N,false,T2,S1,S2,T3} , M:: AbstractArray{T,$MT} ) where
@@ -118,6 +118,8 @@ for MT in [2,3]
118118 end
119119end
120120
121+ # ##########################################
122+
121123function * (A:: DerivativeOperator{T,N} ,M:: AbstractArray{T} ) where {T<: Real ,N}
122124 size_x_temp = [size (M)... ]
123125 size_x_temp[N] -= 2
@@ -141,10 +143,10 @@ function *(c::Number, A::DerivativeOperator{T,N,Wind}) where {T,N,Wind}
141143end
142144
143145
144- # TODO fix syntax error here
146+ # ##########################################
145147
146- # A more efficient mul! implementation for a composition of regular-grid, centered difference
147- # DerivativeOperator operating on a 2D or 3D AbstractArray
148+ # A more efficient mul! implementation for compositions of operators which may include regular-grid, centered difference,
149+ # scalar coefficient, non-winding, DerivativeOperator, operating on a 2D or 3D AbstractArray
148150function LinearAlgebra. mul! (x_temp:: AbstractArray{T,2} , A:: AbstractDiffEqCompositeOperator , M:: AbstractArray{T,2} ) where {T}
149151
150152 # opsA operators satisfy conditions for NNlib.conv! call, opsB operators do not
@@ -160,13 +162,12 @@ function LinearAlgebra.mul!(x_temp::AbstractArray{T,2}, A::AbstractDiffEqComposi
160162
161163 # Check that we can make at least one NNlib.conv! call
162164 if ! isempty (opsA)
163- # TODO replace A.ops with opsA in here
164165 ndimsM = ndims (M)
165166 Wdims = ones (Int64,ndimsM)
166167 pad = zeros (Int64, ndimsM)
167168
168169 # compute dimensions of interior kernel W
169- # Here we still use A.ops since the other dimensions may indicate that
170+ # Here we still use A.ops since operators in opsB may indicate that
170171 # we have more padding to account for
171172 for L in A. ops
172173 axis = typeof (L). parameters[2 ]
@@ -237,7 +238,7 @@ function LinearAlgebra.mul!(x_temp::AbstractArray{T,2}, A::AbstractDiffEqComposi
237238 offset_x = 1
238239 end
239240 if length (ops_1) > 0
240- offset_y = 1
241+ offset_y = 1
241242 end
242243
243244 # convolve boundaries and unaccounted for interior in axis 1
@@ -265,23 +266,21 @@ function LinearAlgebra.mul!(x_temp::AbstractArray{T,2}, A::AbstractDiffEqComposi
265266 # convolve boundaries and unaccounted for interior in axis 2
266267 if length (ops_2) > 0
267268 for i in 1 : size (x_temp)[1 ]
268- # in the case of no axis 1 operators, we need to over x_temp
269+ # in the case of no axis 1 operators, we need to overwrite x_temp
269270 if length (ops_1) == 0
270271 convolve_BC_left! (view (x_temp,i,:), view (M,i+ offset_y,:), opsA[ops_2_max_bpc_idx... ])
271272 convolve_BC_right! (view (x_temp,i,:), view (M,i+ offset_y,:), opsA[ops_2_max_bpc_idx... ])
272273 if i <= pad[1 ] || i > size (x_temp)[1 ]- pad[1 ]
273274 convolve_interior! (view (x_temp,i,:), view (M,i+ offset_y,:), opsA[ops_2_max_bpc_idx... ])
274275 end
275- # scale by dx
276- # fix here as well
276+
277277 else
278278 convolve_BC_left_add! (view (x_temp,i,:), view (M,i+ offset_y,:), opsA[ops_2_max_bpc_idx... ])
279279 convolve_BC_right_add! (view (x_temp,i,:), view (M,i+ offset_y,:), opsA[ops_2_max_bpc_idx... ])
280280 if i <= pad[1 ] || i > size (x_temp)[1 ]- pad[1 ]
281281 convolve_interior_add! (view (x_temp,i,:), view (M,i+ offset_y,:), opsA[ops_2_max_bpc_idx... ])
282282 end
283- # scale by dx
284- # fix here as well
283+
285284 end
286285 for Lidx in ops_2
287286 if Lidx != ops_2_max_bpc_idx[1 ]
@@ -297,7 +296,9 @@ function LinearAlgebra.mul!(x_temp::AbstractArray{T,2}, A::AbstractDiffEqComposi
297296 end
298297 end
299298 end
300- # operating_dims
299+
300+ # Here we compute mul! (additively) for every operator in opsB
301+
301302 operating_dims = zeros (Int64,2 )
302303 # need to consider all dimensions and operators to determine the truncation
303304 # of M to x_temp
@@ -328,7 +329,7 @@ function LinearAlgebra.mul!(x_temp::AbstractArray{T,2}, A::AbstractDiffEqComposi
328329 end
329330 end
330331
331- # Call everything A.ops using fallback
332+ # The case where we call everything in A.ops using the fallback mul!
332333 else
333334 # operating_dims
334335 operating_dims = zeros (Int64,2 )
@@ -376,70 +377,3 @@ function LinearAlgebra.mul!(x_temp::AbstractArray{T,2}, A::AbstractDiffEqComposi
376377 end
377378 end
378379end
379-
380-
381- # Implementations of additive convolutions, need to remove redundancy later since we are
382- # only making these calls when grids are regular, non-winding, and coefficients among indices are equivalent
383- function convolve_interior_add! (x_temp:: AbstractVector{T} , x:: AbstractVector{T} , A:: DerivativeOperator ) where {T<: Real }
384- @assert length (x_temp)+ 2 == length (x)
385- stencil = A. stencil_coefs
386- coeff = A. coefficients
387- mid = div (A. stencil_length,2 )
388- for i in (1 + A. boundary_point_count) : (length (x_temp)- A. boundary_point_count)
389- xtempi = zero (T)
390- cur_stencil = eltype (stencil) <: AbstractVector ? stencil[i- A. boundary_point_count] : stencil
391- cur_coeff = typeof (coeff) <: AbstractVector ? coeff[i] : coeff isa Number ? coeff : true
392- cur_stencil = use_winding (A) && cur_coeff < 0 ? reverse (cur_stencil) : cur_stencil
393- for idx in 1 : A. stencil_length
394- xtempi += cur_coeff * cur_stencil[idx] * x[i - mid + idx]
395- end
396- x_temp[i] += xtempi
397- end
398- end
399-
400- function convolve_interior_add_range! (x_temp:: AbstractVector{T} , x:: AbstractVector{T} , A:: DerivativeOperator , offset:: Int ) where {T<: Real }
401- @assert length (x_temp)+ 2 == length (x)
402- stencil = A. stencil_coefs
403- coeff = A. coefficients
404- mid = div (A. stencil_length,2 )
405- for i in [(1 + A. boundary_point_count): (A. boundary_point_count+ offset); (length (x_temp)- A. boundary_point_count- offset+ 1 ): (length (x_temp)- A. boundary_point_count)]
406- xtempi = zero (T)
407- cur_stencil = eltype (stencil) <: AbstractVector ? stencil[i] : stencil
408- cur_coeff = typeof (coeff) <: AbstractVector ? coeff[i] : coeff isa Number ? coeff : true
409- cur_stencil = use_winding (A) && cur_coeff < 0 ? reverse (cur_stencil) : cur_stencil
410- for idx in 1 : A. stencil_length
411- xtempi += cur_coeff * cur_stencil[idx] * x[i - mid + idx]
412- end
413- x_temp[i] += xtempi
414- end
415- end
416-
417- function convolve_BC_left_add! (x_temp:: AbstractVector{T} , x:: AbstractVector{T} , A:: DerivativeOperator ) where {T<: Real }
418- stencil = A. low_boundary_coefs
419- coeff = A. coefficients
420- for i in 1 : A. boundary_point_count
421- cur_stencil = stencil[i]
422- cur_coeff = typeof (coeff) <: AbstractVector ? coeff[i] : coeff isa Number ? coeff : true
423- cur_stencil = use_winding (A) && cur_coeff < 0 ? reverse (cur_stencil) : cur_stencil
424- xtempi = cur_coeff* stencil[i][1 ]* x[1 ]
425- for idx in 2 : A. boundary_stencil_length
426- xtempi += cur_coeff * cur_stencil[idx] * x[idx]
427- end
428- x_temp[i] += xtempi
429- end
430- end
431-
432- function convolve_BC_right_add! (x_temp:: AbstractVector{T} , x:: AbstractVector{T} , A:: DerivativeOperator ) where {T<: Real }
433- stencil = A. high_boundary_coefs
434- coeff = A. coefficients
435- for i in 1 : A. boundary_point_count
436- cur_stencil = stencil[i]
437- cur_coeff = typeof (coeff) <: AbstractVector ? coeff[i] : coeff isa Number ? coeff : true
438- cur_stencil = use_winding (A) && cur_coeff < 0 ? reverse (cur_stencil) : cur_stencil
439- xtempi = cur_coeff* stencil[i][end ]* x[end ]
440- for idx in (A. boundary_stencil_length- 1 ): - 1 : 1
441- xtempi += cur_coeff * cur_stencil[end - idx] * x[end - idx]
442- end
443- x_temp[end - A. boundary_point_count+ i] += xtempi
444- end
445- end
0 commit comments