Skip to content
This repository was archived by the owner on Jul 19, 2023. It is now read-only.

Commit 806acc1

Browse files
Merge pull request #152 from JuliaDiffEq/3d_mul_compositions
[WIP] 3D mul! for compositions
2 parents 9d0e329 + 2da4fb9 commit 806acc1

File tree

2 files changed

+1265
-0
lines changed

2 files changed

+1265
-0
lines changed

src/derivative_operators/derivative_operator_functions.jl

Lines changed: 258 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -347,3 +347,261 @@ function LinearAlgebra.mul!(x_temp::AbstractArray{T,2}, A::AbstractDiffEqComposi
347347
end
348348
end
349349
end
350+
351+
# A more efficient mul! implementation for compositions of operators which may include regular-grid, centered difference,
352+
# scalar coefficient, non-winding, DerivativeOperator, operating on a 2D or 3D AbstractArray
353+
function LinearAlgebra.mul!(x_temp::AbstractArray{T,3}, A::AbstractDiffEqCompositeOperator, M::AbstractArray{T,3}) where {T}
354+
355+
# opsA operators satisfy conditions for NNlib.conv! call, opsB operators do not
356+
opsA = DerivativeOperator[]
357+
opsB = DerivativeOperator[]
358+
for L in A.ops
359+
if (L.coefficients isa Number || L.coefficients === nothing) && use_winding(L) == false && L.dx isa Number
360+
push!(opsA, L)
361+
else
362+
push!(opsB,L)
363+
end
364+
end
365+
366+
# Check that we can make at least one NNlib.conv! call
367+
if !isempty(opsA)
368+
ndimsM = ndims(M)
369+
Wdims = ones(Int64,ndimsM)
370+
pad = zeros(Int64, ndimsM)
371+
372+
# compute dimensions of interior kernel W
373+
# Here we still use A.ops since operators in opsB may indicate that
374+
# we have more padding to account for
375+
for L in A.ops
376+
axis = typeof(L).parameters[2]
377+
@assert axis <= ndimsM
378+
Wdims[axis] = max(Wdims[axis],L.stencil_length)
379+
pad[axis] = max(pad[axis], L.boundary_point_count)
380+
end
381+
382+
# create zero-valued kernel
383+
W = zeros(T, Wdims...)
384+
mid_Wdims = div.(Wdims,2).+1
385+
idx = div.(Wdims,2).+1
386+
387+
# add to kernel each stencil
388+
for L in opsA
389+
s = L.stencil_coefs
390+
sl = L.stencil_length
391+
axis = typeof(L).parameters[2]
392+
offset = convert(Int64,(Wdims[axis] - sl)/2)
393+
coeff = L.coefficients isa Number ? L.coefficients : true
394+
for i in offset+1:Wdims[axis]-offset
395+
idx[axis]=i
396+
W[idx...] += coeff*s[i-offset]
397+
idx[axis] = mid_Wdims[axis]
398+
end
399+
end
400+
401+
# Reshape x_temp for NNlib.conv!
402+
_x_temp = reshape(x_temp, (size(x_temp)...,1,1))
403+
404+
# Reshape M for NNlib.conv!
405+
_M = reshape(M, (size(M)...,1,1))
406+
407+
_W = reshape(W, (size(W)...,1,1))
408+
409+
# Call NNlib.conv!
410+
cv = DenseConvDims(_M, _W, padding=pad, flipkernel=true)
411+
conv!(_x_temp, _M, _W, cv)
412+
413+
414+
# convolve boundary and interior points near boundary
415+
# partition operator indices along axis of differentiation
416+
if pad[1] > 0 || pad[2] > 0 || pad[3] > 0
417+
ops_1 = Int64[]
418+
ops_1_max_bpc_idx = [0]
419+
ops_2 = Int64[]
420+
ops_2_max_bpc_idx = [0]
421+
ops_3 = Int64[]
422+
ops_3_max_bpc_idx = [0]
423+
424+
for i in 1:length(opsA)
425+
L = opsA[i]
426+
if typeof(L).parameters[2] == 1
427+
push!(ops_1,i)
428+
if L.boundary_point_count == pad[1]
429+
ops_1_max_bpc_idx[1] = i
430+
end
431+
elseif typeof(L).parameters[2] == 2
432+
push!(ops_2,i)
433+
if L.boundary_point_count == pad[2]
434+
ops_2_max_bpc_idx[1]= i
435+
end
436+
else
437+
push!(ops_3,i)
438+
if L.boundary_point_count == pad[3]
439+
ops_3_max_bpc_idx[1]= i
440+
end
441+
end
442+
end
443+
444+
# need offsets since some axis may have ghost nodes and some may not
445+
offset_x = 0
446+
offset_y = 0
447+
offset_z = 0
448+
449+
if length(ops_1) > 0
450+
offset_x = 1
451+
end
452+
if length(ops_2) > 0
453+
offset_y = 1
454+
end
455+
if length(ops_3) > 0
456+
offset_z = 1
457+
end
458+
459+
# convolve boundaries and unaccounted for interior in axis 1
460+
if length(ops_1) > 0
461+
for i in 1:size(x_temp)[2]
462+
for j in 1:size(x_temp)[3]
463+
convolve_BC_left!(view(x_temp,:,i,j), view(M,:,i+offset_y,j+offset_z), opsA[ops_1_max_bpc_idx...])
464+
convolve_BC_right!(view(x_temp,:,i,j), view(M,:,i+offset_y,j+offset_z), opsA[ops_1_max_bpc_idx...])
465+
if i <= pad[2] || i > size(x_temp)[2]-pad[2] || j <= pad[3] || j > size(x_temp)[3]-pad[3]
466+
convolve_interior!(view(x_temp,:,i,j), view(M,:,i+offset_y,j+offset_z), opsA[ops_1_max_bpc_idx...])
467+
end
468+
469+
for Lidx in ops_1
470+
if Lidx != ops_1_max_bpc_idx[1]
471+
convolve_BC_left!(view(x_temp,:,i,j), view(M,:,i+offset_y,j+offset_z), opsA[Lidx], overwrite = false)
472+
convolve_BC_right!(view(x_temp,:,i,j), view(M,:,i+offset_y,j+offset_z), opsA[Lidx], overwrite = false)
473+
if i <= pad[2] || i > size(x_temp)[2]-pad[2] || j <= pad[3] || j > size(x_temp)[3]-pad[3]
474+
convolve_interior!(view(x_temp,:,i,j), view(M,:,i+offset_y,j+offset_z), opsA[Lidx], overwrite = false)
475+
elseif pad[1] - opsA[Lidx].boundary_point_count > 0
476+
convolve_interior_add_range!(view(x_temp,:,i,j), view(M,:,i+offset_y,j+offset_z), opsA[Lidx], pad[1] - opsA[Lidx].boundary_point_count)
477+
end
478+
end
479+
end
480+
end
481+
end
482+
end
483+
# convolve boundaries and unaccounted for interior in axis 2
484+
if length(ops_2) > 0
485+
for i in 1:size(x_temp)[1]
486+
for j in 1:size(x_temp)[3]
487+
# in the case of no axis 1 operators, we need to overwrite x_temp
488+
if length(ops_1) == 0
489+
convolve_BC_left!(view(x_temp,i,:,j), view(M,i+offset_x,:,j+offset_z), opsA[ops_2_max_bpc_idx...])
490+
convolve_BC_right!(view(x_temp,i,:,j), view(M,i+offset_x,:,j+offset_z), opsA[ops_2_max_bpc_idx...])
491+
if i <= pad[1] || i > size(x_temp)[1]-pad[1] || j <= pad[3] || j > size(x_temp)[3]-pad[3]
492+
convolve_interior!(view(x_temp,i,:,j), view(M,i+offset_x,:,j+offset_z), opsA[ops_2_max_bpc_idx...])
493+
end
494+
495+
else
496+
convolve_BC_left!(view(x_temp,i,:,j), view(M,i+offset_x,:,j+offset_z), opsA[ops_2_max_bpc_idx...], overwrite = false)
497+
convolve_BC_right!(view(x_temp,i,:,j), view(M,i+offset_x,:,j+offset_z), opsA[ops_2_max_bpc_idx...], overwrite = false)
498+
if i <= pad[1] || i > size(x_temp)[1]-pad[1] || j <= pad[3] || j > size(x_temp)[3]-pad[3]
499+
convolve_interior!(view(x_temp,i,:,j), view(M,i+offset_y,:,j+offset_z), opsA[ops_2_max_bpc_idx...], overwrite = false)
500+
end
501+
502+
end
503+
for Lidx in ops_2
504+
if Lidx != ops_2_max_bpc_idx[1]
505+
convolve_BC_left!(view(x_temp,i,:,j), view(M,i+offset_x,:,j+offset_z), opsA[Lidx], overwrite = false)
506+
convolve_BC_right!(view(x_temp,i,:,j), view(M,i+offset_x,:,j+offset_z), opsA[Lidx], overwrite = false)
507+
if i <= pad[1] || i > size(x_temp)[1]-pad[1] || j <= pad[3] || j > size(x_temp)[3]-pad[3]
508+
convolve_interior!(view(x_temp,i,:,j), view(M,i+offset_x,:,j+offset_z), opsA[Lidx], overwrite = false)
509+
elseif pad[2] - opsA[Lidx].boundary_point_count > 0
510+
convolve_interior_add_range!(view(x_temp,i,:,j), view(M,i+offset_x,:,j+offset_z), opsA[Lidx], pad[2] - opsA[Lidx].boundary_point_count)
511+
end
512+
end
513+
end
514+
end
515+
end
516+
end
517+
# convolve boundaries and unaccounted for interior in axis 3
518+
if length(ops_3) > 0
519+
for i in 1:size(x_temp)[1]
520+
for j in 1:size(x_temp)[2]
521+
# in the case of no axis 1 and 2 operators, we need to overwrite x_temp
522+
if length(ops_1) == 0 && length(ops_2) == 0
523+
convolve_BC_left!(view(x_temp,i,j,:), view(M,i+offset_x,j+offset_y,:), opsA[ops_3_max_bpc_idx...])
524+
convolve_BC_right!(view(x_temp,i,j,:), view(M,i+offset_x,j+offset_y,:), opsA[ops_3_max_bpc_idx...])
525+
if i <= pad[1] || i > size(x_temp)[1]-pad[1]
526+
convolve_interior!(view(x_temp,i,j,:), view(M,i+offset_x,j+offset_y,:), opsA[ops_3_max_bpc_idx...])
527+
end
528+
529+
else
530+
convolve_BC_left!(view(x_temp,i,j,:), view(M,i+offset_x,j+offset_y,:), opsA[ops_3_max_bpc_idx...], overwrite = false)
531+
convolve_BC_right!(view(x_temp,i,j,:), view(M,i+offset_x,j+offset_y,:), opsA[ops_3_max_bpc_idx...], overwrite = false)
532+
if i <= pad[1] || i > size(x_temp)[1]-pad[1] || j <= pad[2] || j > size(x_temp)[2]-pad[2]
533+
convolve_interior!(view(x_temp,i,j,:), view(M,i+offset_x,j+offset_y,:), opsA[ops_3_max_bpc_idx...], overwrite = false)
534+
end
535+
536+
end
537+
for Lidx in ops_3
538+
if Lidx != ops_3_max_bpc_idx[1]
539+
convolve_BC_left!(view(x_temp,i,j,:), view(M,i+offset_x,j+offset_y,:), opsA[Lidx], overwrite = false)
540+
convolve_BC_right!(view(x_temp,i,j,:), view(M,i+offset_x,j+offset_y,:), opsA[Lidx], overwrite = false)
541+
if i <= pad[1] || i > size(x_temp)[1]-pad[1] || j <= pad[2] || j > size(x_temp)[2]-pad[2]
542+
convolve_interior!(view(x_temp,i,j,:), view(M,i+offset_x,j+offset_y,:), opsA[Lidx], overwrite = false)
543+
elseif pad[3] - opsA[Lidx].boundary_point_count > 0
544+
convolve_interior_add_range!(view(x_temp,i,j,:), view(M,i+offset_x,j+offset_y,:), opsA[Lidx], pad[3] - opsA[Lidx].boundary_point_count)
545+
end
546+
end
547+
end
548+
end
549+
end
550+
end
551+
end
552+
553+
# Here we compute mul! (additively) for every operator in opsB
554+
555+
operating_dims = zeros(Int64,3)
556+
# need to consider all dimensions and operators to determine the truncation
557+
# of M to x_temp
558+
for L in A.ops
559+
operating_dims[diff_axis(L)] = 1
560+
end
561+
562+
x_temp_1, x_temp_2, x_temp_3 = size(x_temp)
563+
564+
for L in opsB
565+
N = diff_axis(L)
566+
if N == 1
567+
mul!(x_temp, L, view(M,1:x_temp_1+2,1:x_temp_2,1:x_temp_3), overwrite = false)
568+
elseif N == 2
569+
mul!(x_temp, L, view(M,1:x_temp_1,1:x_temp_2+2,1:x_temp_3), overwrite = false)
570+
else
571+
mul!(x_temp, L, view(M,1:x_temp_1,1:x_temp_2,1:x_temp_3+2), overwrite = false)
572+
end
573+
end
574+
575+
# The case where we call everything in A.ops using the fallback mul!
576+
else
577+
# operating_dims indicates which dimensions we are multiplying along
578+
operating_dims = zeros(Int64,3)
579+
for L in A.ops
580+
operating_dims[diff_axis(L)] = 1
581+
end
582+
583+
x_temp_1, x_temp_2, x_temp_3 = size(x_temp)
584+
585+
# Handle first case non-additively
586+
N = diff_axis(A.ops[1])
587+
588+
if N == 1
589+
mul!(x_temp, A.ops[1], view(M,1:x_temp_1+2,1:x_temp_2,1:x_temp_3))
590+
elseif N == 2
591+
mul!(x_temp, A.ops[1], view(M,1:x_temp_1,1:x_temp_2+2,1:x_temp_3))
592+
else
593+
mul!(x_temp, A.ops[1], view(M,1:x_temp_1,1:x_temp_2,1:x_temp_3+2))
594+
end
595+
596+
for L in A.ops[2:end]
597+
N = diff_axis(L)
598+
if N == 1
599+
mul!(x_temp, L, view(M,1:x_temp_1+2,1:x_temp_2,1:x_temp_3), overwrite = false)
600+
elseif N == 2
601+
mul!(x_temp, L, view(M,1:x_temp_1,1:x_temp_2+2,1:x_temp_3), overwrite = false)
602+
else
603+
mul!(x_temp, L, view(M,1:x_temp_1,1:x_temp_2,1:x_temp_3+2), overwrite = false)
604+
end
605+
end
606+
end
607+
end

0 commit comments

Comments
 (0)