Skip to content
This repository was archived by the owner on Jul 19, 2023. It is now read-only.

Commit af461f6

Browse files
Merge pull request #129 from JuliaDiffEq/nnlib_convolutions
Fixed dispatch 2D/3D multiplication for DerivativeOperator
2 parents 7213994 + 42312c7 commit af461f6

File tree

2 files changed

+133
-37
lines changed

2 files changed

+133
-37
lines changed

src/derivative_operators/derivative_operator_functions.jl

Lines changed: 28 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ end
2727

2828
for MT in [2,3]
2929
@eval begin
30-
function LinearAlgebra.mul!(x_temp::AbstractArray{T,$MT}, A::DerivativeOperator{T,N,Wind,T2,S1}, M::AbstractArray{T,$MT}) where {T<:Real,N,Wind,T2,SL,S1<:SArray{Tuple{SL},T,1,SL}}
30+
function LinearAlgebra.mul!(x_temp::AbstractArray{T,$MT}, A::DerivativeOperator{T,N,false,T2,S1}, M::AbstractArray{T,$MT}) where {T<:Real,N,T2,SL,S1<:SArray{Tuple{SL},T,1,SL}}
3131

3232
# Check that x_temp has correct dimensions
3333
v = zeros(ndims(x_temp))
@@ -38,52 +38,49 @@ for MT in [2,3]
3838
ndimsM = ndims(M)
3939
@assert N <= ndimsM
4040

41-
# Respahe x_temp for NNlib.conv!
42-
new_size = Any[size(x_temp)...]
41+
# Determine padding for NNlib.conv!
4342
bpc = A.boundary_point_count
44-
setindex!(new_size, new_size[N]- 2*bpc, N)
45-
new_shape = []
46-
for i in 1:ndimsM
47-
if i != N
48-
push!(new_shape,:)
49-
else
50-
push!(new_shape,bpc+1:new_size[N]+bpc)
51-
end
52-
end
53-
_x_temp = reshape(view(x_temp, new_shape...), (new_size...,1,1))
43+
pad = zeros(Int64,ndimsM)
44+
pad[N] = bpc
45+
46+
# Reshape x_temp for NNlib.conv!
47+
_x_temp = reshape(x_temp, (size(x_temp)...,1,1))
5448

5549
# Reshape M for NNlib.conv!
5650
_M = reshape(M, (size(M)...,1,1))
57-
s = A.stencil_coefs
58-
sl = A.stencil_length
5951

6052
# Setup W, the kernel for NNlib.conv!
53+
s = A.stencil_coefs
54+
sl = A.stencil_length
6155
Wdims = ones(Int64, ndims(_x_temp))
6256
Wdims[N] = sl
6357
W = zeros(Wdims...)
6458
Widx = Any[Wdims...]
6559
setindex!(Widx,:,N)
66-
W[Widx...] = s ./ A.dx^A.derivative_order # this will change later
67-
cv = DenseConvDims(_M, W)
60+
W[Widx...] = s
6861

62+
cv = DenseConvDims(_M, W, padding=pad)
6963
conv!(_x_temp, _M, W, cv)
7064

7165
# Now deal with boundaries
72-
dimsM = [axes(M)...]
73-
alldims = [1:ndims(M);]
74-
otherdims = setdiff(alldims, N)
75-
76-
idx = Any[first(ind) for ind in axes(M)]
77-
itershape = tuple(dimsM[otherdims]...)
78-
nidx = length(otherdims)
79-
indices = Iterators.drop(CartesianIndices(itershape), 0)
80-
81-
setindex!(idx, :, N)
82-
for I in indices
83-
Base.replace_tuples!(nidx, idx, idx, otherdims, I)
84-
convolve_BC_left!(view(x_temp, idx...), view(M, idx...), A)
85-
convolve_BC_right!(view(x_temp, idx...), view(M, idx...), A)
66+
if bpc > 0
67+
dimsM = [axes(M)...]
68+
alldims = [1:ndims(M);]
69+
otherdims = setdiff(alldims, N)
70+
71+
idx = Any[first(ind) for ind in axes(M)]
72+
itershape = tuple(dimsM[otherdims]...)
73+
nidx = length(otherdims)
74+
indices = Iterators.drop(CartesianIndices(itershape), 0)
75+
76+
setindex!(idx, :, N)
77+
for I in indices
78+
Base.replace_tuples!(nidx, idx, idx, otherdims, I)
79+
convolve_BC_left!(view(x_temp, idx...), view(M, idx...), A)
80+
convolve_BC_right!(view(x_temp, idx...), view(M, idx...), A)
81+
end
8682
end
83+
mul!(x_temp,x_temp,1/A.dx^A.derivative_order)
8784
end
8885
end
8986
end

test/differentiation_dimension.jl

Lines changed: 105 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,46 @@ function second_derivative_stencil(N)
2121
A
2222
end
2323

24-
@testset "Differenting first three dimensions" begin
24+
@testset "Differentiation on 2D array" begin
25+
M = zeros(22,22)
26+
M_temp = zeros(20,22)
27+
indices = Iterators.drop(CartesianIndices((22,22)), 0)
28+
for idx in indices
29+
M[idx] = sin(idx[1]*0.1)
30+
end
31+
L = CenteredDifference(4,4,0.1,20)
32+
33+
#test mul!
34+
mul!(M_temp, L, M)
35+
36+
correct_row = 10000.0*fourth_deriv_approx_stencil(20)*M[:,1]
37+
38+
for i in 1:22
39+
@test M_temp[:,i] correct_row
40+
end
41+
42+
# Test that * agrees will mul!
43+
@test M_temp == L*M
44+
45+
# Differentiation along second dimension
46+
L = CenteredDifference{2}(4,4,0.1,20)
47+
M_temp_2 = zeros(22,20)
48+
indices = Iterators.drop(CartesianIndices((22,22)), 0)
49+
for idx in indices
50+
M[idx] = sin(idx[2]*0.1)
51+
end
52+
53+
#test mul!
54+
mul!(M_temp_2, L, M)
55+
for i in 1:22
56+
@test M_temp_2[i,:] correct_row
57+
end
58+
59+
# Test that * agrees will mul!
60+
@test M_temp_2 == L*M
61+
end
62+
63+
@testset "Differenting on 3D array with L2" begin
2564
M = zeros(22,22,22)
2665
M_temp = zeros(20,22,22)
2766
indices = Iterators.drop(CartesianIndices((22,22,22)), 0)
@@ -83,21 +122,81 @@ end
83122
@test M_temp_3 == L*M
84123
end
85124

125+
@testset "Differentiation on 3D array with L4" begin
126+
M = zeros(22,22,22)
127+
M_temp = zeros(20,22,22)
128+
indices = Iterators.drop(CartesianIndices((22,22,22)), 0)
129+
for idx in indices
130+
M[idx] = sin(idx[1]*0.1)
131+
end
132+
L = CenteredDifference(4,4,0.1,20)
133+
134+
#test mul!
135+
mul!(M_temp, L, M)
136+
137+
correct_row = 10000.0*fourth_deriv_approx_stencil(20)*M[:,1,1]
138+
139+
for i in 1:22
140+
for j in 1:22
141+
@test M_temp[:,i,j] correct_row
142+
end
143+
end
144+
145+
# Test that * agrees will mul!
146+
@test M_temp == L*M
147+
148+
# Differentiation along second dimension
149+
L = CenteredDifference{2}(4,4,0.1,20)
150+
M_temp_2 = zeros(22,20,22)
151+
for idx in indices
152+
M[idx] = sin(idx[2]*0.1)
153+
end
154+
155+
#test mul!
156+
mul!(M_temp_2, L, M)
157+
for i in 1:22
158+
for j in 1:22
159+
@test M_temp_2[i,:,j] correct_row
160+
end
161+
end
162+
163+
# Test that * agrees will mul!
164+
@test M_temp_2 == L*M
165+
166+
# Differentiation along third dimension
167+
L = CenteredDifference{3}(4,4,0.1,20)
168+
M_temp_3 = zeros(22,22,20)
169+
for idx in indices
170+
M[idx] = sin(idx[3]*0.1)
171+
end
172+
173+
#test mul!
174+
mul!(M_temp_3, L, M)
175+
for i in 1:22
176+
for j in 1:22
177+
@test M_temp_3[i,j,:] correct_row
178+
end
179+
end
180+
181+
# Test that * agrees will mul!
182+
@test M_temp_3 == L*M
183+
end
184+
86185
@testset "Differentiating an arbitrary higher dimension" begin
87186
N = 6
88187
L = CenteredDifference{N}(4,4,0.1,30)
89-
M = zeros(5,5,5,5,5,32)
90-
M_temp = zeros(5,5,5,5,5,30)
91-
indices = Iterators.drop(CartesianIndices((5,5,5,5,5,32)), 0)
188+
M = zeros(5,5,5,5,5,32,5);
189+
M_temp = zeros(5,5,5,5,5,30,5);
190+
indices = Iterators.drop(CartesianIndices((5,5,5,5,5,32,5)), 0);
92191
for idx in indices
93192
M[idx] = cos(idx[N]*0.1)
94193
end
95194

96-
correct_row = (10.0^4)*fourth_deriv_approx_stencil(30)*M[1,1,1,1,1,:]
195+
correct_row = (10.0^4)*fourth_deriv_approx_stencil(30)*M[1,1,1,1,1,:,1]
97196

98197
#test mul!
99198
mul!(M_temp, L, M)
100-
indices = Iterators.drop(CartesianIndices((5,5,5,5,5,30)), 0)
199+
indices = Iterators.drop(CartesianIndices((5,5,5,5,5,30,5)), 0);
101200
for idx in indices
102201
@test M_temp[idx] correct_row[idx[N]]
103202
end

0 commit comments

Comments
 (0)