Skip to content
This repository was archived by the owner on Jul 19, 2023. It is now read-only.

Commit de12dc6

Browse files
committed
some handling for different types of operators in fast 2d/3d call
1 parent deca5ba commit de12dc6

File tree

4 files changed

+140
-121
lines changed

4 files changed

+140
-121
lines changed

src/derivative_operators/convolutions.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ function convolve_interior!(x_temp::AbstractVector{T}, x::AbstractVector{T}, A::
1515
mid = div(A.stencil_length,2)
1616
for i in (1+A.boundary_point_count) : (length(x_temp)-A.boundary_point_count)
1717
xtempi = zero(T)
18-
cur_stencil = eltype(stencil) <: AbstractVector ? stencil[i] : stencil
18+
cur_stencil = eltype(stencil) <: AbstractVector ? stencil[i-A.boundary_point_count] : stencil
1919
cur_coeff = typeof(coeff) <: AbstractVector ? coeff[i] : coeff isa Number ? coeff : true
2020
cur_stencil = use_winding(A) && cur_coeff < 0 ? reverse(cur_stencil) : cur_stencil
2121
for idx in 1:A.stencil_length

src/derivative_operators/derivative_operator.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,3 +166,4 @@ end
166166
CenteredDifference(args...) = CenteredDifference{1}(args...)
167167
UpwindDifference(args...) = UpwindDifference{1}(args...)
168168
use_winding(A::DerivativeOperator{T,N,Wind}) where {T,N,Wind} = Wind
169+
diff_axis(A::DerivativeOperator{T,N}) where {T,N} = N

src/derivative_operators/derivative_operator_functions.jl

Lines changed: 134 additions & 116 deletions
Original file line numberDiff line numberDiff line change
@@ -145,147 +145,165 @@ end
145145
# DerivativeOperator operating on a 2D or 3D AbstractArray
146146
function LinearAlgebra.mul!(x_temp::AbstractArray{T,2}, A::AbstractDiffEqCompositeOperator, M::AbstractArray{T,2}) where {T}
147147

148-
#Check that composite operator satisfies: regular-grid, centered difference:
149-
#for L in A.ops
150-
# if L ... (does not satisfy conditions)
151-
# return (call fall back for multiplication of composite operators)
152-
# end
153-
#end
154-
155-
ndimsM = ndims(M)
156-
Wdims = ones(Int64,ndimsM)
157-
pad = zeros(Int64, ndimsM)
158-
159-
# compute dimensions of interior kernel W
148+
# opsA operators satisfy conditions for NNlib.conv! call, opsB operators do not
149+
opsA = DerivativeOperator[]
150+
opsB = DerivativeOperator[]
160151
for L in A.ops
161-
axis = typeof(L).parameters[2]
162-
@assert axis <= ndimsM
163-
Wdims[axis] = max(Wdims[axis],L.stencil_length)
164-
pad[axis] = max(pad[axis], L.boundary_point_count)
152+
if (L.coefficients isa Number || L.coefficients === nothing) && use_winding(L) == false && L.dx isa Number
153+
push!(opsA, L)
154+
else
155+
push!(opsB,L)
156+
end
165157
end
166158

167-
# create zero-valued kernel
168-
W = zeros(T, Wdims...)
169-
mid_Wdims = div.(Wdims,2).+1
170-
idx = div.(Wdims,2).+1
159+
# Check that we can make at least one NNlib.conv! call
160+
if !isempty(opsA)
161+
#TODO replace A.ops with opsA in here
162+
ndimsM = ndims(M)
163+
Wdims = ones(Int64,ndimsM)
164+
pad = zeros(Int64, ndimsM)
165+
166+
# compute dimensions of interior kernel W
167+
for L in A.ops
168+
axis = typeof(L).parameters[2]
169+
@assert axis <= ndimsM
170+
Wdims[axis] = max(Wdims[axis],L.stencil_length)
171+
pad[axis] = max(pad[axis], L.boundary_point_count)
172+
end
171173

172-
# add to kernel each stencil
173-
for L in A.ops
174-
s = L.stencil_coefs
175-
sl = L.stencil_length
176-
axis = typeof(L).parameters[2]
177-
offset = convert(Int64,(Wdims[axis] - sl)/2)
178-
coeff = L.coefficients isa Number ? L.coefficients : true
179-
for i in offset+1:Wdims[axis]-offset
180-
idx[axis]=i
181-
W[idx...] += coeff*s[i-offset]
182-
idx[axis] = mid_Wdims[axis]
174+
# create zero-valued kernel
175+
W = zeros(T, Wdims...)
176+
mid_Wdims = div.(Wdims,2).+1
177+
idx = div.(Wdims,2).+1
178+
179+
# add to kernel each stencil
180+
for L in A.ops
181+
s = L.stencil_coefs
182+
sl = L.stencil_length
183+
axis = typeof(L).parameters[2]
184+
offset = convert(Int64,(Wdims[axis] - sl)/2)
185+
coeff = L.coefficients isa Number ? L.coefficients : true
186+
for i in offset+1:Wdims[axis]-offset
187+
idx[axis]=i
188+
W[idx...] += coeff*s[i-offset]
189+
idx[axis] = mid_Wdims[axis]
190+
end
183191
end
184-
end
185192

186-
# Reshape x_temp for NNlib.conv!
187-
_x_temp = reshape(x_temp, (size(x_temp)...,1,1))
193+
# Reshape x_temp for NNlib.conv!
194+
_x_temp = reshape(x_temp, (size(x_temp)...,1,1))
188195

189-
# Reshape M for NNlib.conv!
190-
_M = reshape(M, (size(M)...,1,1))
196+
# Reshape M for NNlib.conv!
197+
_M = reshape(M, (size(M)...,1,1))
191198

192-
_W = reshape(W, (size(W)...,1,1))
199+
_W = reshape(W, (size(W)...,1,1))
193200

194-
# Call NNlib.conv!
195-
cv = DenseConvDims(_M, _W, padding=pad, flipkernel=true)
196-
conv!(_x_temp, _M, _W, cv)
201+
# Call NNlib.conv!
202+
cv = DenseConvDims(_M, _W, padding=pad, flipkernel=true)
203+
conv!(_x_temp, _M, _W, cv)
197204

198205

199-
# convolve boundary and interior points near boundary
200-
# partition operator indices along axis of differentiation
201-
if pad[1] > 0 || pad[2] > 0
202-
ops_1 = Int64[]
203-
ops_1_max_bpc_idx = [0]
204-
ops_2 = Int64[]
205-
ops_2_max_bpc_idx = [0]
206-
for i in 1:length(A.ops)
207-
L = A.ops[i]
208-
if typeof(L).parameters[2] == 1
209-
push!(ops_1,i)
210-
if L.boundary_point_count == pad[1]
211-
ops_1_max_bpc_idx[1] = i
212-
end
213-
else
214-
push!(ops_2,i)
215-
if L.boundary_point_count == pad[2]
216-
ops_2_max_bpc_idx[1]= i
206+
# convolve boundary and interior points near boundary
207+
# partition operator indices along axis of differentiation
208+
if pad[1] > 0 || pad[2] > 0
209+
ops_1 = Int64[]
210+
ops_1_max_bpc_idx = [0]
211+
ops_2 = Int64[]
212+
ops_2_max_bpc_idx = [0]
213+
for i in 1:length(A.ops)
214+
L = A.ops[i]
215+
if typeof(L).parameters[2] == 1
216+
push!(ops_1,i)
217+
if L.boundary_point_count == pad[1]
218+
ops_1_max_bpc_idx[1] = i
219+
end
220+
else
221+
push!(ops_2,i)
222+
if L.boundary_point_count == pad[2]
223+
ops_2_max_bpc_idx[1]= i
224+
end
217225
end
218226
end
219-
end
220227

221-
# need offsets since some axis may have ghost nodes and some may not
222-
offset_x = 0
223-
offset_y = 0
228+
# need offsets since some axis may have ghost nodes and some may not
229+
offset_x = 0
230+
offset_y = 0
224231

225-
if length(ops_2) > 0
226-
offset_x = 1
227-
end
228-
if length(ops_1) > 0
229-
offset_y =1
230-
end
232+
if length(ops_2) > 0
233+
offset_x = 1
234+
end
235+
if length(ops_1) > 0
236+
offset_y =1
237+
end
231238

232-
# convolve boundaries and unaccounted for interior in axis 1
233-
if length(ops_1) > 0
234-
for i in 1:size(x_temp)[2]
235-
convolve_BC_left!(view(x_temp,:,i), view(M,:,i+offset_x), A.ops[ops_1_max_bpc_idx...])
236-
convolve_BC_right!(view(x_temp,:,i), view(M,:,i+offset_x), A.ops[ops_1_max_bpc_idx...])
237-
if i <= pad[2] || i > size(x_temp)[2]-pad[2]
238-
convolve_interior!(view(x_temp,:,i), view(M,:,i+offset_x), A.ops[ops_1_max_bpc_idx...])
239-
end
240-
#scale by dx
241-
242-
for Lidx in ops_1
243-
if Lidx != ops_1_max_bpc_idx[1]
244-
convolve_BC_left_add!(view(x_temp,:,i), view(M,:,i+offset_x), A.ops[Lidx])
245-
convolve_BC_right_add!(view(x_temp,:,i), view(M,:,i+offset_x), A.ops[Lidx])
246-
if i <= pad[2] || i > size(x_temp)[2]-pad[2]
247-
convolve_interior_add!(view(x_temp,:,i), view(M,:,i+offset_x), A.ops[Lidx])
248-
elseif pad[1] - A.ops[Lidx].boundary_point_count > 0
249-
convolve_interior_add_range!(view(x_temp,:,i), view(M,:,i+offset_x), A.ops[Lidx], pad[1] - A.ops[Lidx].boundary_point_count)
239+
# convolve boundaries and unaccounted for interior in axis 1
240+
if length(ops_1) > 0
241+
for i in 1:size(x_temp)[2]
242+
convolve_BC_left!(view(x_temp,:,i), view(M,:,i+offset_x), A.ops[ops_1_max_bpc_idx...])
243+
convolve_BC_right!(view(x_temp,:,i), view(M,:,i+offset_x), A.ops[ops_1_max_bpc_idx...])
244+
if i <= pad[2] || i > size(x_temp)[2]-pad[2]
245+
convolve_interior!(view(x_temp,:,i), view(M,:,i+offset_x), A.ops[ops_1_max_bpc_idx...])
246+
end
247+
248+
for Lidx in ops_1
249+
if Lidx != ops_1_max_bpc_idx[1]
250+
convolve_BC_left_add!(view(x_temp,:,i), view(M,:,i+offset_x), A.ops[Lidx])
251+
convolve_BC_right_add!(view(x_temp,:,i), view(M,:,i+offset_x), A.ops[Lidx])
252+
if i <= pad[2] || i > size(x_temp)[2]-pad[2]
253+
convolve_interior_add!(view(x_temp,:,i), view(M,:,i+offset_x), A.ops[Lidx])
254+
elseif pad[1] - A.ops[Lidx].boundary_point_count > 0
255+
convolve_interior_add_range!(view(x_temp,:,i), view(M,:,i+offset_x), A.ops[Lidx], pad[1] - A.ops[Lidx].boundary_point_count)
256+
end
250257
end
251258
end
252259
end
253260
end
254-
end
255-
# convolve boundaries and unaccounted for interior in axis 2
256-
if length(ops_2) > 0
257-
for i in 1:size(x_temp)[1]
258-
# in the case of no axis 1 operators, we need to over x_temp
259-
if length(ops_1) == 0
260-
convolve_BC_left!(view(x_temp,i,:), view(M,i+offset_y,:), A.ops[ops_2_max_bpc_idx...])
261-
convolve_BC_right!(view(x_temp,i,:), view(M,i+offset_y,:), A.ops[ops_2_max_bpc_idx...])
262-
if i <= pad[1] || i > size(x_temp)[1]-pad[1]
263-
convolve_interior!(view(x_temp,i,:), view(M,i+offset_y,:), A.ops[ops_2_max_bpc_idx...])
264-
end
265-
#scale by dx
266-
# fix here as well
267-
else
268-
convolve_BC_left_add!(view(x_temp,i,:), view(M,i+offset_y,:), A.ops[ops_2_max_bpc_idx...])
269-
convolve_BC_right_add!(view(x_temp,i,:), view(M,i+offset_y,:), A.ops[ops_2_max_bpc_idx...])
270-
if i <= pad[1] || i > size(x_temp)[1]-pad[1]
271-
convolve_interior_add!(view(x_temp,i,:), view(M,i+offset_y,:), A.ops[ops_2_max_bpc_idx...])
272-
end
273-
#scale by dx
274-
# fix here as well
275-
end
276-
for Lidx in ops_2
277-
if Lidx != ops_2_max_bpc_idx[1]
278-
convolve_BC_left_add!(view(x_temp,i,:), view(M,i+offset_y,:), A.ops[Lidx])
279-
convolve_BC_right_add!(view(x_temp,i,:), view(M,i+offset_y,:), A.ops[Lidx])
261+
# convolve boundaries and unaccounted for interior in axis 2
262+
if length(ops_2) > 0
263+
for i in 1:size(x_temp)[1]
264+
# in the case of no axis 1 operators, we need to over x_temp
265+
if length(ops_1) == 0
266+
convolve_BC_left!(view(x_temp,i,:), view(M,i+offset_y,:), A.ops[ops_2_max_bpc_idx...])
267+
convolve_BC_right!(view(x_temp,i,:), view(M,i+offset_y,:), A.ops[ops_2_max_bpc_idx...])
268+
if i <= pad[1] || i > size(x_temp)[1]-pad[1]
269+
convolve_interior!(view(x_temp,i,:), view(M,i+offset_y,:), A.ops[ops_2_max_bpc_idx...])
270+
end
271+
#scale by dx
272+
# fix here as well
273+
else
274+
convolve_BC_left_add!(view(x_temp,i,:), view(M,i+offset_y,:), A.ops[ops_2_max_bpc_idx...])
275+
convolve_BC_right_add!(view(x_temp,i,:), view(M,i+offset_y,:), A.ops[ops_2_max_bpc_idx...])
280276
if i <= pad[1] || i > size(x_temp)[1]-pad[1]
281-
convolve_interior_add!(view(x_temp,i,:), view(M,i+offset_y,:), A.ops[Lidx])
282-
elseif pad[2] - A.ops[Lidx].boundary_point_count > 0
283-
convolve_interior_add_range!(view(x_temp,i,:), view(M,i+offset_y,:), A.ops[Lidx], pad[2] - A.ops[Lidx].boundary_point_count)
277+
convolve_interior_add!(view(x_temp,i,:), view(M,i+offset_y,:), A.ops[ops_2_max_bpc_idx...])
278+
end
279+
#scale by dx
280+
# fix here as well
281+
end
282+
for Lidx in ops_2
283+
if Lidx != ops_2_max_bpc_idx[1]
284+
convolve_BC_left_add!(view(x_temp,i,:), view(M,i+offset_y,:), A.ops[Lidx])
285+
convolve_BC_right_add!(view(x_temp,i,:), view(M,i+offset_y,:), A.ops[Lidx])
286+
if i <= pad[1] || i > size(x_temp)[1]-pad[1]
287+
convolve_interior_add!(view(x_temp,i,:), view(M,i+offset_y,:), A.ops[Lidx])
288+
elseif pad[2] - A.ops[Lidx].boundary_point_count > 0
289+
convolve_interior_add_range!(view(x_temp,i,:), view(M,i+offset_y,:), A.ops[Lidx], pad[2] - A.ops[Lidx].boundary_point_count)
290+
end
284291
end
285292
end
286293
end
287294
end
288295
end
296+
# Call everything A.ops using fallback
297+
else
298+
N = diff_axis(A.ops[1])
299+
if N == 1
300+
mul!(view(x_temp,A.ops[1],M)
301+
else
302+
mul!(view(x_temp,A.ops[1],M)
303+
end
304+
for L in A.ops[2:end]
305+
mul_add!(x_temp,L,M)
306+
end
289307
end
290308
end
291309

@@ -299,7 +317,7 @@ function convolve_interior_add!(x_temp::AbstractVector{T}, x::AbstractVector{T},
299317
mid = div(A.stencil_length,2)
300318
for i in (1+A.boundary_point_count) : (length(x_temp)-A.boundary_point_count)
301319
xtempi = zero(T)
302-
cur_stencil = eltype(stencil) <: AbstractVector ? stencil[i] : stencil
320+
cur_stencil = eltype(stencil) <: AbstractVector ? stencil[i-A.boundary_point_count] : stencil
303321
cur_coeff = typeof(coeff) <: AbstractVector ? coeff[i] : coeff isa Number ? coeff : true
304322
cur_stencil = use_winding(A) && cur_coeff < 0 ? reverse(cur_stencil) : cur_stencil
305323
for idx in 1:A.stencil_length

test/2D_3D_fast_multiplication.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -482,14 +482,14 @@ end
482482
# Test composition of all first-dimension operators
483483
A = Lx2+Lx3+Lx4
484484
M_temp = zeros(N,N+2)
485-
@test_broken mul!(M_temp, A, M)
486-
@test_broken M_temp (Lx2*M + Lx3*M + Lx4*M)
485+
mul!(M_temp, A, M)
486+
@test M_temp (Lx2*M + Lx3*M + Lx4*M)
487487

488488
# Test composition of all second-dimension operators
489489
A = Ly2+Ly3+Ly4
490490
M_temp = zeros(N+2,N)
491-
@test_broken mul!(M_temp, A, M)
492-
@test_broken M_temp (Ly2*M + Ly3*M + Ly4*M)
491+
mul!(M_temp, A, M)
492+
@test M_temp (Ly2*M + Ly3*M + Ly4*M)
493493

494494
# Test composition of all operators
495495
A = Lx2+Lx3+Lx4+Ly2+Ly3+Ly4

0 commit comments

Comments
 (0)