@@ -145,147 +145,165 @@ end
145145# DerivativeOperator operating on a 2D or 3D AbstractArray
146146function LinearAlgebra. mul! (x_temp:: AbstractArray{T,2} , A:: AbstractDiffEqCompositeOperator , M:: AbstractArray{T,2} ) where {T}
147147
148- # Check that composite operator satisfies: regular-grid, centered difference:
149- # for L in A.ops
150- # if L ... (does not satisfy conditions)
151- # return (call fall back for multiplication of composite operators)
152- # end
153- # end
154-
155- ndimsM = ndims (M)
156- Wdims = ones (Int64,ndimsM)
157- pad = zeros (Int64, ndimsM)
158-
159- # compute dimensions of interior kernel W
148+ # opsA operators satisfy conditions for NNlib.conv! call, opsB operators do not
149+ opsA = DerivativeOperator[]
150+ opsB = DerivativeOperator[]
160151 for L in A. ops
161- axis = typeof (L). parameters[2 ]
162- @assert axis <= ndimsM
163- Wdims[axis] = max (Wdims[axis],L. stencil_length)
164- pad[axis] = max (pad[axis], L. boundary_point_count)
152+ if (L. coefficients isa Number || L. coefficients === nothing ) && use_winding (L) == false && L. dx isa Number
153+ push! (opsA, L)
154+ else
155+ push! (opsB,L)
156+ end
165157 end
166158
167- # create zero-valued kernel
168- W = zeros (T, Wdims... )
169- mid_Wdims = div .(Wdims,2 ).+ 1
170- idx = div .(Wdims,2 ).+ 1
159+ # Check that we can make at least one NNlib.conv! call
160+ if ! isempty (opsA)
161+ # TODO replace A.ops with opsA in here
162+ ndimsM = ndims (M)
163+ Wdims = ones (Int64,ndimsM)
164+ pad = zeros (Int64, ndimsM)
165+
166+ # compute dimensions of interior kernel W
167+ for L in A. ops
168+ axis = typeof (L). parameters[2 ]
169+ @assert axis <= ndimsM
170+ Wdims[axis] = max (Wdims[axis],L. stencil_length)
171+ pad[axis] = max (pad[axis], L. boundary_point_count)
172+ end
171173
172- # add to kernel each stencil
173- for L in A. ops
174- s = L. stencil_coefs
175- sl = L. stencil_length
176- axis = typeof (L). parameters[2 ]
177- offset = convert (Int64,(Wdims[axis] - sl)/ 2 )
178- coeff = L. coefficients isa Number ? L. coefficients : true
179- for i in offset+ 1 : Wdims[axis]- offset
180- idx[axis]= i
181- W[idx... ] += coeff* s[i- offset]
182- idx[axis] = mid_Wdims[axis]
174+ # create zero-valued kernel
175+ W = zeros (T, Wdims... )
176+ mid_Wdims = div .(Wdims,2 ).+ 1
177+ idx = div .(Wdims,2 ).+ 1
178+
179+ # add to kernel each stencil
180+ for L in A. ops
181+ s = L. stencil_coefs
182+ sl = L. stencil_length
183+ axis = typeof (L). parameters[2 ]
184+ offset = convert (Int64,(Wdims[axis] - sl)/ 2 )
185+ coeff = L. coefficients isa Number ? L. coefficients : true
186+ for i in offset+ 1 : Wdims[axis]- offset
187+ idx[axis]= i
188+ W[idx... ] += coeff* s[i- offset]
189+ idx[axis] = mid_Wdims[axis]
190+ end
183191 end
184- end
185192
186- # Reshape x_temp for NNlib.conv!
187- _x_temp = reshape (x_temp, (size (x_temp)... ,1 ,1 ))
193+ # Reshape x_temp for NNlib.conv!
194+ _x_temp = reshape (x_temp, (size (x_temp)... ,1 ,1 ))
188195
189- # Reshape M for NNlib.conv!
190- _M = reshape (M, (size (M)... ,1 ,1 ))
196+ # Reshape M for NNlib.conv!
197+ _M = reshape (M, (size (M)... ,1 ,1 ))
191198
192- _W = reshape (W, (size (W)... ,1 ,1 ))
199+ _W = reshape (W, (size (W)... ,1 ,1 ))
193200
194- # Call NNlib.conv!
195- cv = DenseConvDims (_M, _W, padding= pad, flipkernel= true )
196- conv! (_x_temp, _M, _W, cv)
201+ # Call NNlib.conv!
202+ cv = DenseConvDims (_M, _W, padding= pad, flipkernel= true )
203+ conv! (_x_temp, _M, _W, cv)
197204
198205
199- # convolve boundary and interior points near boundary
200- # partition operator indices along axis of differentiation
201- if pad[1 ] > 0 || pad[2 ] > 0
202- ops_1 = Int64[]
203- ops_1_max_bpc_idx = [0 ]
204- ops_2 = Int64[]
205- ops_2_max_bpc_idx = [0 ]
206- for i in 1 : length (A. ops)
207- L = A. ops[i]
208- if typeof (L). parameters[2 ] == 1
209- push! (ops_1,i)
210- if L. boundary_point_count == pad[1 ]
211- ops_1_max_bpc_idx[1 ] = i
212- end
213- else
214- push! (ops_2,i)
215- if L. boundary_point_count == pad[2 ]
216- ops_2_max_bpc_idx[1 ]= i
206+ # convolve boundary and interior points near boundary
207+ # partition operator indices along axis of differentiation
208+ if pad[1 ] > 0 || pad[2 ] > 0
209+ ops_1 = Int64[]
210+ ops_1_max_bpc_idx = [0 ]
211+ ops_2 = Int64[]
212+ ops_2_max_bpc_idx = [0 ]
213+ for i in 1 : length (A. ops)
214+ L = A. ops[i]
215+ if typeof (L). parameters[2 ] == 1
216+ push! (ops_1,i)
217+ if L. boundary_point_count == pad[1 ]
218+ ops_1_max_bpc_idx[1 ] = i
219+ end
220+ else
221+ push! (ops_2,i)
222+ if L. boundary_point_count == pad[2 ]
223+ ops_2_max_bpc_idx[1 ]= i
224+ end
217225 end
218226 end
219- end
220227
221- # need offsets since some axis may have ghost nodes and some may not
222- offset_x = 0
223- offset_y = 0
228+ # need offsets since some axis may have ghost nodes and some may not
229+ offset_x = 0
230+ offset_y = 0
224231
225- if length (ops_2) > 0
226- offset_x = 1
227- end
228- if length (ops_1) > 0
229- offset_y = 1
230- end
232+ if length (ops_2) > 0
233+ offset_x = 1
234+ end
235+ if length (ops_1) > 0
236+ offset_y = 1
237+ end
231238
232- # convolve boundaries and unaccounted for interior in axis 1
233- if length (ops_1) > 0
234- for i in 1 : size (x_temp)[2 ]
235- convolve_BC_left! (view (x_temp,:,i), view (M,:,i+ offset_x), A. ops[ops_1_max_bpc_idx... ])
236- convolve_BC_right! (view (x_temp,:,i), view (M,:,i+ offset_x), A. ops[ops_1_max_bpc_idx... ])
237- if i <= pad[2 ] || i > size (x_temp)[2 ]- pad[2 ]
238- convolve_interior! (view (x_temp,:,i), view (M,:,i+ offset_x), A. ops[ops_1_max_bpc_idx... ])
239- end
240- # scale by dx
241-
242- for Lidx in ops_1
243- if Lidx != ops_1_max_bpc_idx[ 1 ]
244- convolve_BC_left_add ! (view (x_temp,:,i), view (M,:,i+ offset_x), A. ops[Lidx])
245- convolve_BC_right_add! ( view (x_temp,:,i), view (M,:,i + offset_x), A . ops[Lidx])
246- if i <= pad[ 2 ] || i > size ( x_temp)[ 2 ] - pad[ 2 ]
247- convolve_interior_add! ( view (x_temp,:,i), view (M,:,i + offset_x), A. ops[Lidx])
248- elseif pad[1 ] - A. ops[Lidx]. boundary_point_count > 0
249- convolve_interior_add_range! ( view (x_temp,:,i), view (M,:,i + offset_x), A . ops[Lidx], pad[ 1 ] - A . ops[Lidx] . boundary_point_count)
239+ # convolve boundaries and unaccounted for interior in axis 1
240+ if length (ops_1) > 0
241+ for i in 1 : size (x_temp)[2 ]
242+ convolve_BC_left! (view (x_temp,:,i), view (M,:,i+ offset_x), A. ops[ops_1_max_bpc_idx... ])
243+ convolve_BC_right! (view (x_temp,:,i), view (M,:,i+ offset_x), A. ops[ops_1_max_bpc_idx... ])
244+ if i <= pad[2 ] || i > size (x_temp)[2 ]- pad[2 ]
245+ convolve_interior! (view (x_temp,:,i), view (M,:,i+ offset_x), A. ops[ops_1_max_bpc_idx... ])
246+ end
247+
248+ for Lidx in ops_1
249+ if Lidx != ops_1_max_bpc_idx[ 1 ]
250+ convolve_BC_left_add! ( view (x_temp,:,i), view (M,:,i + offset_x), A . ops[Lidx])
251+ convolve_BC_right_add ! (view (x_temp,:,i), view (M,:,i+ offset_x), A. ops[Lidx])
252+ if i <= pad[ 2 ] || i > size (x_temp)[ 2 ] - pad[ 2 ]
253+ convolve_interior_add! ( view ( x_temp,:,i), view (M,:,i + offset_x), A . ops[Lidx])
254+ elseif pad[ 1 ] - A. ops[Lidx]. boundary_point_count > 0
255+ convolve_interior_add_range! ( view (x_temp,:,i), view (M,:,i + offset_x), A . ops[Lidx], pad[1 ] - A. ops[Lidx]. boundary_point_count)
256+ end
250257 end
251258 end
252259 end
253260 end
254- end
255- # convolve boundaries and unaccounted for interior in axis 2
256- if length (ops_2) > 0
257- for i in 1 : size (x_temp)[1 ]
258- # in the case of no axis 1 operators, we need to over x_temp
259- if length (ops_1) == 0
260- convolve_BC_left! (view (x_temp,i,:), view (M,i+ offset_y,:), A. ops[ops_2_max_bpc_idx... ])
261- convolve_BC_right! (view (x_temp,i,:), view (M,i+ offset_y,:), A. ops[ops_2_max_bpc_idx... ])
262- if i <= pad[1 ] || i > size (x_temp)[1 ]- pad[1 ]
263- convolve_interior! (view (x_temp,i,:), view (M,i+ offset_y,:), A. ops[ops_2_max_bpc_idx... ])
264- end
265- # scale by dx
266- # fix here as well
267- else
268- convolve_BC_left_add! (view (x_temp,i,:), view (M,i+ offset_y,:), A. ops[ops_2_max_bpc_idx... ])
269- convolve_BC_right_add! (view (x_temp,i,:), view (M,i+ offset_y,:), A. ops[ops_2_max_bpc_idx... ])
270- if i <= pad[1 ] || i > size (x_temp)[1 ]- pad[1 ]
271- convolve_interior_add! (view (x_temp,i,:), view (M,i+ offset_y,:), A. ops[ops_2_max_bpc_idx... ])
272- end
273- # scale by dx
274- # fix here as well
275- end
276- for Lidx in ops_2
277- if Lidx != ops_2_max_bpc_idx[1 ]
278- convolve_BC_left_add! (view (x_temp,i,:), view (M,i+ offset_y,:), A. ops[Lidx])
279- convolve_BC_right_add! (view (x_temp,i,:), view (M,i+ offset_y,:), A. ops[Lidx])
261+ # convolve boundaries and unaccounted for interior in axis 2
262+ if length (ops_2) > 0
263+ for i in 1 : size (x_temp)[1 ]
264+ # in the case of no axis 1 operators, we need to over x_temp
265+ if length (ops_1) == 0
266+ convolve_BC_left! (view (x_temp,i,:), view (M,i+ offset_y,:), A. ops[ops_2_max_bpc_idx... ])
267+ convolve_BC_right! (view (x_temp,i,:), view (M,i+ offset_y,:), A. ops[ops_2_max_bpc_idx... ])
268+ if i <= pad[1 ] || i > size (x_temp)[1 ]- pad[1 ]
269+ convolve_interior! (view (x_temp,i,:), view (M,i+ offset_y,:), A. ops[ops_2_max_bpc_idx... ])
270+ end
271+ # scale by dx
272+ # fix here as well
273+ else
274+ convolve_BC_left_add! (view (x_temp,i,:), view (M,i+ offset_y,:), A. ops[ops_2_max_bpc_idx... ])
275+ convolve_BC_right_add! (view (x_temp,i,:), view (M,i+ offset_y,:), A. ops[ops_2_max_bpc_idx... ])
280276 if i <= pad[1 ] || i > size (x_temp)[1 ]- pad[1 ]
281- convolve_interior_add! (view (x_temp,i,:), view (M,i+ offset_y,:), A. ops[Lidx])
282- elseif pad[2 ] - A. ops[Lidx]. boundary_point_count > 0
283- convolve_interior_add_range! (view (x_temp,i,:), view (M,i+ offset_y,:), A. ops[Lidx], pad[2 ] - A. ops[Lidx]. boundary_point_count)
277+ convolve_interior_add! (view (x_temp,i,:), view (M,i+ offset_y,:), A. ops[ops_2_max_bpc_idx... ])
278+ end
279+ # scale by dx
280+ # fix here as well
281+ end
282+ for Lidx in ops_2
283+ if Lidx != ops_2_max_bpc_idx[1 ]
284+ convolve_BC_left_add! (view (x_temp,i,:), view (M,i+ offset_y,:), A. ops[Lidx])
285+ convolve_BC_right_add! (view (x_temp,i,:), view (M,i+ offset_y,:), A. ops[Lidx])
286+ if i <= pad[1 ] || i > size (x_temp)[1 ]- pad[1 ]
287+ convolve_interior_add! (view (x_temp,i,:), view (M,i+ offset_y,:), A. ops[Lidx])
288+ elseif pad[2 ] - A. ops[Lidx]. boundary_point_count > 0
289+ convolve_interior_add_range! (view (x_temp,i,:), view (M,i+ offset_y,:), A. ops[Lidx], pad[2 ] - A. ops[Lidx]. boundary_point_count)
290+ end
284291 end
285292 end
286293 end
287294 end
288295 end
296+ # Call everything A.ops using fallback
297+ else
298+ N = diff_axis (A. ops[1 ])
299+ if N == 1
300+ mul! (view (x_temp,A. ops[1 ],M)
301+ else
302+ mul! (view (x_temp,A. ops[1 ],M)
303+ end
304+ for L in A. ops[2 : end ]
305+ mul_add! (x_temp,L,M)
306+ end
289307 end
290308end
291309
@@ -299,7 +317,7 @@ function convolve_interior_add!(x_temp::AbstractVector{T}, x::AbstractVector{T},
299317 mid = div (A. stencil_length,2 )
300318 for i in (1 + A. boundary_point_count) : (length (x_temp)- A. boundary_point_count)
301319 xtempi = zero (T)
302- cur_stencil = eltype (stencil) <: AbstractVector ? stencil[i] : stencil
320+ cur_stencil = eltype (stencil) <: AbstractVector ? stencil[i- A . boundary_point_count ] : stencil
303321 cur_coeff = typeof (coeff) <: AbstractVector ? coeff[i] : coeff isa Number ? coeff : true
304322 cur_stencil = use_winding (A) && cur_coeff < 0 ? reverse (cur_stencil) : cur_stencil
305323 for idx in 1 : A. stencil_length
0 commit comments