@@ -347,3 +347,261 @@ function LinearAlgebra.mul!(x_temp::AbstractArray{T,2}, A::AbstractDiffEqComposi
347347 end
348348 end
349349end
350+
351+ # A more efficient mul! implementation for compositions of operators which may include regular-grid, centered difference,
352+ # scalar coefficient, non-winding, DerivativeOperator, operating on a 2D or 3D AbstractArray
353+ function LinearAlgebra. mul! (x_temp:: AbstractArray{T,3} , A:: AbstractDiffEqCompositeOperator , M:: AbstractArray{T,3} ) where {T}
354+
355+ # opsA operators satisfy conditions for NNlib.conv! call, opsB operators do not
356+ opsA = DerivativeOperator[]
357+ opsB = DerivativeOperator[]
358+ for L in A. ops
359+ if (L. coefficients isa Number || L. coefficients === nothing ) && use_winding (L) == false && L. dx isa Number
360+ push! (opsA, L)
361+ else
362+ push! (opsB,L)
363+ end
364+ end
365+
366+ # Check that we can make at least one NNlib.conv! call
367+ if ! isempty (opsA)
368+ ndimsM = ndims (M)
369+ Wdims = ones (Int64,ndimsM)
370+ pad = zeros (Int64, ndimsM)
371+
372+ # compute dimensions of interior kernel W
373+ # Here we still use A.ops since operators in opsB may indicate that
374+ # we have more padding to account for
375+ for L in A. ops
376+ axis = typeof (L). parameters[2 ]
377+ @assert axis <= ndimsM
378+ Wdims[axis] = max (Wdims[axis],L. stencil_length)
379+ pad[axis] = max (pad[axis], L. boundary_point_count)
380+ end
381+
382+ # create zero-valued kernel
383+ W = zeros (T, Wdims... )
384+ mid_Wdims = div .(Wdims,2 ).+ 1
385+ idx = div .(Wdims,2 ).+ 1
386+
387+ # add to kernel each stencil
388+ for L in opsA
389+ s = L. stencil_coefs
390+ sl = L. stencil_length
391+ axis = typeof (L). parameters[2 ]
392+ offset = convert (Int64,(Wdims[axis] - sl)/ 2 )
393+ coeff = L. coefficients isa Number ? L. coefficients : true
394+ for i in offset+ 1 : Wdims[axis]- offset
395+ idx[axis]= i
396+ W[idx... ] += coeff* s[i- offset]
397+ idx[axis] = mid_Wdims[axis]
398+ end
399+ end
400+
401+ # Reshape x_temp for NNlib.conv!
402+ _x_temp = reshape (x_temp, (size (x_temp)... ,1 ,1 ))
403+
404+ # Reshape M for NNlib.conv!
405+ _M = reshape (M, (size (M)... ,1 ,1 ))
406+
407+ _W = reshape (W, (size (W)... ,1 ,1 ))
408+
409+ # Call NNlib.conv!
410+ cv = DenseConvDims (_M, _W, padding= pad, flipkernel= true )
411+ conv! (_x_temp, _M, _W, cv)
412+
413+
414+ # convolve boundary and interior points near boundary
415+ # partition operator indices along axis of differentiation
416+ if pad[1 ] > 0 || pad[2 ] > 0 || pad[3 ] > 0
417+ ops_1 = Int64[]
418+ ops_1_max_bpc_idx = [0 ]
419+ ops_2 = Int64[]
420+ ops_2_max_bpc_idx = [0 ]
421+ ops_3 = Int64[]
422+ ops_3_max_bpc_idx = [0 ]
423+
424+ for i in 1 : length (opsA)
425+ L = opsA[i]
426+ if typeof (L). parameters[2 ] == 1
427+ push! (ops_1,i)
428+ if L. boundary_point_count == pad[1 ]
429+ ops_1_max_bpc_idx[1 ] = i
430+ end
431+ elseif typeof (L). parameters[2 ] == 2
432+ push! (ops_2,i)
433+ if L. boundary_point_count == pad[2 ]
434+ ops_2_max_bpc_idx[1 ]= i
435+ end
436+ else
437+ push! (ops_3,i)
438+ if L. boundary_point_count == pad[3 ]
439+ ops_3_max_bpc_idx[1 ]= i
440+ end
441+ end
442+ end
443+
444+ # need offsets since some axis may have ghost nodes and some may not
445+ offset_x = 0
446+ offset_y = 0
447+ offset_z = 0
448+
449+ if length (ops_1) > 0
450+ offset_x = 1
451+ end
452+ if length (ops_2) > 0
453+ offset_y = 1
454+ end
455+ if length (ops_3) > 0
456+ offset_z = 1
457+ end
458+
459+ # convolve boundaries and unaccounted for interior in axis 1
460+ if length (ops_1) > 0
461+ for i in 1 : size (x_temp)[2 ]
462+ for j in 1 : size (x_temp)[3 ]
463+ convolve_BC_left! (view (x_temp,:,i,j), view (M,:,i+ offset_y,j+ offset_z), opsA[ops_1_max_bpc_idx... ])
464+ convolve_BC_right! (view (x_temp,:,i,j), view (M,:,i+ offset_y,j+ offset_z), opsA[ops_1_max_bpc_idx... ])
465+ if i <= pad[2 ] || i > size (x_temp)[2 ]- pad[2 ] || j <= pad[3 ] || j > size (x_temp)[3 ]- pad[3 ]
466+ convolve_interior! (view (x_temp,:,i,j), view (M,:,i+ offset_y,j+ offset_z), opsA[ops_1_max_bpc_idx... ])
467+ end
468+
469+ for Lidx in ops_1
470+ if Lidx != ops_1_max_bpc_idx[1 ]
471+ convolve_BC_left! (view (x_temp,:,i,j), view (M,:,i+ offset_y,j+ offset_z), opsA[Lidx], overwrite = false )
472+ convolve_BC_right! (view (x_temp,:,i,j), view (M,:,i+ offset_y,j+ offset_z), opsA[Lidx], overwrite = false )
473+ if i <= pad[2 ] || i > size (x_temp)[2 ]- pad[2 ] || j <= pad[3 ] || j > size (x_temp)[3 ]- pad[3 ]
474+ convolve_interior! (view (x_temp,:,i,j), view (M,:,i+ offset_y,j+ offset_z), opsA[Lidx], overwrite = false )
475+ elseif pad[1 ] - opsA[Lidx]. boundary_point_count > 0
476+ convolve_interior_add_range! (view (x_temp,:,i,j), view (M,:,i+ offset_y,j+ offset_z), opsA[Lidx], pad[1 ] - opsA[Lidx]. boundary_point_count)
477+ end
478+ end
479+ end
480+ end
481+ end
482+ end
483+ # convolve boundaries and unaccounted for interior in axis 2
484+ if length (ops_2) > 0
485+ for i in 1 : size (x_temp)[1 ]
486+ for j in 1 : size (x_temp)[3 ]
487+ # in the case of no axis 1 operators, we need to overwrite x_temp
488+ if length (ops_1) == 0
489+ convolve_BC_left! (view (x_temp,i,:,j), view (M,i+ offset_x,:,j+ offset_z), opsA[ops_2_max_bpc_idx... ])
490+ convolve_BC_right! (view (x_temp,i,:,j), view (M,i+ offset_x,:,j+ offset_z), opsA[ops_2_max_bpc_idx... ])
491+ if i <= pad[1 ] || i > size (x_temp)[1 ]- pad[1 ] || j <= pad[3 ] || j > size (x_temp)[3 ]- pad[3 ]
492+ convolve_interior! (view (x_temp,i,:,j), view (M,i+ offset_x,:,j+ offset_z), opsA[ops_2_max_bpc_idx... ])
493+ end
494+
495+ else
496+ convolve_BC_left! (view (x_temp,i,:,j), view (M,i+ offset_x,:,j+ offset_z), opsA[ops_2_max_bpc_idx... ], overwrite = false )
497+ convolve_BC_right! (view (x_temp,i,:,j), view (M,i+ offset_x,:,j+ offset_z), opsA[ops_2_max_bpc_idx... ], overwrite = false )
498+ if i <= pad[1 ] || i > size (x_temp)[1 ]- pad[1 ] || j <= pad[3 ] || j > size (x_temp)[3 ]- pad[3 ]
499+ convolve_interior! (view (x_temp,i,:,j), view (M,i+ offset_y,:,j+ offset_z), opsA[ops_2_max_bpc_idx... ], overwrite = false )
500+ end
501+
502+ end
503+ for Lidx in ops_2
504+ if Lidx != ops_2_max_bpc_idx[1 ]
505+ convolve_BC_left! (view (x_temp,i,:,j), view (M,i+ offset_x,:,j+ offset_z), opsA[Lidx], overwrite = false )
506+ convolve_BC_right! (view (x_temp,i,:,j), view (M,i+ offset_x,:,j+ offset_z), opsA[Lidx], overwrite = false )
507+ if i <= pad[1 ] || i > size (x_temp)[1 ]- pad[1 ] || j <= pad[3 ] || j > size (x_temp)[3 ]- pad[3 ]
508+ convolve_interior! (view (x_temp,i,:,j), view (M,i+ offset_x,:,j+ offset_z), opsA[Lidx], overwrite = false )
509+ elseif pad[2 ] - opsA[Lidx]. boundary_point_count > 0
510+ convolve_interior_add_range! (view (x_temp,i,:,j), view (M,i+ offset_x,:,j+ offset_z), opsA[Lidx], pad[2 ] - opsA[Lidx]. boundary_point_count)
511+ end
512+ end
513+ end
514+ end
515+ end
516+ end
517+ # convolve boundaries and unaccounted for interior in axis 3
518+ if length (ops_3) > 0
519+ for i in 1 : size (x_temp)[1 ]
520+ for j in 1 : size (x_temp)[2 ]
521+ # in the case of no axis 1 and 2 operators, we need to overwrite x_temp
522+ if length (ops_1) == 0 && length (ops_2) == 0
523+ convolve_BC_left! (view (x_temp,i,j,:), view (M,i+ offset_x,j+ offset_y,:), opsA[ops_3_max_bpc_idx... ])
524+ convolve_BC_right! (view (x_temp,i,j,:), view (M,i+ offset_x,j+ offset_y,:), opsA[ops_3_max_bpc_idx... ])
525+ if i <= pad[1 ] || i > size (x_temp)[1 ]- pad[1 ]
526+ convolve_interior! (view (x_temp,i,j,:), view (M,i+ offset_x,j+ offset_y,:), opsA[ops_3_max_bpc_idx... ])
527+ end
528+
529+ else
530+ convolve_BC_left! (view (x_temp,i,j,:), view (M,i+ offset_x,j+ offset_y,:), opsA[ops_3_max_bpc_idx... ], overwrite = false )
531+ convolve_BC_right! (view (x_temp,i,j,:), view (M,i+ offset_x,j+ offset_y,:), opsA[ops_3_max_bpc_idx... ], overwrite = false )
532+ if i <= pad[1 ] || i > size (x_temp)[1 ]- pad[1 ] || j <= pad[2 ] || j > size (x_temp)[2 ]- pad[2 ]
533+ convolve_interior! (view (x_temp,i,j,:), view (M,i+ offset_x,j+ offset_y,:), opsA[ops_3_max_bpc_idx... ], overwrite = false )
534+ end
535+
536+ end
537+ for Lidx in ops_3
538+ if Lidx != ops_3_max_bpc_idx[1 ]
539+ convolve_BC_left! (view (x_temp,i,j,:), view (M,i+ offset_x,j+ offset_y,:), opsA[Lidx], overwrite = false )
540+ convolve_BC_right! (view (x_temp,i,j,:), view (M,i+ offset_x,j+ offset_y,:), opsA[Lidx], overwrite = false )
541+ if i <= pad[1 ] || i > size (x_temp)[1 ]- pad[1 ] || j <= pad[2 ] || j > size (x_temp)[2 ]- pad[2 ]
542+ convolve_interior! (view (x_temp,i,j,:), view (M,i+ offset_x,j+ offset_y,:), opsA[Lidx], overwrite = false )
543+ elseif pad[3 ] - opsA[Lidx]. boundary_point_count > 0
544+ convolve_interior_add_range! (view (x_temp,i,j,:), view (M,i+ offset_x,j+ offset_y,:), opsA[Lidx], pad[3 ] - opsA[Lidx]. boundary_point_count)
545+ end
546+ end
547+ end
548+ end
549+ end
550+ end
551+ end
552+
553+ # Here we compute mul! (additively) for every operator in opsB
554+
555+ operating_dims = zeros (Int64,3 )
556+ # need to consider all dimensions and operators to determine the truncation
557+ # of M to x_temp
558+ for L in A. ops
559+ operating_dims[diff_axis (L)] = 1
560+ end
561+
562+ x_temp_1, x_temp_2, x_temp_3 = size (x_temp)
563+
564+ for L in opsB
565+ N = diff_axis (L)
566+ if N == 1
567+ mul! (x_temp, L, view (M,1 : x_temp_1+ 2 ,1 : x_temp_2,1 : x_temp_3), overwrite = false )
568+ elseif N == 2
569+ mul! (x_temp, L, view (M,1 : x_temp_1,1 : x_temp_2+ 2 ,1 : x_temp_3), overwrite = false )
570+ else
571+ mul! (x_temp, L, view (M,1 : x_temp_1,1 : x_temp_2,1 : x_temp_3+ 2 ), overwrite = false )
572+ end
573+ end
574+
575+ # The case where we call everything in A.ops using the fallback mul!
576+ else
577+ # operating_dims indicates which dimensions we are multiplying along
578+ operating_dims = zeros (Int64,3 )
579+ for L in A. ops
580+ operating_dims[diff_axis (L)] = 1
581+ end
582+
583+ x_temp_1, x_temp_2, x_temp_3 = size (x_temp)
584+
585+ # Handle first case non-additively
586+ N = diff_axis (A. ops[1 ])
587+
588+ if N == 1
589+ mul! (x_temp, A. ops[1 ], view (M,1 : x_temp_1+ 2 ,1 : x_temp_2,1 : x_temp_3))
590+ elseif N == 2
591+ mul! (x_temp, A. ops[1 ], view (M,1 : x_temp_1,1 : x_temp_2+ 2 ,1 : x_temp_3))
592+ else
593+ mul! (x_temp, A. ops[1 ], view (M,1 : x_temp_1,1 : x_temp_2,1 : x_temp_3+ 2 ))
594+ end
595+
596+ for L in A. ops[2 : end ]
597+ N = diff_axis (L)
598+ if N == 1
599+ mul! (x_temp, L, view (M,1 : x_temp_1+ 2 ,1 : x_temp_2,1 : x_temp_3), overwrite = false )
600+ elseif N == 2
601+ mul! (x_temp, L, view (M,1 : x_temp_1,1 : x_temp_2+ 2 ,1 : x_temp_3), overwrite = false )
602+ else
603+ mul! (x_temp, L, view (M,1 : x_temp_1,1 : x_temp_2,1 : x_temp_3+ 2 ), overwrite = false )
604+ end
605+ end
606+ end
607+ end
0 commit comments