Skip to content
This repository was archived by the owner on Jul 19, 2023. It is now read-only.

Commit bad3c23

Browse files
Merge pull request #138 from JuliaDiffEq/kronecker
Higher Dimension Concretization
2 parents 54e1c17 + f058ec5 commit bad3c23

File tree

5 files changed

+228
-3
lines changed

5 files changed

+228
-3
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,10 @@ version = "4.1.0"
55

66
[deps]
77
BandedMatrices = "aae01518-5342-5314-be14-df237901396f"
8+
BlockBandedMatrices = "ffab5731-97b5-5995-9138-79e8c1846df0"
89
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
910
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
11+
LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02"
1012
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1113
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
1214
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"

src/DiffEqOperators.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ import Base: +, -, *, /, \, size, getindex, setindex!, Matrix, convert
44
using DiffEqBase, StaticArrays, LinearAlgebra
55
import LinearAlgebra: mul!, ldiv!, lmul!, rmul!, axpy!, opnorm, factorize, I
66
import DiffEqBase: AbstractDiffEqLinearOperator, update_coefficients!, is_constant
7-
using SparseArrays, ForwardDiff, BandedMatrices, NNlib
7+
using SparseArrays, ForwardDiff, BandedMatrices, NNlib, LazyArrays, BlockBandedMatrices
88

99
abstract type AbstractDerivativeOperator{T} <: AbstractDiffEqLinearOperator{T} end
1010
abstract type AbstractDiffEqCompositeOperator{T} <: AbstractDiffEqLinearOperator{T} end
@@ -15,7 +15,7 @@ include("matrixfree_operators.jl")
1515
include("jacvec_operators.jl")
1616

1717
### Utilities
18-
include("utils.jl")
18+
include("utils.jl")
1919

2020
### Boundary Padded Arrays
2121
include("boundary_padded_arrays.jl")

src/derivative_operators/concretization.jl

Lines changed: 116 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ function SparseArrays.SparseMatrixCSC(A::DerivativeOperator{T}, N::Int=A.len) wh
5050
return L
5151
end
5252

53-
function SparseArrays.sparse(A::AbstractDerivativeOperator{T}, N::Int=A.len) where T
53+
function SparseArrays.sparse(A::DerivativeOperator{T}, N::Int=A.len) where T
5454
SparseMatrixCSC(A,N)
5555
end
5656

@@ -247,3 +247,118 @@ LinearAlgebra.Array(Q::ComposedMultiDimBC, Ns) = Tuple(Array.(Q.BCs, Ns))
247247
SparseArrays.SparseMatrixCSC(Q::ComposedMultiDimBC, Ns...) = Tuple(sparse.(Q.BCs, Ns))
248248
SparseArrays.sparse(Q::ComposedMultiDimBC, Ns) = SparseMatrixCSC(Q, Ns)
249249
BandedMatrices.BandedMatrix(Q::ComposedMultiDimBC, Ns) = Tuple(BandedMatrix.(Q.BCs, Ns))
250+
251+
# HIgher Dimensional Concretizations. The following concretizations return two dimensional arrays
252+
# which operate on flattened vectors. Mshape is the size of the unflattened array on which A is operating on.
253+
254+
function LinearAlgebra.Array(A::DerivativeOperator{T,N}, Mshape) where {T,N}
255+
# Case where A is not differentiating along the first dimension
256+
if N != 1
257+
n = 1
258+
for M_i in Mshape[1:N-1]
259+
n *= M_i
260+
end
261+
B = Kron(Array(A), Eye(n))
262+
if N != length(Mshape)
263+
n = 1
264+
for M_i in Mshape[N+1:end]
265+
n *= M_i
266+
end
267+
B = Kron(Eye(n), B)
268+
end
269+
270+
# Case where A is differentiating along hte first dimension
271+
else
272+
n = 1
273+
for M_i in Mshape[2:end]
274+
n *= M_i
275+
end
276+
B = Kron(Eye(n), Array(A))
277+
end
278+
return Array(B)
279+
end
280+
281+
function SparseArrays.SparseMatrixCSC(A::DerivativeOperator{T,N}, Mshape) where {T,N}
282+
# Case where A is not differentiating along the first dimension
283+
if N != 1
284+
n = 1
285+
for M_i in Mshape[1:N-1]
286+
n *= M_i
287+
end
288+
B = Kron(sparse(A), sparse(I,n,n))
289+
if N != length(Mshape)
290+
n = 1
291+
for M_i in Mshape[N+1:end]
292+
n *= M_i
293+
end
294+
B = Kron(sparse(I,n,n), B)
295+
end
296+
297+
# Case where A is differentiating along hte first dimension
298+
else
299+
n = 1
300+
for M_i in Mshape[2:end]
301+
n *= M_i
302+
end
303+
B = Kron(sparse(I,n,n), sparse(A))
304+
end
305+
return sparse(B)
306+
end
307+
308+
function SparseArrays.sparse(A::DerivativeOperator{T,N}, Mshape) where {T,N}
309+
return SparseMatrixCSC(A,Mshape)
310+
end
311+
312+
function BandedMatrices.BandedMatrix(A::DerivativeOperator{T,N}, Mshape) where {T,N}
313+
# Case where A is not differentiating along the first dimension
314+
if N != 1
315+
n = 1
316+
for M_i in Mshape[1:N-1]
317+
n *= M_i
318+
end
319+
B = Kron(BandedMatrix(A), Eye(n))
320+
if N != length(Mshape)
321+
n = 1
322+
for M_i in Mshape[N+1:end]
323+
n *= M_i
324+
end
325+
B = Kron(Eye(n), B)
326+
end
327+
328+
# Case where A is differentiating along hte first dimension
329+
else
330+
n = 1
331+
for M_i in Mshape[2:end]
332+
n *= M_i
333+
end
334+
B = Kron(BandedMatrix(Eye(n)), BandedMatrix(A))
335+
end
336+
return BandedMatrix(B)
337+
end
338+
339+
function BlockBandedMatrices.BandedBlockBandedMatrix(A::DerivativeOperator{T,N}, Mshape) where {T,N}
340+
# Case where A is not differentiating along the first dimension
341+
if N != 1
342+
n = 1
343+
for M_i in Mshape[1:N-1]
344+
n *= M_i
345+
end
346+
B = Kron(BandedMatrix(A), Eye(n))
347+
if N != length(Mshape)
348+
n = 1
349+
for M_i in Mshape[N+1:end]
350+
n *= M_i
351+
end
352+
B = Kron(Eye(n), B)
353+
end
354+
355+
# Case where A is differentiating along hte first dimension
356+
else
357+
n = 1
358+
for M_i in Mshape[2:end]
359+
n *= M_i
360+
end
361+
B = Kron(BandedMatrix(Eye(n)), BandedMatrix(A))
362+
end
363+
return BandedBlockBandedMatrix(B)
364+
end

test/concretization.jl

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
using SparseArrays, DiffEqOperators, LinearAlgebra, Random,
2+
Test, BandedMatrices, FillArrays, LazyArrays, BlockBandedMatrices
3+
4+
# This test file tests for the correctness of higher dimensional concretization.
5+
# The tests verify that multiplication in the concretized case agrees with the matrix-free
6+
# multiplication
7+
8+
@testset "First Dimension" begin
9+
10+
# Test that even when we have a vector, the concretizations using the higher dimension dispatch still function
11+
# correctly
12+
M = rand(22)
13+
14+
L1 = CenteredDifference(1,2,1.0,20)
15+
L2 = CenteredDifference(1,2,1.0,20)
16+
L3 = CenteredDifference(3,3,1.0,20)
17+
18+
@test L1*M Array(L1, size(M))*vec(M) sparse(L1,size(M))*M BandedMatrix(L1, size(M))*M BandedBlockBandedMatrix(L1,size(M))*M
19+
@test L2*M Array(L2, size(M))*vec(M) sparse(L2,size(M))*M BandedMatrix(L2, size(M))*M BandedBlockBandedMatrix(L2,size(M))*M
20+
@test L3*M Array(L3, size(M))*vec(M) sparse(L3,size(M))*M BandedMatrix(L3, size(M))*M BandedBlockBandedMatrix(L3,size(M))*M
21+
22+
M = rand(22,2,2,2,2)
23+
24+
@test vec(L1*M) Array(L1, size(M))*vec(M) sparse(L1,size(M))*vec(M) BandedMatrix(L1, size(M))*vec(M) BandedBlockBandedMatrix(L1,size(M))*vec(M)
25+
@test vec(L2*M) Array(L2, size(M))*vec(M) sparse(L2,size(M))*vec(M) BandedMatrix(L2, size(M))*vec(M) BandedBlockBandedMatrix(L2,size(M))*vec(M)
26+
@test vec(L3*M) Array(L3, size(M))*vec(M) sparse(L3,size(M))*vec(M) BandedMatrix(L3, size(M))*vec(M) BandedBlockBandedMatrix(L3,size(M))*vec(M)
27+
28+
end
29+
30+
@testset "Second Dimension" begin
31+
32+
M = rand(2,22,2)
33+
34+
L1 = CenteredDifference{2}(1,2,1.0,20)
35+
L2 = CenteredDifference{2}(1,2,1.0,20)
36+
L3 = CenteredDifference{2}(3,3,1.0,20)
37+
38+
@test vec(L1*M) Array(L1, size(M))*vec(M) sparse(L1,size(M))*vec(M) BandedMatrix(L1, size(M))*vec(M) BandedBlockBandedMatrix(L1,size(M))*vec(M)
39+
@test vec(L2*M) Array(L2, size(M))*vec(M) sparse(L2,size(M))*vec(M) BandedMatrix(L2, size(M))*vec(M) BandedBlockBandedMatrix(L2,size(M))*vec(M)
40+
@test vec(L3*M) Array(L3, size(M))*vec(M) sparse(L3,size(M))*vec(M) BandedMatrix(L3, size(M))*vec(M) BandedBlockBandedMatrix(L3,size(M))*vec(M)
41+
42+
M = rand(2,22,2,2)
43+
44+
@test vec(L1*M) Array(L1, size(M))*vec(M) sparse(L1,size(M))*vec(M) BandedMatrix(L1, size(M))*vec(M) BandedBlockBandedMatrix(L1,size(M))*vec(M)
45+
@test vec(L2*M) Array(L2, size(M))*vec(M) sparse(L2,size(M))*vec(M) BandedMatrix(L2, size(M))*vec(M) BandedBlockBandedMatrix(L2,size(M))*vec(M)
46+
@test vec(L3*M) Array(L3, size(M))*vec(M) sparse(L3,size(M))*vec(M) BandedMatrix(L3, size(M))*vec(M) BandedBlockBandedMatrix(L3,size(M))*vec(M)
47+
48+
M = rand(2,22,2,2,3)
49+
50+
@test vec(L1*M) Array(L1, size(M))*vec(M) sparse(L1,size(M))*vec(M) BandedMatrix(L1, size(M))*vec(M) BandedBlockBandedMatrix(L1,size(M))*vec(M)
51+
@test vec(L2*M) Array(L2, size(M))*vec(M) sparse(L2,size(M))*vec(M) BandedMatrix(L2, size(M))*vec(M) BandedBlockBandedMatrix(L2,size(M))*vec(M)
52+
@test vec(L3*M) Array(L3, size(M))*vec(M) sparse(L3,size(M))*vec(M) BandedMatrix(L3, size(M))*vec(M) BandedBlockBandedMatrix(L3,size(M))*vec(M)
53+
54+
end
55+
56+
@testset "Third Dimension" begin
57+
58+
M = rand(3,2,22)
59+
60+
L1 = CenteredDifference{3}(1,2,1.0,20)
61+
L2 = CenteredDifference{3}(1,2,1.0,20)
62+
L3 = CenteredDifference{3}(3,3,1.0,20)
63+
64+
65+
@test vec(L1*M) Array(L1, size(M))*vec(M) sparse(L1,size(M))*vec(M) BandedMatrix(L1, size(M))*vec(M) BandedBlockBandedMatrix(L1,size(M))*vec(M)
66+
@test vec(L2*M) Array(L2, size(M))*vec(M) sparse(L2,size(M))*vec(M) BandedMatrix(L2, size(M))*vec(M) BandedBlockBandedMatrix(L2,size(M))*vec(M)
67+
@test vec(L3*M) Array(L3, size(M))*vec(M) sparse(L3,size(M))*vec(M) BandedMatrix(L3, size(M))*vec(M) BandedBlockBandedMatrix(L3,size(M))*vec(M)
68+
69+
M = rand(3,2,22,2)
70+
71+
@test vec(L1*M) Array(L1, size(M))*vec(M) sparse(L1,size(M))*vec(M) BandedMatrix(L1, size(M))*vec(M) BandedBlockBandedMatrix(L1,size(M))*vec(M)
72+
@test vec(L2*M) Array(L2, size(M))*vec(M) sparse(L2,size(M))*vec(M) BandedMatrix(L2, size(M))*vec(M) BandedBlockBandedMatrix(L2,size(M))*vec(M)
73+
@test vec(L3*M) Array(L3, size(M))*vec(M) sparse(L3,size(M))*vec(M) BandedMatrix(L3, size(M))*vec(M) BandedBlockBandedMatrix(L3,size(M))*vec(M)
74+
75+
M = rand(3,2,22,2,3)
76+
77+
@test vec(L1*M) Array(L1, size(M))*vec(M) sparse(L1,size(M))*vec(M) BandedMatrix(L1, size(M))*vec(M) BandedBlockBandedMatrix(L1,size(M))*vec(M)
78+
@test vec(L2*M) Array(L2, size(M))*vec(M) sparse(L2,size(M))*vec(M) BandedMatrix(L2, size(M))*vec(M) BandedBlockBandedMatrix(L2,size(M))*vec(M)
79+
@test vec(L3*M) Array(L3, size(M))*vec(M) sparse(L3,size(M))*vec(M) BandedMatrix(L3, size(M))*vec(M) BandedBlockBandedMatrix(L3,size(M))*vec(M)
80+
81+
end
82+
83+
@testset "Fifth Dimension" begin
84+
85+
M = rand(3,2,3,2,22)
86+
87+
L1 = CenteredDifference{5}(1,2,1.0,20)
88+
L2 = CenteredDifference{5}(1,2,1.0,20)
89+
L3 = CenteredDifference{5}(3,3,1.0,20)
90+
91+
@test vec(L1*M) Array(L1, size(M))*vec(M) sparse(L1,size(M))*vec(M) BandedMatrix(L1, size(M))*vec(M) BandedBlockBandedMatrix(L1,size(M))*vec(M)
92+
@test vec(L2*M) Array(L2, size(M))*vec(M) sparse(L2,size(M))*vec(M) BandedMatrix(L2, size(M))*vec(M) BandedBlockBandedMatrix(L2,size(M))*vec(M)
93+
@test vec(L3*M) Array(L3, size(M))*vec(M) sparse(L3,size(M))*vec(M) BandedMatrix(L3, size(M))*vec(M) BandedBlockBandedMatrix(L3,size(M))*vec(M)
94+
95+
M = rand(3,2,3,2,22,3)
96+
97+
@test vec(L1*M) Array(L1, size(M))*vec(M) sparse(L1,size(M))*vec(M) BandedMatrix(L1, size(M))*vec(M) BandedBlockBandedMatrix(L1,size(M))*vec(M)
98+
@test vec(L2*M) Array(L2, size(M))*vec(M) sparse(L2,size(M))*vec(M) BandedMatrix(L2, size(M))*vec(M) BandedBlockBandedMatrix(L2,size(M))*vec(M)
99+
@test vec(L3*M) Array(L3, size(M))*vec(M) sparse(L3,size(M))*vec(M) BandedMatrix(L3, size(M))*vec(M) BandedBlockBandedMatrix(L3,size(M))*vec(M)
100+
101+
M = rand(2,2,2,2,22,3,2)
102+
103+
@test vec(L1*M) Array(L1, size(M))*vec(M) sparse(L1,size(M))*vec(M) BandedMatrix(L1, size(M))*vec(M) BandedBlockBandedMatrix(L1,size(M))*vec(M)
104+
@test vec(L2*M) Array(L2, size(M))*vec(M) sparse(L2,size(M))*vec(M) BandedMatrix(L2, size(M))*vec(M) BandedBlockBandedMatrix(L2,size(M))*vec(M)
105+
@test vec(L3*M) Array(L3, size(M))*vec(M) sparse(L3,size(M))*vec(M) BandedMatrix(L3, size(M))*vec(M) BandedBlockBandedMatrix(L3,size(M))*vec(M)
106+
107+
end

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,4 @@ import Base: isapprox
1616
@time @safetestset "Convolutions" begin include("convolutions.jl") end
1717
@time @safetestset "Differentiation Dimension" begin include("differentiation_dimension.jl") end
1818
@time @safetestset "2D and 3D fast multiplication" begin include("2D_3D_fast_multiplication.jl") end
19+
@time @safetestset "Higher Dimensional Concretization" begin include("concretization.jl") end

0 commit comments

Comments
 (0)