Skip to content

Commit 8c8226e

Browse files
committed
test fixes
1 parent 791866a commit 8c8226e

18 files changed

+354
-185
lines changed

src/MOLFiniteDifference.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
struct MOLFiniteDifference{G} <: DiffEqBase.AbstractDiscretization
2+
dxs
3+
time
4+
approx_order::Int
5+
grid_align::G
6+
end
7+
8+
# Constructors. If no order is specified, both upwind and centered differences will be 2nd order
9+
function MOLFiniteDifference(dxs, time=nothing; approx_order = 2, grid_align=CenterAlignedGrid())
10+
11+
if approx_order % 2 != 0
12+
warn("Discretization approx_order must be even, rounding up to $(approx_order+1)")
13+
end
14+
return MOLFiniteDifference{typeof(grid_align)}(dxs, time, approx_order, grid_align)
15+
end

src/MOL_utils.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,13 @@ function unitindices(N::Int) #create unit CartesianIndex for each dimension
7373
Tuple(out)
7474
end
7575

76+
function split_additive_terms(eq)
77+
rhs_arg = istree(eq.rhs) && (SymbolicUtils.operation(eq.rhs) == +) ? SymbolicUtils.arguments(eq.rhs) : [eq.rhs]
78+
lhs_arg = istree(eq.lhs) && (SymbolicUtils.operation(eq.lhs) == +) ? SymbolicUtils.arguments(eq.lhs) : [eq.lhs]
79+
80+
return vcat(lhs_arg,rhs_arg)
81+
end
82+
7683
half_range(x) = -div(x,2):div(x,2)
7784

7885

src/MethodOfLines.jl

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,17 @@ module MethodOfLines
1111

1212
include("MOL_utils.jl")
1313
include("discretization/fornberg.jl")
14-
include("bcs/generate_bc_eqs.jl")
14+
15+
include("grid_types.jl")
16+
include("MOLFiniteDifference.jl")
17+
include("bcs/boundary_types.jl")
18+
1519
include("discretization/discretize_vars.jl")
20+
1621
include("discretization/differential_discretizer.jl")
17-
include("discretization/generate_rules.jl")
22+
include("discretization/generate_finite_difference_rules.jl")
23+
24+
include("bcs/generate_bc_eqs.jl")
1825

1926
include("discretization/MOL_discretization.jl")
2027

src/bcs/boundary_types.jl

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
### INITIAL AND BOUNDARY CONDITIONS ###
2+
3+
abstract type AbstractBoundary end
4+
5+
abstract type AbstractTruncatingBoundary <: AbstractBoundary end
6+
7+
abstract type AbstractExtendingBoundary <: AbstractBoundary end
8+
9+
struct LowerBoundary <: AbstractTruncatingBoundary
10+
end
11+
12+
struct UpperBoundary<: AbstractTruncatingBoundary
13+
end
14+
15+
struct CompleteBoundary <: AbstractTruncatingBoundary
16+
end
17+
18+
struct PeriodicBoundary <: AbstractBoundary
19+
end
20+
21+
struct BoundaryHandler{hasperiodic}
22+
boundaries::Dict{Num, AbstractBoundary}
23+
end

src/bcs/generate_bc_eqs.jl

Lines changed: 46 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,36 @@
1-
### INITIAL AND BOUNDARY CONDITIONS ###
1+
function _boundary_rules(s, orders, val)
2+
args = copy(s.params)
23

3-
abstract type AbstractBoundary end
4-
5-
abstract type AbstractTruncatingBoundary <: AbstractBoundary end
6-
7-
abstract type AbstractExtendingBoundary <: AbstractBoundary end
4+
if isequal(val, floor(val))
5+
args = [substitute.(args, (x=>val,)), substitute.(args, (x=>Int(val),))]
6+
else
7+
args = [substitute.(args, (x=>val,))]
8+
end
9+
substitute.(args, (x=>lowerboundary(x),))
10+
11+
rules = [@rule operation(u)(arg...) => (u, x) for u in s.vars, arg in args]
812

9-
struct LowerBoundary <: AbstractTruncatingBoundary
13+
return vcat(rules, vec([@rule (Differential(x)^d)(operation(u)(arg...)) => (u, x) for d in orders[x], u in s.vars, arg in args]))
1014
end
1115

12-
struct UpperBoundary<: AbstractTruncatingBoundary
13-
end
16+
function generate_boundary_matching_rules(s, orders)
17+
# TODO: Check for bc equations of multiple variables
18+
lowerboundary(x) = first(s.axies[x])
19+
upperboundary(x) = last(s.axies[x])
1420

15-
struct CompleteBoundary <: AbstractTruncatingBoundary
16-
end
21+
# Rules to match boundary conditions on the lower boundaries
22+
lower = reduce(vcat, [_boundary_rules(s, orders, lowerboundary(x)) for x in s.vars])
1723

18-
struct PeriodicBoundary <: AbstractBoundary
19-
end
24+
upper = reduce(vcat, [_boundary_rules(s, orders, upperboundary(x)) for x in s.vars])
2025

21-
struct BoundaryHandler{hasperiodic}
22-
boundaries::Dict{Num, AbstractBoundary}
26+
return (lower, upper)
2327
end
2428

2529
"""
2630
Mutates bceqs and u0 by finding relevant equations and discretizing them.
27-
TODO: return a handler for use with generate_finite_difference_rules and pull out initial condition
31+
TODO: return a handler for use with generate_finite_difference_rules and pull out initial condition. Important to remember that BCs can have
2832
"""
29-
function BoundaryHandler!!(u0, bceqs, bcs, s, depvar_ops, tspan, derivweights)
33+
function BoundaryHandler!!(u0, bceqs, bcs, s::DiscreteSpace, depvar_ops, tspan, derivweights::DifferentialDiscretizer)
3034

3135
t=s.time
3236

@@ -38,30 +42,33 @@ function BoundaryHandler!!(u0, bceqs, bcs, s, depvar_ops, tspan, derivweights)
3842

3943
# Create some rules to match which bundary/variable a bc concerns
4044
# * Assume that the term of the condition is applied additively and has no multiplier/divisor/power etc.
41-
# ? Is it nessecary to check whether all other args are present?
42-
lower_boundary_rules = vec([@rule operation(u)(~~a, lowerboundary(x), ~~b) => IfElse.ifelse(all(y-> y in vcat(~~a, ~~b), setdiff(x, arguments(u))), x, nothing) for x in setdiff(arguments(u), t), u in s.vars])
45+
46+
## BC matching rules, returns the variable and parameter the bc concerns
47+
48+
lower_boundary_rules, upper_boundary_rules = generate_boundary_matching_rules(s, derivweights.orders)
4349

44-
upper_boundary_rules = vec([@rule operation(u)(~~a, upperboundary(x), ~~b) => IfElse.ifelse(all(y-> y in vcat(~~a, ~~b), setdiff(x, arguments(u))), x, nothing) for x in setdiff(arguments(u), t), u in s.vars])
50+
# indexes for Iedge depending on boundary type
51+
idx(::LowerBoundary) = 1
52+
idx(::UpperBoundary) = 2
4553

4654
# Generate initial conditions and bc equations
4755
for bc in bcs
56+
# * Assume in the form `u(...) ~ ...` for now
4857
bcdepvar = first(get_depvars(bc.lhs, depvar_ops))
58+
4959
if any(u -> isequal(operation(u), operation(bcdepvar)), s.vars)
5060
if t !== nothing && operation(bc.lhs) isa Sym && !any(x -> isequal(x, t.val), arguments(bc.lhs))
5161
# initial condition
52-
# * Assume in the form `u(...) ~ ...` for now
5362
# * Assume that the initial condition is not in terms of the initial derivative
5463
initindex = findfirst(isequal(bc.lhs), initmaps)
5564
if initindex !== nothing
5665
push!(u0,vec(s.discvars[s.vars[initindex]] .=> substitute.((bc.rhs,),gridvals(s))))
5766
end
5867
else
5968
# Split out additive terms
60-
rhs_arg = istree(pde.rhs) && (SymbolicUtils.operation(pde.rhs) == +) ? SymbolicUtils.arguments(pde.rhs) : [pde.rhs]
61-
lhs_arg = istree(pde.lhs) && (SymbolicUtils.operation(pde.lhs) == +) ? SymbolicUtils.arguments(pde.lhs) : [pde.lhs]
69+
terms = split_additive_terms(bc)
6270

6371
u_, x_ = (nothing, nothing)
64-
terms = vcat(lhs_arg,rhs_arg)
6572
boundary = nothing
6673
# Check whether the bc is on the lower boundary, or periodic
6774
for term in terms, r in lower_boundary_rules
@@ -75,6 +82,7 @@ function BoundaryHandler!!(u0, bceqs, bcs, s, depvar_ops, tspan, derivweights)
7582
#TODO: Add handling for perioodic boundary conditions here
7683
end
7784
end
85+
end
7886
break
7987
end
8088
end
@@ -86,24 +94,26 @@ function BoundaryHandler!!(u0, bceqs, bcs, s, depvar_ops, tspan, derivweights)
8694
end
8795
end
8896

89-
@assert boundary !== nothing "Boundary condition ${bc} is not on a boundary of the domain, or is not a valid boundary condition"
97+
@assert boundary !== nothing "Boundary condition $bc is not on a boundary of the domain, or is not a valid boundary condition"
9098

91-
push!(bceqs, vec(map(s.Iedge[x_][boundary]) do II
92-
rules = generate_bc_rules(II, s, bc, u_, boundary, derivweights)
93-
rules = vcat(rules, )
99+
push!(bceqs, vec(map(s.Iedge[x_][idx(boundary)]) do II
100+
rules = generate_bc_rules(II, derivweights, s, bc, u_, x_, boundary)
94101

95102
substitute(bc.lhs, rules) ~ substitute(bc.rhs, rules)
96103
end))
97104
end
98-
else
99-
throw(ArgumentError("No active variables in boundary condition $bc lhs, please ensure that bcs are "))
105+
end
100106
end
101107
end
102108

103-
function generate_bc_rules(II, dim, s, bc, u_, x_, boundary::AbstractTruncatingBoundary, derivweights, G::CenterAlignedGrid)
109+
function generate_bc_rules(II, derivweights, s::DiscreteSpace{N,M,G}, bc, u_, x_, ::AbstractTruncatingBoundary) where {N, M, G<:CenterAlignedGrid}
104110
# depvarbcmaps will dictate what to replace the variable terms with in the bcs
105111
# replace u(t,0) with u₁, etc
106112
ufunc(v, I, x) = s.discvars[v][I]
113+
114+
depvarderivbcmaps = []
115+
depvarbcmaps = []
116+
107117
# * Assume that the BC is in terms of an explicit expression, not containing references to variables other than u_ at the boundary
108118
for u in s.vars
109119
if isequal(operation(u), operation(u_))
@@ -124,12 +134,15 @@ function generate_bc_rules(II, dim, s, bc, u_, x_, boundary::AbstractTruncatingB
124134
return vcat(depvarderivbcmaps, depvarbcmaps, fd_rules, varrules)
125135
end
126136

127-
function generate_bc_rules(II, dim, s, bc, u_, x_, boundary::AbstractTruncatingBoundary, derivweights, G::EdgeAlignedGrid)
137+
function generate_bc_rules(II, derivweights, s::DiscreteSpace{N,M,G}, bc, u_, x_, boundary::AbstractTruncatingBoundary) where {N, M, G<:EdgeAlignedGrid}
128138

129139
offset(::LowerBoundary) = 1/2
130140
offset(::UpperBoundary) = -1/2
131141
ufunc(v, I, x) = s.discvars[v][I]
132142

143+
depvarderivbcmaps = []
144+
depvarbcmaps = []
145+
133146
# depvarbcmaps will dictate what to replace the variable terms with in the bcs
134147
# replace u(t,0) with u₁, etc
135148
# * Assume that the BC is in terms of an explicit expression, not containing references to variables other than u_ at the boundary
@@ -144,7 +157,7 @@ function generate_bc_rules(II, dim, s, bc, u_, x_, boundary::AbstractTruncatingB
144157

145158
fd_rules = generate_finite_difference_rules(II, s, bc, derivweights)
146159
varrules = axiesvals(s, x_, II)
147-
160+
148161
# valrules should be caught by depvarbcmaps and varrules if the above assumption holds
149162
#valr = valrules(s, II)
150163

src/discretization/MOL_discretization.jl

Lines changed: 3 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,6 @@
11
# Method of lines discretization scheme
2-
abstract type AbstractGrid end
32

4-
struct CenterAlignedGrid <: AbstractGrid
5-
end
6-
7-
struct EdgeAlignedGrid <: AbstractGrid
8-
end
9-
10-
const center_align=CenterAlignedGrid()
11-
const edge_align=EdgeAlignedGrid()
12-
13-
struct MOLFiniteDifference{G} <: DiffEqBase.AbstractDiscretization
14-
dxs
15-
time
16-
approx_order::Int
17-
grid_align::G
18-
end
19-
20-
# Constructors. If no order is specified, both upwind and centered differences will be 2nd order
21-
function MOLFiniteDifference(dxs, time=nothing; upwind_order = 1, centered_order = 2, grid_align=CenterAlignedGrid())
22-
23-
if centered_order % 2 != 0
24-
warn("Discretization centered_order must be even, rounding up to $(centered_order+1)")
25-
end
26-
return MOLFiniteDifference{typeof(grid_align)}(dxs, time, centered_order, grid_align)
27-
end
28-
29-
function SciMLBase.symbolic_discretize(pdesys::PDESystem, discretization::MethodOfLines.MOLFiniteDifference)
3+
function SciMLBase.symbolic_discretize(pdesys::PDESystem, discretization::MethodOfLines.MOLFiniteDifference{G}) where G
304
pdeeqs = pdesys.eqs isa Vector ? pdesys.eqs : [pdesys.eqs]
315
bcs = pdesys.bcs
326
domain = pdesys.domain
@@ -56,6 +30,7 @@ function SciMLBase.symbolic_discretize(pdesys::PDESystem, discretization::Method
5630
depvars_lhs = get_depvars(pde.lhs, depvar_ops)
5731
depvars_rhs = get_depvars(pde.rhs, depvar_ops)
5832
depvars = collect(depvars_lhs depvars_rhs)
33+
5934
# Read the independent variables,
6035
# ignore if the only argument is [t]
6136
allindvars = Set(filter(xs->!isequal(xs, [t]), map(arguments, depvars)))
@@ -82,7 +57,7 @@ function SciMLBase.symbolic_discretize(pdesys::PDESystem, discretization::Method
8257
derivweights = DifferentialDiscretizer(pde, s, discretization)
8358

8459
# Get the boundary conditions
85-
generate_u0_and_bceqs!!(u0, bceqs, pdesys.bcs, s, depvar_ops, tspan, derivweights)
60+
BoundaryHandler!!(u0, bceqs, pdesys.bcs, s, depvar_ops, tspan, derivweights)
8661

8762
# Find the indexes on the interior
8863
interior = Iinterior(s)

src/discretization/differential_discretizer.jl

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
struct DifferentialDiscretizer{T, D1}
33
approx_order::Int
44
map::D1
5-
halfoffsetmap::Dict{Num, DiffEqOperators.DerivativeOperator}
5+
halfoffsetmap::Dict
66
interpmap::Dict{Num, DiffEqOperators.DerivativeOperator}
77
orders::Dict{Num, Vector{Int}}
88
end
@@ -26,7 +26,8 @@ function DifferentialDiscretizer(pde, s, discretization)
2626

2727
nonlinlap = vcat(nonlinlap, [Differential(x)^d => CompleteHalfCenteredDifference(d, approx_order, s.dxs[x]) for d in last(orders).second])
2828
differentialmap = vcat(differentialmap, rs)
29-
push!(interp, x => CompleteHalfCenteredDifference(0, approx_order, s.dxs[x])
29+
# A 0th order derivative off the grid is an interpolation
30+
push!(interp, x => CompleteHalfCenteredDifference(0, approx_order, s.dxs[x]))
3031
end
3132

3233
return DifferentialDiscretizer{eltype(orders), typeof(Dict(differentialmap))}(approx_order, Dict(differentialmap), Dict(nonlinlap), Dict(interp), Dict(orders))
@@ -98,12 +99,12 @@ function get_half_offset_weights_and_stencil(D, II, s, offset, jx)
9899
end
99100

100101
# i is the index of the offset, assuming that there is one precalculated set of weights for each offset required for a first order finite difference
101-
function half_offset_centered_difference(D, II, s, offset, i, jx, u, ufunc)
102+
function half_offset_centered_difference(D, II, s, offset, jx, u, ufunc)
102103
j, x = jx
103104
I1 = unitindices(nparams(s))[j]
104105
# Shift the current index to the correct offset
105-
II_prime = II + offset*I1
106+
II_prime = II + Int(offset-0.5)*I1
106107
# Get the weights and stencil
107-
(weights, Itap) = _get_weights_and_stencil(D, II_prime, I1, s, i, j, x)
108+
(weights, Itap) = _get_weights_and_stencil(D, II_prime, I1, s, offset, j, x)
108109
return dot(weights, ufunc(u, Itap, x))
109110
end

0 commit comments

Comments
 (0)