Skip to content
This repository was archived by the owner on Jul 19, 2023. It is now read-only.

Commit 22308ab

Browse files
committed
all tests passing now, requires a few irregular and regular operator composition tests
1 parent 7b390f8 commit 22308ab

File tree

2 files changed

+60
-27
lines changed

2 files changed

+60
-27
lines changed

src/derivative_operators/derivative_operator_functions.jl

Lines changed: 56 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,8 @@ function LinearAlgebra.mul!(x_temp::AbstractArray{T,2}, A::AbstractDiffEqComposi
166166
pad = zeros(Int64, ndimsM)
167167

168168
# compute dimensions of interior kernel W
169+
# Here we still use A.ops since the other dimensions may indicate that
170+
# we have more padding to account for
169171
for L in A.ops
170172
axis = typeof(L).parameters[2]
171173
@assert axis <= ndimsM
@@ -179,7 +181,7 @@ function LinearAlgebra.mul!(x_temp::AbstractArray{T,2}, A::AbstractDiffEqComposi
179181
idx = div.(Wdims,2).+1
180182

181183
# add to kernel each stencil
182-
for L in A.ops
184+
for L in opsA
183185
s = L.stencil_coefs
184186
sl = L.stencil_length
185187
axis = typeof(L).parameters[2]
@@ -212,8 +214,8 @@ function LinearAlgebra.mul!(x_temp::AbstractArray{T,2}, A::AbstractDiffEqComposi
212214
ops_1_max_bpc_idx = [0]
213215
ops_2 = Int64[]
214216
ops_2_max_bpc_idx = [0]
215-
for i in 1:length(A.ops)
216-
L = A.ops[i]
217+
for i in 1:length(opsA)
218+
L = opsA[i]
217219
if typeof(L).parameters[2] == 1
218220
push!(ops_1,i)
219221
if L.boundary_point_count == pad[1]
@@ -241,20 +243,20 @@ function LinearAlgebra.mul!(x_temp::AbstractArray{T,2}, A::AbstractDiffEqComposi
241243
# convolve boundaries and unaccounted for interior in axis 1
242244
if length(ops_1) > 0
243245
for i in 1:size(x_temp)[2]
244-
convolve_BC_left!(view(x_temp,:,i), view(M,:,i+offset_x), A.ops[ops_1_max_bpc_idx...])
245-
convolve_BC_right!(view(x_temp,:,i), view(M,:,i+offset_x), A.ops[ops_1_max_bpc_idx...])
246+
convolve_BC_left!(view(x_temp,:,i), view(M,:,i+offset_x), opsA[ops_1_max_bpc_idx...])
247+
convolve_BC_right!(view(x_temp,:,i), view(M,:,i+offset_x), opsA[ops_1_max_bpc_idx...])
246248
if i <= pad[2] || i > size(x_temp)[2]-pad[2]
247-
convolve_interior!(view(x_temp,:,i), view(M,:,i+offset_x), A.ops[ops_1_max_bpc_idx...])
249+
convolve_interior!(view(x_temp,:,i), view(M,:,i+offset_x), opsA[ops_1_max_bpc_idx...])
248250
end
249251

250252
for Lidx in ops_1
251253
if Lidx != ops_1_max_bpc_idx[1]
252-
convolve_BC_left_add!(view(x_temp,:,i), view(M,:,i+offset_x), A.ops[Lidx])
253-
convolve_BC_right_add!(view(x_temp,:,i), view(M,:,i+offset_x), A.ops[Lidx])
254+
convolve_BC_left_add!(view(x_temp,:,i), view(M,:,i+offset_x), opsA[Lidx])
255+
convolve_BC_right_add!(view(x_temp,:,i), view(M,:,i+offset_x), opsA[Lidx])
254256
if i <= pad[2] || i > size(x_temp)[2]-pad[2]
255-
convolve_interior_add!(view(x_temp,:,i), view(M,:,i+offset_x), A.ops[Lidx])
256-
elseif pad[1] - A.ops[Lidx].boundary_point_count > 0
257-
convolve_interior_add_range!(view(x_temp,:,i), view(M,:,i+offset_x), A.ops[Lidx], pad[1] - A.ops[Lidx].boundary_point_count)
257+
convolve_interior_add!(view(x_temp,:,i), view(M,:,i+offset_x), opsA[Lidx])
258+
elseif pad[1] - opsA[Lidx].boundary_point_count > 0
259+
convolve_interior_add_range!(view(x_temp,:,i), view(M,:,i+offset_x), opsA[Lidx], pad[1] - opsA[Lidx].boundary_point_count)
258260
end
259261
end
260262
end
@@ -265,36 +267,67 @@ function LinearAlgebra.mul!(x_temp::AbstractArray{T,2}, A::AbstractDiffEqComposi
265267
for i in 1:size(x_temp)[1]
266268
# in the case of no axis 1 operators, we need to over x_temp
267269
if length(ops_1) == 0
268-
convolve_BC_left!(view(x_temp,i,:), view(M,i+offset_y,:), A.ops[ops_2_max_bpc_idx...])
269-
convolve_BC_right!(view(x_temp,i,:), view(M,i+offset_y,:), A.ops[ops_2_max_bpc_idx...])
270+
convolve_BC_left!(view(x_temp,i,:), view(M,i+offset_y,:), opsA[ops_2_max_bpc_idx...])
271+
convolve_BC_right!(view(x_temp,i,:), view(M,i+offset_y,:), opsA[ops_2_max_bpc_idx...])
270272
if i <= pad[1] || i > size(x_temp)[1]-pad[1]
271-
convolve_interior!(view(x_temp,i,:), view(M,i+offset_y,:), A.ops[ops_2_max_bpc_idx...])
273+
convolve_interior!(view(x_temp,i,:), view(M,i+offset_y,:), opsA[ops_2_max_bpc_idx...])
272274
end
273275
#scale by dx
274276
# fix here as well
275277
else
276-
convolve_BC_left_add!(view(x_temp,i,:), view(M,i+offset_y,:), A.ops[ops_2_max_bpc_idx...])
277-
convolve_BC_right_add!(view(x_temp,i,:), view(M,i+offset_y,:), A.ops[ops_2_max_bpc_idx...])
278+
convolve_BC_left_add!(view(x_temp,i,:), view(M,i+offset_y,:), opsA[ops_2_max_bpc_idx...])
279+
convolve_BC_right_add!(view(x_temp,i,:), view(M,i+offset_y,:), opsA[ops_2_max_bpc_idx...])
278280
if i <= pad[1] || i > size(x_temp)[1]-pad[1]
279-
convolve_interior_add!(view(x_temp,i,:), view(M,i+offset_y,:), A.ops[ops_2_max_bpc_idx...])
281+
convolve_interior_add!(view(x_temp,i,:), view(M,i+offset_y,:), opsA[ops_2_max_bpc_idx...])
280282
end
281283
#scale by dx
282284
# fix here as well
283285
end
284286
for Lidx in ops_2
285287
if Lidx != ops_2_max_bpc_idx[1]
286-
convolve_BC_left_add!(view(x_temp,i,:), view(M,i+offset_y,:), A.ops[Lidx])
287-
convolve_BC_right_add!(view(x_temp,i,:), view(M,i+offset_y,:), A.ops[Lidx])
288+
convolve_BC_left_add!(view(x_temp,i,:), view(M,i+offset_y,:), opsA[Lidx])
289+
convolve_BC_right_add!(view(x_temp,i,:), view(M,i+offset_y,:), opsA[Lidx])
288290
if i <= pad[1] || i > size(x_temp)[1]-pad[1]
289-
convolve_interior_add!(view(x_temp,i,:), view(M,i+offset_y,:), A.ops[Lidx])
290-
elseif pad[2] - A.ops[Lidx].boundary_point_count > 0
291-
convolve_interior_add_range!(view(x_temp,i,:), view(M,i+offset_y,:), A.ops[Lidx], pad[2] - A.ops[Lidx].boundary_point_count)
291+
convolve_interior_add!(view(x_temp,i,:), view(M,i+offset_y,:), opsA[Lidx])
292+
elseif pad[2] - opsA[Lidx].boundary_point_count > 0
293+
convolve_interior_add_range!(view(x_temp,i,:), view(M,i+offset_y,:), opsA[Lidx], pad[2] - opsA[Lidx].boundary_point_count)
292294
end
293295
end
294296
end
295297
end
296298
end
297299
end
300+
#operating_dims
301+
operating_dims = zeros(Int64,2)
302+
# need to consider all dimensions and operators to determine the truncation
303+
# of M to x_temp
304+
for L in A.ops
305+
if diff_axis(L) == 1
306+
operating_dims[1] = 1
307+
else
308+
operating_dims[2] = 1
309+
end
310+
end
311+
312+
x_temp_1, x_temp_2 = size(x_temp)
313+
314+
for L in opsB
315+
N = diff_axis(L)
316+
if N == 1
317+
if operating_dims[2] == 1
318+
mul_add!(x_temp,L,view(M,1:x_temp_1+2,1:x_temp_2))
319+
else
320+
mul_add!(x_temp,L,M)
321+
end
322+
else
323+
if operating_dims[1] == 1
324+
mul_add!(x_temp,L,view(M,1:x_temp_1,1:x_temp_2+2))
325+
else
326+
mul_add!(x_temp,L,M)
327+
end
328+
end
329+
end
330+
298331
# Call everything A.ops using fallback
299332
else
300333
#operating_dims
@@ -309,7 +342,7 @@ function LinearAlgebra.mul!(x_temp::AbstractArray{T,2}, A::AbstractDiffEqComposi
309342

310343
x_temp_1, x_temp_2 = size(x_temp)
311344

312-
# Handle first case additively
345+
# Handle first case non-additively
313346
N = diff_axis(A.ops[1])
314347
if N == 1
315348
if operating_dims[2] == 1

test/2D_3D_fast_multiplication.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -539,8 +539,8 @@ end
539539
# Test that composition of both x and y operators works
540540
A = Lx2 + Ly2 + Lx3 + Ly3 + Ly4 + Lx4
541541
M_temp = zeros(N,N)
542-
@test_broken mul!(M_temp, A, M)
543-
@test_broken M_temp ((Lx2*M)[1:N,2:N+1]+(Lx3*M)[1:N,2:N+1]+(Lx4*M)[1:N,2:N+1]+(Ly2*M)[2:N+1,1:N]+(Ly3*M)[2:N+1,1:N]+(Ly4*M)[2:N+1,1:N])
542+
mul!(M_temp, A, M)
543+
@test M_temp ((Lx2*M)[1:N,2:N+1]+(Lx3*M)[1:N,2:N+1]+(Lx4*M)[1:N,2:N+1]+(Ly2*M)[2:N+1,1:N]+(Ly3*M)[2:N+1,1:N]+(Ly4*M)[2:N+1,1:N])
544544

545545
end
546546

@@ -586,7 +586,7 @@ end
586586
# Test that composition of both x and y operators works
587587
A = Lx2 + Ly2 + Lx3 + Ly3 + Ly4 + Lx4
588588
M_temp = zeros(N,N)
589-
@test_broken mul!(M_temp, A, M)
590-
@test_broken M_temp ((Lx2*M)[1:N,2:N+1]+(Lx3*M)[1:N,2:N+1]+(Lx4*M)[1:N,2:N+1]+(Ly2*M)[2:N+1,1:N]+(Ly3*M)[2:N+1,1:N]+(Ly4*M)[2:N+1,1:N])
589+
mul!(M_temp, A, M)
590+
@test M_temp ((Lx2*M)[1:N,2:N+1]+(Lx3*M)[1:N,2:N+1]+(Lx4*M)[1:N,2:N+1]+(Ly2*M)[2:N+1,1:N]+(Ly3*M)[2:N+1,1:N]+(Ly4*M)[2:N+1,1:N])
591591

592592
end

0 commit comments

Comments
 (0)