Skip to content

Commit 7afac3e

Browse files
committed
Finished first pass at generate_bc_rules
1 parent 8e93a88 commit 7afac3e

File tree

7 files changed

+103
-66
lines changed

7 files changed

+103
-66
lines changed

src/MOL_utils.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,3 +74,5 @@ function unitindices(N::Int) #create unit CartesianIndex for each dimension
7474
end
7575

7676
half_range(x) = -div(x,2):div(x,2)
77+
78+

src/MethodOfLines.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,10 @@ module MethodOfLines
1111

1212
include("MOL_utils.jl")
1313
include("discretization/fornberg.jl")
14+
include("bcs/generate_bc_eqs.jl")
1415
include("discretization/discretize_vars.jl")
1516
include("discretization/differential_discretizer.jl")
1617
include("discretization/generate_rules.jl")
17-
include("bcs/generate_bc_eqs.jl")
1818

1919
include("discretization/MOL_discretization.jl")
2020

src/bcs/generate_bc_eqs.jl

Lines changed: 47 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,17 @@
22

33
abstract type AbstractBoundary end
44

5-
struct LowerBoundary <: AbstractBoundary
5+
abstract type AbstractTruncatingBoundary <: AbstractBoundary end
6+
7+
abstract type AbstractExtendingBoundary <: AbstractBoundary end
8+
9+
struct LowerBoundary <: AbstractTruncatingBoundary
610
end
711

8-
struct UpperBoundary<: AbstractBoundary
12+
struct UpperBoundary<: AbstractTruncatingBoundary
913
end
1014

11-
struct CompleteBoundary <: AbstractBoundary
15+
struct CompleteBoundary <: AbstractTruncatingBoundary
1216
end#
1317

1418
struct PeriodicBoundary <: AbstractBoundary
@@ -20,7 +24,7 @@ end
2024

2125
"""
2226
Mutates bceqs and u0 by finding relevant equations and discretizing them.
23-
TODO: return a handler for use with generate_finite_difference_rules
27+
TODO: return a handler for use with generate_finite_difference_rules and pull out initial condition
2428
"""
2529
function BoundaryHandler!!(u0, bceqs, bcs, s, depvar_ops, tspan, derivweights)
2630

@@ -63,7 +67,7 @@ function BoundaryHandler!!(u0, bceqs, bcs, s, depvar_ops, tspan, derivweights)
6367
for term in terms, r in lower_boundary_rules
6468
if r(term) !== nothing
6569
u_, x_ = (term, r(term))
66-
boundary = :lower
70+
boundary = LowerBoundary()
6771
for term_ in setdiff(terms, term)
6872
for r in upper_boundary_rules
6973
if r(term_) !== nothing
@@ -77,7 +81,7 @@ function BoundaryHandler!!(u0, bceqs, bcs, s, depvar_ops, tspan, derivweights)
7781
for term in terms, r in upper_boundary_rules
7882
if r(term) !== nothing
7983
u_, x_ = (term, r(term))
80-
boundary = :upper
84+
boundary = UpperBoundary()
8185
break
8286
end
8387
end
@@ -86,7 +90,7 @@ function BoundaryHandler!!(u0, bceqs, bcs, s, depvar_ops, tspan, derivweights)
8690

8791
push!(bceqs, vec(map(s.Iedge[x_][boundary]) do II
8892
rules = generate_bc_rules(II, s, bc, u_, boundary, derivweights)
89-
rules = vcat(rules, generate_finite_difference_rules(II, s, bc, derivweights))
93+
rules = vcat(rules, )
9094

9195
substitute(bc.lhs, rules) ~ substitute(bc.rhs, rules)
9296
end))
@@ -96,17 +100,45 @@ function BoundaryHandler!!(u0, bceqs, bcs, s, depvar_ops, tspan, derivweights)
96100
end
97101
end
98102

99-
function get_active_variable(bc, s, depvar_ops)
100-
bcdepvar = first(get_depvars(bc.lhs, depvar_ops))
101-
out = Array{typeof(operation(s.vars[1]))}()
103+
function generate_bc_rules(II, dim, s, bc, u_, x_, boundary::AbstractTruncatingBoundary, derivweights, G::CenterAlignedGrid)
104+
# depvarbcmaps will dictate what to replace the variable terms with in the bcs
105+
# replace u(t,0) with u₁, etc
106+
ufunc(v, I, x) = s.discvars[v][I]
107+
102108
for u in s.vars
103-
if u isa Sym && isequal(operation(u), operation(bcdepvar))
104-
push!(out, u)
109+
if isequal(operation(u), operation(u_))
110+
# What to replace derivatives at the boundary with
111+
depvarderivbcmaps = [(Differential(x)^d)(u_) => central_difference(derivweights.map[Differential(x_)^d], II, s, (s.x2i[x_],x_), u, ufunc) for d in derivweights.orders[x_]]
112+
# ? Does this need to be done for all variables at the boundary?
113+
depvarbcmaps = [u_ => s.discvars[u][II]]
105114
end
106115
end
107-
return out
108-
end
109116

117+
fd_rules = generate_finite_difference_rules(II, s, bc, derivweights)
118+
varrules = axiesvals(s, II)
119+
120+
return vcat(depvarderivbcmaps, depvarbcmaps, fd_rules, varrules)
121+
end
110122

123+
function generate_bc_rules(II, dim, s, bc, u_, x_, boundary::AbstractTruncatingBoundary, derivweights, G::EdgeAlignedGrid)
124+
125+
boundaryoffset(::LowerBoundary) = 1/2
126+
boundaryoffset(::UpperBoundary) = -1/2
127+
ufunc(v, I, x) = s.discvars[v][I]
111128

112-
#function generate_u0_and_bceqs_with_rules!!(u0, bceqs, bcs, t, s, depvar_ops)
129+
# depvarbcmaps will dictate what to replace the variable terms with in the bcs
130+
# replace u(t,0) with u₁, etc
131+
for u in s.vars
132+
if isequal(operation(u), operation(u_))
133+
depvarderivbcmaps = [(Differential(x)^d)(u_) => half_offset_centered_difference(derivweights.halfoffsetmap[Differential(x_)^d], II, s, offset(boundary), (s.x2i[x_],x_), u, ufunc) for d in derivweights.orders[x_]]
134+
135+
depvarbcmaps = [u_ => half_offset_centered_difference(derivweights.interpmap[x_], II, s, offset(boundary), (s.x2i[x_],x_), u, ufunc)]
136+
end
137+
end
138+
139+
fd_rules = generate_finite_difference_rules(II, s, bc, derivweights)
140+
varrules = axiesvals(s, II)
141+
valr = valrules(s, II)
142+
143+
return vcat(depvarderivbcmaps, depvarbcmaps, fd_rules, varrules)
144+
end

src/discretization/MOL_discretization.jl

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,29 @@
11
# Method of lines discretization scheme
2+
abstract type AbstractGrid end
23

3-
@enum GridAlign center_align edge_align
4+
struct CenterAlignedGrid <: AbstractGrid
5+
end
6+
7+
struct EdgeAlignedGrid <: AbstractGrid
8+
end
9+
10+
const center_align=CenterAlignedGrid()
11+
const edge_align=EdgeAlignedGrid()
412

5-
struct MOLFiniteDifference{T,T2} <: DiffEqBase.AbstractDiscretization
6-
dxs::T
7-
time::T2
13+
struct MOLFiniteDifference{G} <: DiffEqBase.AbstractDiscretization
14+
dxs
15+
time
816
approx_order::Int
9-
grid_align::GridAlign
17+
grid_align::G
1018
end
1119

1220
# Constructors. If no order is specified, both upwind and centered differences will be 2nd order
13-
function MOLFiniteDifference(dxs, time=nothing; upwind_order = 1, centered_order = 2, grid_align=center_align)
21+
function MOLFiniteDifference(dxs, time=nothing; upwind_order = 1, centered_order = 2, grid_align=CenterAlignedGrid())
1422

1523
if centered_order % 2 != 0
1624
warn("Discretization centered_order must be even, rounding up to $(centered_order+1)")
1725
end
18-
return MOLFiniteDifference(dxs, time, centered_order, grid_align)
26+
return MOLFiniteDifference{typeof(grid_align)}(dxs, time, centered_order, grid_align)
1927
end
2028

2129
function SciMLBase.symbolic_discretize(pdesys::PDESystem, discretization::MethodOfLines.MOLFiniteDifference)
@@ -81,7 +89,7 @@ function SciMLBase.symbolic_discretize(pdesys::PDESystem, discretization::Method
8189

8290
# Discretize the equation on the interior
8391
pdeeqs = vec(map(interior) do II
84-
rules = generate_finite_difference_rules(II, s, pde, derivweights)
92+
rules = vcat(generate_finite_difference_rules(II, s, pde, derivweights), valrules(s, II))
8593
substitute(pde.lhs,rules) ~ substitute(pde.rhs,rules)
8694
end)
8795

src/discretization/differential_discretizer.jl

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ end
99

1010
function DifferentialDiscretizer(pde, s, discretization)
1111
approx_order = discretization.approx_order
12+
# TODO: Include bcs in this calculation
1213
d_orders(x) = reverse(sort(collect(union(differential_order(pde.rhs, x), differential_order(pde.lhs, x)))))
1314

1415
# central_deriv_rules = [(Differential(s)^2)(u) => central_deriv(2,II,j,k) for (j,s) in enumerate(s.nottime), (k,u) in enumerate(s.vars)]
@@ -23,16 +24,18 @@ function DifferentialDiscretizer(pde, s, discretization)
2324
# TODO: Only generate weights for derivatives that are actually used and avoid redundant calculations
2425
rs = [(Differential(x)^d) => CompleteCenteredDifference(d, approx_order, s.dxs[x] ) for d in last(orders).second]
2526

27+
nonlinlap = vcat(nonlinlap, [Differential(x)^d => CompleteHalfCenteredDifference(d, approx_order, s.dxs[x]) for d in last(orders).second])
2628
differentialmap = vcat(differentialmap, rs)
27-
push!(nonlinlap, x => CompleteHalfCenteredDifference(0, approx_order, s.dxs[x])
28-
push!(interp, x => CompleteHalfCenteredDifference(1, approx_order, s.dxs[x]))
29+
push!(interp, x => CompleteHalfCenteredDifference(0, approx_order, s.dxs[x])
2930
end
3031

3132
return DifferentialDiscretizer{eltype(orders), typeof(Dict(differentialmap))}(approx_order, Dict(differentialmap), Dict(nonlinlap), Dict(interp), Dict(orders))
3233
end
3334

34-
35-
# ufunc is a function that returns the correct discretization indexed at Itap, it is designed this way to allow for central differences of arbitrary expressions which may be needed in some schemes
35+
"""
36+
Performs a centered difference in `x` centered at index `II` of `u`
37+
ufunc is a function that returns the correct discretization indexed at Itap, it is designed this way to allow for central differences of arbitrary expressions which may be needed in some schemes
38+
"""
3639
function central_difference(D, II, s, jx, u, ufunc)
3740
j, x = jx
3841
# unit index in direction of the derivative
@@ -88,10 +91,9 @@ TODO: consider refactoring this to harmonize with centered difference
8891
function get_half_offset_weights_and_stencil(D, II, s, offset, jx)
8992
j, x = jx
9093
I1 = unitindices(nparams(s))[j]
91-
9294
# Shift the current index to the correct offset
93-
II_prime = II + offset*I1
94-
95+
II_prime = II + Int(offset-0.5)*I1
96+
@assert all(i-> 0<i<=length(s,x), II_prime) "Index out of bounds"
9597
return _get_weights_and_stencil(D, II_prime, I1, s, offset, j, x)
9698
end
9799

src/discretization/discretize_vars.jl

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
Return a map of of variables to the gridpoints at the edge of the domain
44
"""
55
@inline function get_edgevals(params, axies, i)
6-
return [params[i] => first(axies[i]), params[i] => last(axies[i])]
6+
return [params[i] => first(axies[params[i]]), params[i] => last(axies[params[i]])]
77
end
88

99
"""
@@ -42,17 +42,28 @@ params(s::DiscreteSpace{N,M}) where {N,M}= s.params
4242
grid_idxs(s::DiscreteSpace) = CartesianIndices(((axes(g)[1] for g in s.grid)...,))
4343
edge_idxs(s::DiscreteSpace{N}) where {N} = reduce(vcat, [[vcat([Colon() for j = 1:i-1], 1, [Colon() for j = i+1:N]), vcat([Colon() for j = 1:i-1], length(s.axies[i]), [Colon() for j = i+1:N])] for i = 1:N])
4444

45-
axiesvals(s::DiscreteSpace{N}) where {N} = map(y -> [s.nottime[i] => s.axies[s.nottime[i]][y.I[i]] for i = 1:N], s.Iaxies)
46-
gridvals(s::DiscreteSpace{N}) where {N} = map(y -> [s.nottime[i] => s.grid[s.nottime[i]][y.I[i]] for i = 1:N], s.Igrid)
45+
axiesvals(s::DiscreteSpace{N}, I) where {N} = [x => s.axies[x][I[j]] for (j,x) in enumerate(s.nottime)]
46+
gridvals(s::DiscreteSpace{N}, I) where {N} = [x => s.grid[x][I[j]] for (j,x) in enumerate(s.nottime)]
4747

4848
## Boundary methods ##
4949
edgevals(s::DiscreteSpace{N}) where {N} = reduce(vcat, [get_edgevals(s.nottime, s.axies, i) for i = 1:N])
50-
edgevars(s::DiscreteSpace) = [[d[e...] for e in s.Iedge] for d in s.discvars]
50+
edgevars(s::DiscreteSpace, I) = [u => s.discvars[u][I] for u in s.vars]
5151

52-
@inline function edgemaps(s::DiscreteSpace)
53-
bclocs(s::DiscreteSpace) = map(e -> substitute.(s.params, e), edgevals(s))
54-
return Dict(bclocs(s) .=> [axiesvals(s)[e...] for e in s.Iedge])
52+
"""
53+
Generate a map of variables to the gridpoints at the edge of the domain
54+
"""
55+
@inline function edgemaps(s::DiscreteSpace, ::LowerBoundary)
56+
return [x => first(s.axies[x]) for x in s.nottime]
5557
end
58+
@inline function edgemaps(s::DiscreteSpace, ::LowerBoundary)
59+
return [x => last(s.axies[x]) for x in s.nottime]
60+
end
61+
62+
varmaps(s::DiscreteSpace, II) = [u => s.discvars[u][II] for u in s.vars]
63+
64+
valmaps(s::DiscreteSpace, II) = vcat(varmaps(s,II), gridvals(s,II))
65+
66+
5667

5768
Iinterior(s::DiscreteSpace) = s.Igrid[[2:(length(s, x)-1) for x in s.nottime]...]
5869

@@ -109,7 +120,7 @@ function DiscreteSpace(domain, depvars, indvars, nottime, discretization)
109120
end
110121

111122
# Build symbolic maps for boundaries
112-
Iedge = vcat((vcat(vec(selectdim(Igrid, dim, 1)), vec(selectdim(Igrid, dim, length(grid[dim].second)))) for dim in 1:nspace)...)
123+
Iedge = Dict([x => Dict([LowerBoundary() => vec(selectdim(Igrid, dim, 1)), UpperBoundary() => vec(selectdim(Igrid, dim, length(grid[dim].second)))]) for (dim, x) in enumerate(s.nottime)])
113124

114125
nottime2dim = [nottime[i] => i for i in 1:nspace]
115126
dim2nottime = [i => nottime[i] for i in 1:nspace]

src/discretization/generate_rules.jl renamed to src/discretization/generate_finite_difference_rules.jl

Lines changed: 9 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ Interpolate gridpoints by taking the average of the values of the discrete point
55
"""
66
function interpolate_discrete_param(II, s, itap, j, x)
77
# * This will need to be updated to dispatch on grid type when grids become more general
8-
offset = itap+1/2
98
if (II[j]+itap) < 1
109
return s.grid[x][1]+s.dxs[x]*offset
1110
elseif (II[j]+itap) > (length(x) - 1)
@@ -47,13 +46,14 @@ function cartesian_nonlinear_laplacian(expr, II, derivweights, s, x, u)
4746
# See scheme 1, namely the term without the 1/r dependence. See also #354 and #371 in DiffEqOperators, the previous home of this package.
4847

4948
jx = j, x = (s.x2i(x), x)
50-
D_inner = derivweights.halfoffsetmap[x]
49+
D_inner = derivweights.halfoffsetmap[Differential(x)]
5150
inner_interpolater = derivweights.interpmap[x]
5251

53-
# Get the outer weights and stencil to generate the required
54-
outerweights, outerstencil = get_half_offset_weights_and_stencil(D_inner, II, s, 0, jx)
52+
# Get the outer weights and stencil.
53+
# ! The offset should eliminate the need for bounds checking, the inner function ensuring that the taps lie inbounds.
54+
outerweights, outerstencil = _get_weights_and_stencil(D_inner, II, s, 1/2, jx)
5555
# Index offsets of each stencil in the inner finite difference to get the correct stencil for each needed half grid point, 0 corresopnds to x+1/2
56-
itaps = getindex.(outerstencil, (j,))
56+
itaps = getindex.(outerstencil .- II, (j,)) .+ 0.5
5757

5858
# Get the correct weights and stencils for this II
5959
inner_deriv_weights_and_stencil = [get_half_offset_weights_and_stencil(D_inner, II, s, itap, jx) for itap in itaps]
@@ -128,6 +128,7 @@ function spherical_diffusion(innerexpr, II, derivweights, s, r, u)
128128
return exprhere*(D_1_u/substitute(r, _rsubs(r, II)) + cartesian_nonlinear_laplacian(innerexpr, II, derivweights, s, r, u))
129129
end
130130

131+
131132
"""
132133
`generate_finite_difference_rules`
133134
@@ -150,8 +151,6 @@ There are of course more specific schemes that are used to improve stability/spe
150151
Please submit an issue if you know of any special cases that are not implemented, with links to papers and/or code that demonstrates the special case.
151152
"""
152153
function generate_finite_difference_rules(II, s, pde, derivweights)
153-
valrules = vcat([u => s.discvars[u][II] for u in s.vars],
154-
[x => s.grid[x][II[j]] for (j,x) in enumerate(s.nottime)])
155154
# central_deriv_rules = [(Differential(s)^2)(u) => central_deriv(2,II,j,k) for (j,s) in enumerate(s.nottime), (k,u) in enumerate(s.vars)]
156155

157156
central_ufunc(u, I, x) = s.discvars[u][I]
@@ -177,9 +176,9 @@ function generate_finite_difference_rules(II, s, pde, derivweights)
177176
cartesian_deriv_rules = vcat(vec(cartesian_deriv_rules),vec(
178177
[@rule ($(Differential(x))($(Differential(x))(u)/~a)) => cartesian_nonlinear_laplacian(1/~a, II, derivweights, s, x, u) for x in s.nottime, u in s.vars]))
179178

180-
spherical_deriv_rules = [@rule *(~~a, (r^-2), ($(Differential(r))(*(~~c, (r^2), ~~d, $(Differential(r))(u), ~~e))), ~~b) =>
179+
spherical_deriv_rules = vec([@rule *(~~a, (r^-2), ($(Differential(r))(*(~~c, (r^2), ~~d, $(Differential(r))(u), ~~e))), ~~b) =>
181180
*(~a..., spherical_diffusion(*(~c..., ~d..., ~e...), II, derivweights, s, r, u), ~b...)
182-
for r in s.nottime, u in s.vars]
181+
for r in s.nottime, u in s.vars])
183182

184183
rhs_arg = istree(pde.rhs) && (SymbolicUtils.operation(pde.rhs) == +) ? SymbolicUtils.arguments(pde.rhs) : [pde.rhs]
185184
lhs_arg = istree(pde.lhs) && (SymbolicUtils.operation(pde.lhs) == +) ? SymbolicUtils.arguments(pde.lhs) : [pde.lhs]
@@ -200,24 +199,7 @@ function generate_finite_difference_rules(II, s, pde, derivweights)
200199
end
201200
end
202201
rules = vcat(vec(nonlinlap_rules),
203-
vec(central_deriv_rules_cartesian),
204-
valrules)
202+
vec(central_deriv_rules_cartesian))
205203
return rules
206204
end
207205

208-
function generate_bc_rules(II, dim, s, bc, edgemaps)
209-
# ! Recognise which dim a BC is on, and use that to get the axiesvals at the boundary
210-
# ! loop through the Iedge at that boundary
211-
# ! replace the symbolic variables at either end of the boundary with the appropriate discvars, interpolated if nessecary
212-
# ! eventually move to multi dim interpolations to improve validity
213-
214-
# depvarbcmaps will dictate what to replace the variable terms with in the bcs
215-
# replace u(t,0) with u₁, etc
216-
if grid_align == center_align
217-
depvarbcmaps = reduce(vcat,[substitute(depvar, edgevals(s)) .=> edgevar for (depvar, edgevar) in zip(s.vars, edgevars(s, II))])
218-
elseif grid_align == edge_align
219-
220-
end
221-
222-
varrules = edgemaps(s)
223-
rules = vcat(depvarbcmaps, edgemaps)

0 commit comments

Comments
 (0)