Skip to content
This repository was archived by the owner on Jul 19, 2023. It is now read-only.

Commit 0087375

Browse files
ajozefiakChrisRackauckas
authored andcommitted
first pass for 3d composition mul, some easy tests are passing
1 parent 29752e5 commit 0087375

File tree

2 files changed

+315
-0
lines changed

2 files changed

+315
-0
lines changed

src/derivative_operators/derivative_operator_functions.jl

Lines changed: 229 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -347,3 +347,232 @@ function LinearAlgebra.mul!(x_temp::AbstractArray{T,2}, A::AbstractDiffEqComposi
347347
end
348348
end
349349
end
350+
351+
# A more efficient mul! implementation for compositions of operators which may include regular-grid, centered difference,
352+
# scalar coefficient, non-winding, DerivativeOperator, operating on a 2D or 3D AbstractArray
353+
function LinearAlgebra.mul!(x_temp::AbstractArray{T,3}, A::AbstractDiffEqCompositeOperator, M::AbstractArray{T,3}) where {T}
354+
355+
# opsA operators satisfy conditions for NNlib.conv! call, opsB operators do not
356+
opsA = DerivativeOperator[]
357+
opsB = DerivativeOperator[]
358+
for L in A.ops
359+
if (L.coefficients isa Number || L.coefficients === nothing) && use_winding(L) == false && L.dx isa Number
360+
push!(opsA, L)
361+
else
362+
push!(opsB,L)
363+
end
364+
end
365+
366+
# Check that we can make at least one NNlib.conv! call
367+
if !isempty(opsA)
368+
ndimsM = ndims(M)
369+
Wdims = ones(Int64,ndimsM)
370+
pad = zeros(Int64, ndimsM)
371+
372+
# compute dimensions of interior kernel W
373+
# Here we still use A.ops since operators in opsB may indicate that
374+
# we have more padding to account for
375+
for L in A.ops
376+
axis = typeof(L).parameters[2]
377+
@assert axis <= ndimsM
378+
Wdims[axis] = max(Wdims[axis],L.stencil_length)
379+
pad[axis] = max(pad[axis], L.boundary_point_count)
380+
end
381+
382+
# create zero-valued kernel
383+
W = zeros(T, Wdims...)
384+
mid_Wdims = div.(Wdims,2).+1
385+
idx = div.(Wdims,2).+1
386+
387+
# add to kernel each stencil
388+
for L in opsA
389+
s = L.stencil_coefs
390+
sl = L.stencil_length
391+
axis = typeof(L).parameters[2]
392+
offset = convert(Int64,(Wdims[axis] - sl)/2)
393+
coeff = L.coefficients isa Number ? L.coefficients : true
394+
for i in offset+1:Wdims[axis]-offset
395+
idx[axis]=i
396+
W[idx...] += coeff*s[i-offset]
397+
idx[axis] = mid_Wdims[axis]
398+
end
399+
end
400+
401+
# Reshape x_temp for NNlib.conv!
402+
_x_temp = reshape(x_temp, (size(x_temp)...,1,1))
403+
404+
# Reshape M for NNlib.conv!
405+
_M = reshape(M, (size(M)...,1,1))
406+
407+
_W = reshape(W, (size(W)...,1,1))
408+
409+
# Call NNlib.conv!
410+
cv = DenseConvDims(_M, _W, padding=pad, flipkernel=true)
411+
conv!(_x_temp, _M, _W, cv)
412+
413+
414+
# convolve boundary and interior points near boundary
415+
# partition operator indices along axis of differentiation
416+
if pad[1] > 0 || pad[2] > 0
417+
ops_1 = Int64[]
418+
ops_1_max_bpc_idx = [0]
419+
ops_2 = Int64[]
420+
ops_2_max_bpc_idx = [0]
421+
for i in 1:length(opsA)
422+
L = opsA[i]
423+
if typeof(L).parameters[2] == 1
424+
push!(ops_1,i)
425+
if L.boundary_point_count == pad[1]
426+
ops_1_max_bpc_idx[1] = i
427+
end
428+
else
429+
push!(ops_2,i)
430+
if L.boundary_point_count == pad[2]
431+
ops_2_max_bpc_idx[1]= i
432+
end
433+
end
434+
end
435+
436+
# need offsets since some axis may have ghost nodes and some may not
437+
offset_x = 0
438+
offset_y = 0
439+
440+
if length(ops_2) > 0
441+
offset_x = 1
442+
end
443+
if length(ops_1) > 0
444+
offset_y = 1
445+
end
446+
447+
# convolve boundaries and unaccounted for interior in axis 1
448+
if length(ops_1) > 0
449+
for i in 1:size(x_temp)[2]
450+
convolve_BC_left!(view(x_temp,:,i), view(M,:,i+offset_x), opsA[ops_1_max_bpc_idx...])
451+
convolve_BC_right!(view(x_temp,:,i), view(M,:,i+offset_x), opsA[ops_1_max_bpc_idx...])
452+
if i <= pad[2] || i > size(x_temp)[2]-pad[2]
453+
convolve_interior!(view(x_temp,:,i), view(M,:,i+offset_x), opsA[ops_1_max_bpc_idx...])
454+
end
455+
456+
for Lidx in ops_1
457+
if Lidx != ops_1_max_bpc_idx[1]
458+
convolve_BC_left!(view(x_temp,:,i), view(M,:,i+offset_x), opsA[Lidx], overwrite = false)
459+
convolve_BC_right!(view(x_temp,:,i), view(M,:,i+offset_x), opsA[Lidx], overwrite = false)
460+
if i <= pad[2] || i > size(x_temp)[2]-pad[2]
461+
convolve_interior!(view(x_temp,:,i), view(M,:,i+offset_x), opsA[Lidx], overwrite = false)
462+
elseif pad[1] - opsA[Lidx].boundary_point_count > 0
463+
convolve_interior_add_range!(view(x_temp,:,i), view(M,:,i+offset_x), opsA[Lidx], pad[1] - opsA[Lidx].boundary_point_count)
464+
end
465+
end
466+
end
467+
end
468+
end
469+
# convolve boundaries and unaccounted for interior in axis 2
470+
if length(ops_2) > 0
471+
for i in 1:size(x_temp)[1]
472+
# in the case of no axis 1 operators, we need to overwrite x_temp
473+
if length(ops_1) == 0
474+
convolve_BC_left!(view(x_temp,i,:), view(M,i+offset_y,:), opsA[ops_2_max_bpc_idx...])
475+
convolve_BC_right!(view(x_temp,i,:), view(M,i+offset_y,:), opsA[ops_2_max_bpc_idx...])
476+
if i <= pad[1] || i > size(x_temp)[1]-pad[1]
477+
convolve_interior!(view(x_temp,i,:), view(M,i+offset_y,:), opsA[ops_2_max_bpc_idx...])
478+
end
479+
480+
else
481+
convolve_BC_left!(view(x_temp,i,:), view(M,i+offset_y,:), opsA[ops_2_max_bpc_idx...], overwrite = false)
482+
convolve_BC_right!(view(x_temp,i,:), view(M,i+offset_y,:), opsA[ops_2_max_bpc_idx...], overwrite = false)
483+
if i <= pad[1] || i > size(x_temp)[1]-pad[1]
484+
convolve_interior!(view(x_temp,i,:), view(M,i+offset_y,:), opsA[ops_2_max_bpc_idx...], overwrite = false)
485+
end
486+
487+
end
488+
for Lidx in ops_2
489+
if Lidx != ops_2_max_bpc_idx[1]
490+
convolve_BC_left!(view(x_temp,i,:), view(M,i+offset_y,:), opsA[Lidx], overwrite = false)
491+
convolve_BC_right!(view(x_temp,i,:), view(M,i+offset_y,:), opsA[Lidx], overwrite = false)
492+
if i <= pad[1] || i > size(x_temp)[1]-pad[1]
493+
convolve_interior!(view(x_temp,i,:), view(M,i+offset_y,:), opsA[Lidx], overwrite = false)
494+
elseif pad[2] - opsA[Lidx].boundary_point_count > 0
495+
convolve_interior_add_range!(view(x_temp,i,:), view(M,i+offset_y,:), opsA[Lidx], pad[2] - opsA[Lidx].boundary_point_count)
496+
end
497+
end
498+
end
499+
end
500+
end
501+
end
502+
503+
# Here we compute mul! (additively) for every operator in opsB
504+
505+
operating_dims = zeros(Int64,2)
506+
# need to consider all dimensions and operators to determine the truncation
507+
# of M to x_temp
508+
for L in A.ops
509+
if diff_axis(L) == 1
510+
operating_dims[1] = 1
511+
else
512+
operating_dims[2] = 1
513+
end
514+
end
515+
516+
x_temp_1, x_temp_2 = size(x_temp)
517+
518+
for L in opsB
519+
N = diff_axis(L)
520+
if N == 1
521+
if operating_dims[2] == 1
522+
mul!(x_temp,L,view(M,1:x_temp_1+2,1:x_temp_2), overwrite = false)
523+
else
524+
mul!(x_temp,L,M, overwrite = false)
525+
end
526+
else
527+
if operating_dims[1] == 1
528+
mul!(x_temp,L,view(M,1:x_temp_1,1:x_temp_2+2), overwrite = false)
529+
else
530+
mul!(x_temp,L,M, overwrite = false)
531+
end
532+
end
533+
end
534+
535+
# The case where we call everything in A.ops using the fallback mul!
536+
else
537+
# operating_dims indicates which dimensions we are multiplying along
538+
operating_dims = zeros(Int64,3)
539+
for L in A.ops
540+
operating_dims[diff_axis(L)] = 1
541+
end
542+
543+
x_temp_1, x_temp_2, x_temp_3 = size(x_temp)
544+
545+
# Handle first case non-additively
546+
N = diff_axis(A.ops[1])
547+
if N == 1
548+
if operating_dims[2] == 1
549+
mul!(x_temp,A.ops[1],view(M,1:x_temp_1+2,1:x_temp_2))
550+
else
551+
mul!(x_temp,A.ops[1],M)
552+
end
553+
else
554+
if operating_dims[1] == 1
555+
mul!(x_temp,A.ops[1],view(M,1:x_temp_1,1:x_temp_2+2))
556+
else
557+
mul!(x_temp,A.ops[1],M)
558+
end
559+
end
560+
561+
for L in A.ops[2:end]
562+
N = diff_axis(L)
563+
if N == 1
564+
if operating_dims[2] == 1
565+
mul!(x_temp,L,view(M,1:x_temp_1+2,1:x_temp_2), overwrite = false)
566+
else
567+
mul!(x_temp,L,M, overwrite = false)
568+
end
569+
else
570+
if operating_dims[1] == 1
571+
mul!(x_temp,L,view(M,1:x_temp_1,1:x_temp_2+2), overwrite = false)
572+
else
573+
mul!(x_temp,L,M, overwrite = false)
574+
end
575+
end
576+
end
577+
end
578+
end

test/2D_3D_fast_multiplication.jl

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -626,3 +626,89 @@ end
626626
@test M_temp ((Lx2*M)[1:N,2:N+1]+(Lx3*M)[1:N,2:N+1]+(Lx4*M)[1:N,2:N+1]+(Ly2*M)[2:N+1,1:N]+(Ly3*M)[2:N+1,1:N]+(Ly4*M)[2:N+1,1:N]+(_Ly2*M)[2:N+1,1:N]+(_Ly3*M)[2:N+1,1:N]+(_Ly4*M)[2:N+1,1:N])
627627

628628
end
629+
630+
################################################################################
631+
# 3D Multiplication Tests
632+
################################################################################
633+
634+
@testset "3D Multiplication with no boundary points and dx = dy = dz = 1.0" begin
635+
636+
# Test (Lxx + Lyy + Lzz)*M, dx = dy = dz = 1.0, no coefficient
637+
N = 100
638+
M = zeros(N+2,N+2,N+2)
639+
M_temp = zeros(N,N,N)
640+
641+
for i in 1:N+2
642+
for j in 1:N+2
643+
for k in 1:N+2
644+
M[i,j,k] = cos(0.1i)+sin(0.1j) + exp(0.01k)
645+
end
646+
end
647+
end
648+
649+
Lxx = CenteredDifference{1}(2,2,1.0,N)
650+
Lyy = CenteredDifference{2}(2,2,1.0,N)
651+
Lzz = CenteredDifference{3}(2,2,1.0,N)
652+
A = Lxx + Lyy + Lzz
653+
654+
mul!(M_temp, A, M)
655+
656+
@test M_temp ((Lxx*M)[1:N,2:N+1,2:N+1] + (Lyy*M)[2:N+1,1:N,2:N+1] + (Lzz*M)[2:N+1,2:N+1,1:N])
657+
658+
# Test a single axis, multiple operators: (Lx + Lxx)*M, dx = 1.0
659+
Lx = CenteredDifference{1}(1,2,1.0,N)
660+
A = Lx + Lxx
661+
662+
M_temp = zeros(N,N+2,N+2)
663+
mul!(M_temp, A, M)
664+
665+
@test M_temp ((Lx*M)+(Lxx*M))
666+
667+
# Test a single axis, multiple operators: (Ly + Lyy)*M, dy = 1.0, no coefficient
668+
Ly = CenteredDifference{2}(1,2,1.0,N)
669+
A = Ly + Lyy
670+
671+
M_temp = zeros(N+2,N,N+2)
672+
mul!(M_temp, A, M)
673+
674+
@test M_temp ((Ly*M)+(Lyy*M))
675+
676+
# Test a single axis, multiple operators: (Lz + Lzz)*M, dz = 1.0, no coefficient
677+
Lz = CenteredDifference{3}(1,2,1.0,N)
678+
A = Lz + Lzz
679+
680+
M_temp = zeros(N+2,N+2,N)
681+
mul!(M_temp, A, M)
682+
683+
@test M_temp ((Lz*M)+(Lzz*M))
684+
685+
# Test multiple operators on both axis: (Lx + Ly + Lxx + Lyy)*M, no coefficient
686+
A = Lx + Ly + Lxx + Lyy
687+
M_temp = zeros(N,N,N+2)
688+
mul!(M_temp, A, M)
689+
690+
@test M_temp ((Lx*M)[1:N,2:N+1,:] +(Ly*M)[2:N+1,1:N,:] + (Lxx*M)[1:N,2:N+1,:] +(Lyy*M)[2:N+1,1:N,:])
691+
692+
# Test multiple operators on both axis: (Lx + Lxx + Lz + Lzz)*M, no coefficient
693+
A = Lx + Lxx + Lz + Lzz
694+
M_temp = zeros(N,N+2,N)
695+
mul!(M_temp, A, M)
696+
697+
@test M_temp ((Lx*M)[1:N,:,2:N+1] + (Lxx*M)[1:N,:,2:N+1] + (Lz*M)[2:N+1,:,1:N] +(Lzz*M)[2:N+1,:,1:N])
698+
699+
700+
# Test multiple operators on both axis: (Ly + Lyy + Lz + Lzz)*M, no coefficient
701+
A = Ly + Lyy + Lz + Lzz
702+
M_temp = zeros(N+2,N,N)
703+
mul!(M_temp, A, M)
704+
705+
@test M_temp ((Ly*M)[:,1:N,2:N+1] + (Lyy*M)[:,1:N,2:N+1] + (Lz*M)[:,2:N+1,1:N] +(Lzz*M)[:,2:N+1,1:N])
706+
707+
# Test multiple operators on both axis: (Lx + Ly + Lxx + Lyy + Lz + Lzz)*M, no coefficient
708+
A = Lx + Ly + Lxx + Lyy + Lz + Lzz
709+
M_temp = zeros(N,N,N)
710+
mul!(M_temp, A, M)
711+
712+
@test M_temp ((Lx*M)[1:N,2:N+1,2:N+1] +(Ly*M)[2:N+1,1:N,2:N+1] + (Lxx*M)[1:N,2:N+1,2:N+1] +(Lyy*M)[2:N+1,1:N,2:N+1] + (Lz*M)[2:N+1,2:N+1,1:N] +(Lzz*M)[2:N+1,2:N+1,1:N])
713+
714+
end

0 commit comments

Comments
 (0)