Skip to content
This repository was archived by the owner on Jul 19, 2023. It is now read-only.

Commit 7213994

Browse files
Merge pull request #128 from JuliaDiffEq/nnlib_convolutions
mul! implemented for 2d and 3d multiplication with NNlib
2 parents 13be768 + ce1269e commit 7213994

File tree

3 files changed

+68
-1
lines changed

3 files changed

+68
-1
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1111
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
1212
SuiteSparse = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9"
1313
BandedMatrices = "aae01518-5342-5314-be14-df237901396f"
14+
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
15+
1416

1517
[compat]
1618
julia = "1"

src/DiffEqOperators.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ import Base: +, -, *, /, \, size, getindex, setindex!, Matrix, convert
44
using DiffEqBase, StaticArrays, LinearAlgebra
55
import LinearAlgebra: mul!, ldiv!, lmul!, rmul!, axpy!, opnorm, factorize, I
66
import DiffEqBase: AbstractDiffEqLinearOperator, update_coefficients!, is_constant
7-
using SparseArrays, ForwardDiff, BandedMatrices
7+
using SparseArrays, ForwardDiff, BandedMatrices, NNlib
88

99
abstract type AbstractDerivativeOperator{T} <: AbstractDiffEqLinearOperator{T} end
1010
abstract type AbstractDiffEqCompositeOperator{T} <: AbstractDiffEqLinearOperator{T} end

src/derivative_operators/derivative_operator_functions.jl

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
function LinearAlgebra.mul!(x_temp::AbstractArray{T}, A::DerivativeOperator{T,N}, M::AbstractArray{T}) where {T<:Real,N}
22

3+
# Check that x_temp has correct dimensions
34
v = zeros(ndims(x_temp))
45
v[N] = 2
56
@assert [size(x_temp)...]+v == [size(M)...]
67

8+
# Check that axis of differentiation is in the dimensions of M and x_temp
79
ndimsM = ndims(M)
810
@assert N <= ndimsM
911

@@ -23,6 +25,69 @@ function LinearAlgebra.mul!(x_temp::AbstractArray{T}, A::DerivativeOperator{T,N}
2325
end
2426
end
2527

28+
for MT in [2,3]
29+
@eval begin
30+
function LinearAlgebra.mul!(x_temp::AbstractArray{T,$MT}, A::DerivativeOperator{T,N,Wind,T2,S1}, M::AbstractArray{T,$MT}) where {T<:Real,N,Wind,T2,SL,S1<:SArray{Tuple{SL},T,1,SL}}
31+
32+
# Check that x_temp has correct dimensions
33+
v = zeros(ndims(x_temp))
34+
v[N] = 2
35+
@assert [size(x_temp)...]+v == [size(M)...]
36+
37+
# Check that axis of differentiation is in the dimensions of M and x_temp
38+
ndimsM = ndims(M)
39+
@assert N <= ndimsM
40+
41+
# Respahe x_temp for NNlib.conv!
42+
new_size = Any[size(x_temp)...]
43+
bpc = A.boundary_point_count
44+
setindex!(new_size, new_size[N]- 2*bpc, N)
45+
new_shape = []
46+
for i in 1:ndimsM
47+
if i != N
48+
push!(new_shape,:)
49+
else
50+
push!(new_shape,bpc+1:new_size[N]+bpc)
51+
end
52+
end
53+
_x_temp = reshape(view(x_temp, new_shape...), (new_size...,1,1))
54+
55+
# Reshape M for NNlib.conv!
56+
_M = reshape(M, (size(M)...,1,1))
57+
s = A.stencil_coefs
58+
sl = A.stencil_length
59+
60+
# Setup W, the kernel for NNlib.conv!
61+
Wdims = ones(Int64, ndims(_x_temp))
62+
Wdims[N] = sl
63+
W = zeros(Wdims...)
64+
Widx = Any[Wdims...]
65+
setindex!(Widx,:,N)
66+
W[Widx...] = s ./ A.dx^A.derivative_order # this will change later
67+
cv = DenseConvDims(_M, W)
68+
69+
conv!(_x_temp, _M, W, cv)
70+
71+
# Now deal with boundaries
72+
dimsM = [axes(M)...]
73+
alldims = [1:ndims(M);]
74+
otherdims = setdiff(alldims, N)
75+
76+
idx = Any[first(ind) for ind in axes(M)]
77+
itershape = tuple(dimsM[otherdims]...)
78+
nidx = length(otherdims)
79+
indices = Iterators.drop(CartesianIndices(itershape), 0)
80+
81+
setindex!(idx, :, N)
82+
for I in indices
83+
Base.replace_tuples!(nidx, idx, idx, otherdims, I)
84+
convolve_BC_left!(view(x_temp, idx...), view(M, idx...), A)
85+
convolve_BC_right!(view(x_temp, idx...), view(M, idx...), A)
86+
end
87+
end
88+
end
89+
end
90+
2691
function *(A::DerivativeOperator{T,N},M::AbstractArray{T}) where {T<:Real,N}
2792
size_x_temp = [size(M)...]
2893
size_x_temp[N] -= 2

0 commit comments

Comments
 (0)