@@ -347,3 +347,232 @@ function LinearAlgebra.mul!(x_temp::AbstractArray{T,2}, A::AbstractDiffEqComposi
347347 end
348348 end
349349end
350+
351+ # A more efficient mul! implementation for compositions of operators which may include regular-grid, centered difference,
352+ # scalar coefficient, non-winding, DerivativeOperator, operating on a 2D or 3D AbstractArray
353+ function LinearAlgebra. mul! (x_temp:: AbstractArray{T,3} , A:: AbstractDiffEqCompositeOperator , M:: AbstractArray{T,3} ) where {T}
354+
355+ # opsA operators satisfy conditions for NNlib.conv! call, opsB operators do not
356+ opsA = DerivativeOperator[]
357+ opsB = DerivativeOperator[]
358+ for L in A. ops
359+ if (L. coefficients isa Number || L. coefficients === nothing ) && use_winding (L) == false && L. dx isa Number
360+ push! (opsA, L)
361+ else
362+ push! (opsB,L)
363+ end
364+ end
365+
366+ # Check that we can make at least one NNlib.conv! call
367+ if ! isempty (opsA)
368+ ndimsM = ndims (M)
369+ Wdims = ones (Int64,ndimsM)
370+ pad = zeros (Int64, ndimsM)
371+
372+ # compute dimensions of interior kernel W
373+ # Here we still use A.ops since operators in opsB may indicate that
374+ # we have more padding to account for
375+ for L in A. ops
376+ axis = typeof (L). parameters[2 ]
377+ @assert axis <= ndimsM
378+ Wdims[axis] = max (Wdims[axis],L. stencil_length)
379+ pad[axis] = max (pad[axis], L. boundary_point_count)
380+ end
381+
382+ # create zero-valued kernel
383+ W = zeros (T, Wdims... )
384+ mid_Wdims = div .(Wdims,2 ).+ 1
385+ idx = div .(Wdims,2 ).+ 1
386+
387+ # add to kernel each stencil
388+ for L in opsA
389+ s = L. stencil_coefs
390+ sl = L. stencil_length
391+ axis = typeof (L). parameters[2 ]
392+ offset = convert (Int64,(Wdims[axis] - sl)/ 2 )
393+ coeff = L. coefficients isa Number ? L. coefficients : true
394+ for i in offset+ 1 : Wdims[axis]- offset
395+ idx[axis]= i
396+ W[idx... ] += coeff* s[i- offset]
397+ idx[axis] = mid_Wdims[axis]
398+ end
399+ end
400+
401+ # Reshape x_temp for NNlib.conv!
402+ _x_temp = reshape (x_temp, (size (x_temp)... ,1 ,1 ))
403+
404+ # Reshape M for NNlib.conv!
405+ _M = reshape (M, (size (M)... ,1 ,1 ))
406+
407+ _W = reshape (W, (size (W)... ,1 ,1 ))
408+
409+ # Call NNlib.conv!
410+ cv = DenseConvDims (_M, _W, padding= pad, flipkernel= true )
411+ conv! (_x_temp, _M, _W, cv)
412+
413+
414+ # convolve boundary and interior points near boundary
415+ # partition operator indices along axis of differentiation
416+ if pad[1 ] > 0 || pad[2 ] > 0
417+ ops_1 = Int64[]
418+ ops_1_max_bpc_idx = [0 ]
419+ ops_2 = Int64[]
420+ ops_2_max_bpc_idx = [0 ]
421+ for i in 1 : length (opsA)
422+ L = opsA[i]
423+ if typeof (L). parameters[2 ] == 1
424+ push! (ops_1,i)
425+ if L. boundary_point_count == pad[1 ]
426+ ops_1_max_bpc_idx[1 ] = i
427+ end
428+ else
429+ push! (ops_2,i)
430+ if L. boundary_point_count == pad[2 ]
431+ ops_2_max_bpc_idx[1 ]= i
432+ end
433+ end
434+ end
435+
436+ # need offsets since some axis may have ghost nodes and some may not
437+ offset_x = 0
438+ offset_y = 0
439+
440+ if length (ops_2) > 0
441+ offset_x = 1
442+ end
443+ if length (ops_1) > 0
444+ offset_y = 1
445+ end
446+
447+ # convolve boundaries and unaccounted for interior in axis 1
448+ if length (ops_1) > 0
449+ for i in 1 : size (x_temp)[2 ]
450+ convolve_BC_left! (view (x_temp,:,i), view (M,:,i+ offset_x), opsA[ops_1_max_bpc_idx... ])
451+ convolve_BC_right! (view (x_temp,:,i), view (M,:,i+ offset_x), opsA[ops_1_max_bpc_idx... ])
452+ if i <= pad[2 ] || i > size (x_temp)[2 ]- pad[2 ]
453+ convolve_interior! (view (x_temp,:,i), view (M,:,i+ offset_x), opsA[ops_1_max_bpc_idx... ])
454+ end
455+
456+ for Lidx in ops_1
457+ if Lidx != ops_1_max_bpc_idx[1 ]
458+ convolve_BC_left! (view (x_temp,:,i), view (M,:,i+ offset_x), opsA[Lidx], overwrite = false )
459+ convolve_BC_right! (view (x_temp,:,i), view (M,:,i+ offset_x), opsA[Lidx], overwrite = false )
460+ if i <= pad[2 ] || i > size (x_temp)[2 ]- pad[2 ]
461+ convolve_interior! (view (x_temp,:,i), view (M,:,i+ offset_x), opsA[Lidx], overwrite = false )
462+ elseif pad[1 ] - opsA[Lidx]. boundary_point_count > 0
463+ convolve_interior_add_range! (view (x_temp,:,i), view (M,:,i+ offset_x), opsA[Lidx], pad[1 ] - opsA[Lidx]. boundary_point_count)
464+ end
465+ end
466+ end
467+ end
468+ end
469+ # convolve boundaries and unaccounted for interior in axis 2
470+ if length (ops_2) > 0
471+ for i in 1 : size (x_temp)[1 ]
472+ # in the case of no axis 1 operators, we need to overwrite x_temp
473+ if length (ops_1) == 0
474+ convolve_BC_left! (view (x_temp,i,:), view (M,i+ offset_y,:), opsA[ops_2_max_bpc_idx... ])
475+ convolve_BC_right! (view (x_temp,i,:), view (M,i+ offset_y,:), opsA[ops_2_max_bpc_idx... ])
476+ if i <= pad[1 ] || i > size (x_temp)[1 ]- pad[1 ]
477+ convolve_interior! (view (x_temp,i,:), view (M,i+ offset_y,:), opsA[ops_2_max_bpc_idx... ])
478+ end
479+
480+ else
481+ convolve_BC_left! (view (x_temp,i,:), view (M,i+ offset_y,:), opsA[ops_2_max_bpc_idx... ], overwrite = false )
482+ convolve_BC_right! (view (x_temp,i,:), view (M,i+ offset_y,:), opsA[ops_2_max_bpc_idx... ], overwrite = false )
483+ if i <= pad[1 ] || i > size (x_temp)[1 ]- pad[1 ]
484+ convolve_interior! (view (x_temp,i,:), view (M,i+ offset_y,:), opsA[ops_2_max_bpc_idx... ], overwrite = false )
485+ end
486+
487+ end
488+ for Lidx in ops_2
489+ if Lidx != ops_2_max_bpc_idx[1 ]
490+ convolve_BC_left! (view (x_temp,i,:), view (M,i+ offset_y,:), opsA[Lidx], overwrite = false )
491+ convolve_BC_right! (view (x_temp,i,:), view (M,i+ offset_y,:), opsA[Lidx], overwrite = false )
492+ if i <= pad[1 ] || i > size (x_temp)[1 ]- pad[1 ]
493+ convolve_interior! (view (x_temp,i,:), view (M,i+ offset_y,:), opsA[Lidx], overwrite = false )
494+ elseif pad[2 ] - opsA[Lidx]. boundary_point_count > 0
495+ convolve_interior_add_range! (view (x_temp,i,:), view (M,i+ offset_y,:), opsA[Lidx], pad[2 ] - opsA[Lidx]. boundary_point_count)
496+ end
497+ end
498+ end
499+ end
500+ end
501+ end
502+
503+ # Here we compute mul! (additively) for every operator in opsB
504+
505+ operating_dims = zeros (Int64,2 )
506+ # need to consider all dimensions and operators to determine the truncation
507+ # of M to x_temp
508+ for L in A. ops
509+ if diff_axis (L) == 1
510+ operating_dims[1 ] = 1
511+ else
512+ operating_dims[2 ] = 1
513+ end
514+ end
515+
516+ x_temp_1, x_temp_2 = size (x_temp)
517+
518+ for L in opsB
519+ N = diff_axis (L)
520+ if N == 1
521+ if operating_dims[2 ] == 1
522+ mul! (x_temp,L,view (M,1 : x_temp_1+ 2 ,1 : x_temp_2), overwrite = false )
523+ else
524+ mul! (x_temp,L,M, overwrite = false )
525+ end
526+ else
527+ if operating_dims[1 ] == 1
528+ mul! (x_temp,L,view (M,1 : x_temp_1,1 : x_temp_2+ 2 ), overwrite = false )
529+ else
530+ mul! (x_temp,L,M, overwrite = false )
531+ end
532+ end
533+ end
534+
535+ # The case where we call everything in A.ops using the fallback mul!
536+ else
537+ # operating_dims indicates which dimensions we are multiplying along
538+ operating_dims = zeros (Int64,3 )
539+ for L in A. ops
540+ operating_dims[diff_axis (L)] = 1
541+ end
542+
543+ x_temp_1, x_temp_2, x_temp_3 = size (x_temp)
544+
545+ # Handle first case non-additively
546+ N = diff_axis (A. ops[1 ])
547+ if N == 1
548+ if operating_dims[2 ] == 1
549+ mul! (x_temp,A. ops[1 ],view (M,1 : x_temp_1+ 2 ,1 : x_temp_2))
550+ else
551+ mul! (x_temp,A. ops[1 ],M)
552+ end
553+ else
554+ if operating_dims[1 ] == 1
555+ mul! (x_temp,A. ops[1 ],view (M,1 : x_temp_1,1 : x_temp_2+ 2 ))
556+ else
557+ mul! (x_temp,A. ops[1 ],M)
558+ end
559+ end
560+
561+ for L in A. ops[2 : end ]
562+ N = diff_axis (L)
563+ if N == 1
564+ if operating_dims[2 ] == 1
565+ mul! (x_temp,L,view (M,1 : x_temp_1+ 2 ,1 : x_temp_2), overwrite = false )
566+ else
567+ mul! (x_temp,L,M, overwrite = false )
568+ end
569+ else
570+ if operating_dims[1 ] == 1
571+ mul! (x_temp,L,view (M,1 : x_temp_1,1 : x_temp_2+ 2 ), overwrite = false )
572+ else
573+ mul! (x_temp,L,M, overwrite = false )
574+ end
575+ end
576+ end
577+ end
578+ end
0 commit comments