Skip to content
This repository was archived by the owner on Jul 19, 2023. It is now read-only.

Commit 6305844

Browse files
committed
added overwrite keyword argument for convolutions
1 parent b35ddd1 commit 6305844

File tree

2 files changed

+37
-115
lines changed

2 files changed

+37
-115
lines changed
Lines changed: 18 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
# mul! done by convolutions
2-
function LinearAlgebra.mul!(x_temp::AbstractVector{T}, A::DerivativeOperator, x::AbstractVector{T}) where T<:Real
3-
convolve_BC_left!(x_temp, x, A)
4-
convolve_interior!(x_temp, x, A)
5-
convolve_BC_right!(x_temp, x, A)
2+
function LinearAlgebra.mul!(x_temp::AbstractVector{T}, A::DerivativeOperator, x::AbstractVector{T}; overwrite = true) where T<:Real
3+
convolve_BC_left!(x_temp, x, A, overwrite = overwrite)
4+
convolve_interior!(x_temp, x, A, overwrite = overwrite)
5+
convolve_BC_right!(x_temp, x, A, overwrite = overwrite)
66
end
77

88
################################################
99

1010
# Against a standard vector, assume already padded and just apply the stencil
11-
function convolve_interior!(x_temp::AbstractVector{T}, x::AbstractVector{T}, A::DerivativeOperator) where {T<:Real}
11+
function convolve_interior!(x_temp::AbstractVector{T}, x::AbstractVector{T}, A::DerivativeOperator; overwrite = true) where {T<:Real}
1212
@assert length(x_temp)+2 == length(x)
1313
stencil = A.stencil_coefs
1414
coeff = A.coefficients
@@ -21,11 +21,11 @@ function convolve_interior!(x_temp::AbstractVector{T}, x::AbstractVector{T}, A::
2121
for idx in 1:A.stencil_length
2222
xtempi += cur_coeff * cur_stencil[idx] * x[i - mid + idx]
2323
end
24-
x_temp[i] = xtempi
24+
x_temp[i] = xtempi + !:($overwrite)*x_temp[i]
2525
end
2626
end
2727

28-
function convolve_BC_left!(x_temp::AbstractVector{T}, x::AbstractVector{T}, A::DerivativeOperator) where {T<:Real}
28+
function convolve_BC_left!(x_temp::AbstractVector{T}, x::AbstractVector{T}, A::DerivativeOperator; overwrite = true) where {T<:Real}
2929
stencil = A.low_boundary_coefs
3030
coeff = A.coefficients
3131
for i in 1 : A.boundary_point_count
@@ -36,11 +36,11 @@ function convolve_BC_left!(x_temp::AbstractVector{T}, x::AbstractVector{T}, A::D
3636
for idx in 2:A.boundary_stencil_length
3737
xtempi += cur_coeff * cur_stencil[idx] * x[idx]
3838
end
39-
x_temp[i] = xtempi
39+
x_temp[i] = xtempi + !:($overwrite)*x_temp[i]
4040
end
4141
end
4242

43-
function convolve_BC_right!(x_temp::AbstractVector{T}, x::AbstractVector{T}, A::DerivativeOperator) where {T<:Real}
43+
function convolve_BC_right!(x_temp::AbstractVector{T}, x::AbstractVector{T}, A::DerivativeOperator; overwrite = true) where {T<:Real}
4444
stencil = A.high_boundary_coefs
4545
coeff = A.coefficients
4646
for i in 1 : A.boundary_point_count
@@ -51,14 +51,14 @@ function convolve_BC_right!(x_temp::AbstractVector{T}, x::AbstractVector{T}, A::
5151
for idx in (A.boundary_stencil_length-1):-1:1
5252
xtempi += cur_coeff * cur_stencil[end-idx] * x[end-idx]
5353
end
54-
x_temp[end-A.boundary_point_count+i] = xtempi
54+
x_temp[end-A.boundary_point_count+i] = xtempi + !:($overwrite)*x_temp[end-A.boundary_point_count+i]
5555
end
5656
end
5757

5858
###########################################
5959

6060
# Against A BC-padded vector, specialize the computation to explicitly use the left, right, and middle parts
61-
function convolve_interior!(x_temp::AbstractVector{T}, _x::BoundaryPaddedVector, A::DerivativeOperator) where {T<:Real}
61+
function convolve_interior!(x_temp::AbstractVector{T}, _x::BoundaryPaddedVector, A::DerivativeOperator; overwrite = true) where {T<:Real}
6262
stencil = A.stencil_coefs
6363
coeff = A.coefficients
6464
x = _x.u
@@ -72,11 +72,11 @@ function convolve_interior!(x_temp::AbstractVector{T}, _x::BoundaryPaddedVector,
7272
@inbounds for idx in 1:A.stencil_length
7373
xtempi += cur_coeff * cur_stencil[idx] * x[(i-1) - (mid-idx) + 1]
7474
end
75-
x_temp[i] = xtempi
75+
x_temp[i] = xtempi + !:($overwrite)*x_temp[i]
7676
end
7777
end
7878

79-
function convolve_BC_left!(x_temp::AbstractVector{T}, _x::BoundaryPaddedVector, A::DerivativeOperator) where {T<:Real}
79+
function convolve_BC_left!(x_temp::AbstractVector{T}, _x::BoundaryPaddedVector, A::DerivativeOperator; overwrite = true) where {T<:Real}
8080
stencil = A.low_boundary_coefs
8181
coeff = A.coefficients
8282
for i in 1 : A.boundary_point_count
@@ -87,7 +87,7 @@ function convolve_BC_left!(x_temp::AbstractVector{T}, _x::BoundaryPaddedVector,
8787
@inbounds for idx in 2:A.boundary_stencil_length
8888
xtempi += cur_coeff * cur_stencil[idx] * _x.u[idx-1]
8989
end
90-
x_temp[i] = xtempi
90+
x_temp[i] = xtempi + !:($overwrite)*x_temp[i]
9191
end
9292
# need to account for x.l in first interior
9393
mid = div(A.stencil_length,2) + 1
@@ -101,10 +101,10 @@ function convolve_BC_left!(x_temp::AbstractVector{T}, _x::BoundaryPaddedVector,
101101
@inbounds for idx in 2:A.stencil_length
102102
xtempi += cur_coeff * cur_stencil[idx] * x[(i-1) - (mid-idx) + 1]
103103
end
104-
x_temp[i] = xtempi
104+
x_temp[i] = xtempi + !:($overwrite)*x_temp[i]
105105
end
106106

107-
function convolve_BC_right!(x_temp::AbstractVector{T}, _x::BoundaryPaddedVector, A::DerivativeOperator) where {T<:Real}
107+
function convolve_BC_right!(x_temp::AbstractVector{T}, _x::BoundaryPaddedVector, A::DerivativeOperator; overwrite = true) where {T<:Real}
108108
stencil = A.high_boundary_coefs
109109
coeff = A.coefficients
110110
bc_start = length(_x.u) - A.boundary_point_count
@@ -120,7 +120,7 @@ function convolve_BC_right!(x_temp::AbstractVector{T}, _x::BoundaryPaddedVector,
120120
@inbounds for idx in 1:A.stencil_length-1
121121
xtempi += cur_coeff * cur_stencil[idx] * x[(i-1) - (mid-idx) + 1]
122122
end
123-
x_temp[i] = xtempi
123+
x_temp[i] = xtempi + !:($overwrite)*x_temp[i]
124124
for i in 1 : A.boundary_point_count
125125
cur_stencil = stencil[i]
126126
cur_coeff = typeof(coeff) <: AbstractVector ? coeff[bc_start + i] : coeff isa Number ? coeff : true
@@ -129,30 +129,12 @@ function convolve_BC_right!(x_temp::AbstractVector{T}, _x::BoundaryPaddedVector,
129129
@inbounds for idx in A.stencil_length:-1:1
130130
xtempi += cur_coeff * cur_stencil[end-idx] * _x.u[end-idx+1]
131131
end
132-
x_temp[bc_start + i] = xtempi
132+
x_temp[bc_start + i] = xtempi + !:($overwrite)*x_temp[bc_start + i]
133133
end
134134
end
135135

136136
###########################################
137137

138-
# Implementations of additive convolutions, necessary for compositions of operators
139-
function convolve_interior_add!(x_temp::AbstractVector{T}, x::AbstractVector{T}, A::DerivativeOperator) where {T<:Real}
140-
@assert length(x_temp)+2 == length(x)
141-
stencil = A.stencil_coefs
142-
coeff = A.coefficients
143-
mid = div(A.stencil_length,2)
144-
for i in (1+A.boundary_point_count) : (length(x_temp)-A.boundary_point_count)
145-
xtempi = zero(T)
146-
cur_stencil = eltype(stencil) <: AbstractVector ? stencil[i-A.boundary_point_count] : stencil
147-
cur_coeff = typeof(coeff) <: AbstractVector ? coeff[i] : coeff isa Number ? coeff : true
148-
cur_stencil = use_winding(A) && cur_coeff < 0 ? reverse(cur_stencil) : cur_stencil
149-
for idx in 1:A.stencil_length
150-
xtempi += cur_coeff * cur_stencil[idx] * x[i - mid + idx]
151-
end
152-
x_temp[i] += xtempi
153-
end
154-
end
155-
156138
function convolve_interior_add_range!(x_temp::AbstractVector{T}, x::AbstractVector{T}, A::DerivativeOperator, offset::Int) where {T<:Real}
157139
@assert length(x_temp)+2 == length(x)
158140
stencil = A.stencil_coefs
@@ -169,33 +151,3 @@ function convolve_interior_add_range!(x_temp::AbstractVector{T}, x::AbstractVect
169151
x_temp[i] += xtempi
170152
end
171153
end
172-
173-
function convolve_BC_left_add!(x_temp::AbstractVector{T}, x::AbstractVector{T}, A::DerivativeOperator) where {T<:Real}
174-
stencil = A.low_boundary_coefs
175-
coeff = A.coefficients
176-
for i in 1 : A.boundary_point_count
177-
cur_stencil = stencil[i]
178-
cur_coeff = typeof(coeff) <: AbstractVector ? coeff[i] : coeff isa Number ? coeff : true
179-
cur_stencil = use_winding(A) && cur_coeff < 0 ? reverse(cur_stencil) : cur_stencil
180-
xtempi = cur_coeff*stencil[i][1]*x[1]
181-
for idx in 2:A.boundary_stencil_length
182-
xtempi += cur_coeff * cur_stencil[idx] * x[idx]
183-
end
184-
x_temp[i] += xtempi
185-
end
186-
end
187-
188-
function convolve_BC_right_add!(x_temp::AbstractVector{T}, x::AbstractVector{T}, A::DerivativeOperator) where {T<:Real}
189-
stencil = A.high_boundary_coefs
190-
coeff = A.coefficients
191-
for i in 1 : A.boundary_point_count
192-
cur_stencil = stencil[i]
193-
cur_coeff = typeof(coeff) <: AbstractVector ? coeff[i] : coeff isa Number ? coeff : true
194-
cur_stencil = use_winding(A) && cur_coeff < 0 ? reverse(cur_stencil) : cur_stencil
195-
xtempi = cur_coeff*stencil[i][end]*x[end]
196-
for idx in (A.boundary_stencil_length-1):-1:1
197-
xtempi += cur_coeff * cur_stencil[end-idx] * x[end-idx]
198-
end
199-
x_temp[end-A.boundary_point_count+i] += xtempi
200-
end
201-
end

src/derivative_operators/derivative_operator_functions.jl

Lines changed: 19 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# Fallback mul! implementation for a single DerivativeOperator operating on an AbstractArray
2-
function LinearAlgebra.mul!(x_temp::AbstractArray{T}, A::DerivativeOperator{T,N}, M::AbstractArray{T}) where {T,N}
2+
function LinearAlgebra.mul!(x_temp::AbstractArray{T}, A::DerivativeOperator{T,N}, M::AbstractArray{T}; overwrite = true) where {T,N}
33

44
# Check that x_temp has correct dimensions
55
v = zeros(ndims(x_temp))
@@ -22,37 +22,7 @@ function LinearAlgebra.mul!(x_temp::AbstractArray{T}, A::DerivativeOperator{T,N}
2222
setindex!(idx, :, N)
2323
for I in indices
2424
Base.replace_tuples!(nidx, idx, idx, otherdims, I)
25-
mul!(view(x_temp, idx...), A, view(M, idx...))
26-
end
27-
end
28-
29-
# Additive mul! fallback that is necessary for handling compositions
30-
function mul_add!(x_temp::AbstractArray{T}, A::DerivativeOperator{T,N}, M::AbstractArray{T}) where {T,N}
31-
32-
# Check that x_temp has correct dimensions
33-
v = zeros(ndims(x_temp))
34-
v[N] = 2
35-
@assert [size(x_temp)...]+v == [size(M)...]
36-
37-
# Check that axis of differentiation is in the dimensions of M and x_temp
38-
ndimsM = ndims(M)
39-
@assert N <= ndimsM
40-
41-
dimsM = [axes(M)...]
42-
alldims = [1:ndims(M);]
43-
otherdims = setdiff(alldims, N)
44-
45-
idx = Any[first(ind) for ind in axes(M)]
46-
itershape = tuple(dimsM[otherdims]...)
47-
nidx = length(otherdims)
48-
indices = Iterators.drop(CartesianIndices(itershape), 0)
49-
50-
setindex!(idx, :, N)
51-
for I in indices
52-
Base.replace_tuples!(nidx, idx, idx, otherdims, I)
53-
convolve_interior_add!(view(x_temp, idx...), view(M, idx...), A)
54-
convolve_BC_right_add!(view(x_temp, idx...), view(M, idx...), A)
55-
convolve_BC_left_add!(view(x_temp, idx...), view(M, idx...), A)
25+
mul!(view(x_temp, idx...), A, view(M, idx...), overwrite = overwrite)
5626
end
5727
end
5828

@@ -252,10 +222,10 @@ function LinearAlgebra.mul!(x_temp::AbstractArray{T,2}, A::AbstractDiffEqComposi
252222

253223
for Lidx in ops_1
254224
if Lidx != ops_1_max_bpc_idx[1]
255-
convolve_BC_left_add!(view(x_temp,:,i), view(M,:,i+offset_x), opsA[Lidx])
256-
convolve_BC_right_add!(view(x_temp,:,i), view(M,:,i+offset_x), opsA[Lidx])
225+
convolve_BC_left!(view(x_temp,:,i), view(M,:,i+offset_x), opsA[Lidx], overwrite = false)
226+
convolve_BC_right!(view(x_temp,:,i), view(M,:,i+offset_x), opsA[Lidx], overwrite = false)
257227
if i <= pad[2] || i > size(x_temp)[2]-pad[2]
258-
convolve_interior_add!(view(x_temp,:,i), view(M,:,i+offset_x), opsA[Lidx])
228+
convolve_interior!(view(x_temp,:,i), view(M,:,i+offset_x), opsA[Lidx], overwrite = false)
259229
elseif pad[1] - opsA[Lidx].boundary_point_count > 0
260230
convolve_interior_add_range!(view(x_temp,:,i), view(M,:,i+offset_x), opsA[Lidx], pad[1] - opsA[Lidx].boundary_point_count)
261231
end
@@ -275,19 +245,19 @@ function LinearAlgebra.mul!(x_temp::AbstractArray{T,2}, A::AbstractDiffEqComposi
275245
end
276246

277247
else
278-
convolve_BC_left_add!(view(x_temp,i,:), view(M,i+offset_y,:), opsA[ops_2_max_bpc_idx...])
279-
convolve_BC_right_add!(view(x_temp,i,:), view(M,i+offset_y,:), opsA[ops_2_max_bpc_idx...])
248+
convolve_BC_left!(view(x_temp,i,:), view(M,i+offset_y,:), opsA[ops_2_max_bpc_idx...], overwrite = false)
249+
convolve_BC_right!(view(x_temp,i,:), view(M,i+offset_y,:), opsA[ops_2_max_bpc_idx...], overwrite = false)
280250
if i <= pad[1] || i > size(x_temp)[1]-pad[1]
281-
convolve_interior_add!(view(x_temp,i,:), view(M,i+offset_y,:), opsA[ops_2_max_bpc_idx...])
251+
convolve_interior!(view(x_temp,i,:), view(M,i+offset_y,:), opsA[ops_2_max_bpc_idx...], overwrite = false)
282252
end
283253

284254
end
285255
for Lidx in ops_2
286256
if Lidx != ops_2_max_bpc_idx[1]
287-
convolve_BC_left_add!(view(x_temp,i,:), view(M,i+offset_y,:), opsA[Lidx])
288-
convolve_BC_right_add!(view(x_temp,i,:), view(M,i+offset_y,:), opsA[Lidx])
257+
convolve_BC_left!(view(x_temp,i,:), view(M,i+offset_y,:), opsA[Lidx], overwrite = false)
258+
convolve_BC_right!(view(x_temp,i,:), view(M,i+offset_y,:), opsA[Lidx], overwrite = false)
289259
if i <= pad[1] || i > size(x_temp)[1]-pad[1]
290-
convolve_interior_add!(view(x_temp,i,:), view(M,i+offset_y,:), opsA[Lidx])
260+
convolve_interior!(view(x_temp,i,:), view(M,i+offset_y,:), opsA[Lidx], overwrite = false)
291261
elseif pad[2] - opsA[Lidx].boundary_point_count > 0
292262
convolve_interior_add_range!(view(x_temp,i,:), view(M,i+offset_y,:), opsA[Lidx], pad[2] - opsA[Lidx].boundary_point_count)
293263
end
@@ -316,15 +286,15 @@ function LinearAlgebra.mul!(x_temp::AbstractArray{T,2}, A::AbstractDiffEqComposi
316286
N = diff_axis(L)
317287
if N == 1
318288
if operating_dims[2] == 1
319-
mul_add!(x_temp,L,view(M,1:x_temp_1+2,1:x_temp_2))
289+
mul!(x_temp,L,view(M,1:x_temp_1+2,1:x_temp_2), overwrite = false)
320290
else
321-
mul_add!(x_temp,L,M)
291+
mul!(x_temp,L,M, overwrite = false)
322292
end
323293
else
324294
if operating_dims[1] == 1
325-
mul_add!(x_temp,L,view(M,1:x_temp_1,1:x_temp_2+2))
295+
mul!(x_temp,L,view(M,1:x_temp_1,1:x_temp_2+2), overwrite = false)
326296
else
327-
mul_add!(x_temp,L,M)
297+
mul!(x_temp,L,M, overwrite = false)
328298
end
329299
end
330300
end
@@ -363,15 +333,15 @@ function LinearAlgebra.mul!(x_temp::AbstractArray{T,2}, A::AbstractDiffEqComposi
363333
N = diff_axis(L)
364334
if N == 1
365335
if operating_dims[2] == 1
366-
mul_add!(x_temp,L,view(M,1:x_temp_1+2,1:x_temp_2))
336+
mul!(x_temp,L,view(M,1:x_temp_1+2,1:x_temp_2), overwrite = false)
367337
else
368-
mul_add!(x_temp,L,M)
338+
mul!(x_temp,L,M, overwrite = false)
369339
end
370340
else
371341
if operating_dims[1] == 1
372-
mul_add!(x_temp,L,view(M,1:x_temp_1,1:x_temp_2+2))
342+
mul!(x_temp,L,view(M,1:x_temp_1,1:x_temp_2+2), overwrite = false)
373343
else
374-
mul_add!(x_temp,L,M)
344+
mul!(x_temp,L,M, overwrite = false)
375345
end
376346
end
377347
end

0 commit comments

Comments
 (0)