Skip to content

Commit 27e1ec4

Browse files
fix: make SymbolicAffect and AffectSystem type-stable
1 parent a6ba3cd commit 27e1ec4

File tree

1 file changed

+39
-28
lines changed

1 file changed

+39
-28
lines changed

src/systems/callbacks.jl

Lines changed: 39 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,19 @@ end
77
struct SymbolicAffect
88
affect::Vector{Equation}
99
alg_eqs::Vector{Equation}
10-
discrete_parameters::Vector{Any}
10+
discrete_parameters::Vector{SymbolicT}
1111
end
1212

1313
function SymbolicAffect(affect::Vector{Equation}; alg_eqs = Equation[],
14-
discrete_parameters = Any[], kwargs...)
15-
if !(discrete_parameters isa AbstractVector)
16-
discrete_parameters = Any[discrete_parameters]
17-
elseif !(discrete_parameters isa Vector{Any})
18-
discrete_parameters = Vector{Any}(discrete_parameters)
14+
discrete_parameters = SymbolicT[], kwargs...)
15+
if symbolic_type(discrete_parameters) !== NotSymbolic()
16+
discrete_parameters = SymbolicT[unwrap(discrete_parameters)]
17+
elseif !(discrete_parameters isa Vector{SymbolicT})
18+
_discs = SymbolicT[]
19+
for p in discrete_parameters
20+
push!(_discs, unwrap(p))
21+
end
22+
discrete_parameters = _discs
1923
end
2024
SymbolicAffect(affect, alg_eqs, discrete_parameters)
2125
end
@@ -33,11 +37,11 @@ struct AffectSystem
3337
"""The internal implicit discrete system whose equations are solved to obtain values after the affect."""
3438
system::AbstractSystem
3539
"""Unknowns of the parent ODESystem whose values are modified or accessed by the affect."""
36-
unknowns::Vector
40+
unknowns::Vector{SymbolicT}
3741
"""Parameters of the parent ODESystem whose values are accessed by the affect."""
38-
parameters::Vector
42+
parameters::Vector{SymbolicT}
3943
"""Parameters of the parent ODESystem whose values are modified by the affect."""
40-
discretes::Vector
44+
discretes::Vector{SymbolicT}
4145
end
4246

4347
function (s::SymbolicUtils.Substituter)(aff::AffectSystem)
@@ -57,49 +61,57 @@ function AffectSystem(spec::SymbolicAffect; iv = nothing, alg_eqs = Equation[],
5761
discrete_parameters = spec.discrete_parameters, kwargs...)
5862
end
5963

60-
function AffectSystem(affect::Vector{Equation}; discrete_parameters = Any[],
64+
@noinline function warn_algebraic_equation(eq::Equation)
65+
@warn "Affect equation $eq has no `Pre` operator. As such it will be interpreted as an algebraic equation to be satisfied after the callback. If you intended to use the value of a variable x before the affect, use Pre(x). Errors may be thrown if there is no `Pre` and the algebraic equation is unsatisfiable, such as X ~ X + 1."
66+
end
67+
68+
function AffectSystem(affect::Vector{Equation}; discrete_parameters = SymbolicT[],
6169
iv = nothing, alg_eqs::Vector{Equation} = Equation[], warn_no_algebraic = true, kwargs...)
6270
isempty(affect) && return nothing
6371
if isnothing(iv)
6472
iv = t_nounits
6573
@warn "No independent variable specified. Defaulting to t_nounits."
6674
end
6775

68-
discrete_parameters isa AbstractVector || (discrete_parameters = [discrete_parameters])
69-
discrete_parameters = unwrap.(discrete_parameters)
76+
discrete_parameters = SymbolicAffect(affect; alg_eqs, discrete_parameters).discrete_parameters
7077

7178
for p in discrete_parameters
7279
SU.query!(isequal(unwrap(iv)), unwrap(p)) ||
7380
error("Non-time dependent parameter $p passed in as a discrete. Must be declared as @parameters $p(t).")
7481
end
7582

76-
dvs = OrderedSet()
77-
params = OrderedSet()
78-
_varsbuf = Set()
83+
dvs = OrderedSet{SymbolicT}()
84+
params = OrderedSet{SymbolicT}()
85+
_varsbuf = Set{SymbolicT}()
7986
for eq in affect
80-
if !haspre(eq) && !(symbolic_type(eq.rhs) === NotSymbolic() ||
81-
symbolic_type(eq.lhs) === NotSymbolic())
82-
@warn "Affect equation $eq has no `Pre` operator. As such it will be interpreted as an algebraic equation to be satisfied after the callback. If you intended to use the value of a variable x before the affect, use Pre(x). Errors may be thrown if there is no `Pre` and the algebraic equation is unsatisfiable, such as X ~ X + 1."
87+
if !haspre(eq) && !(isconst(eq.lhs) && isconst(eq.rhs))
88+
@invokelatest warn_algebraic_equation(eq)
8389
end
8490
collect_vars!(dvs, params, eq, iv; op = Pre)
8591
empty!(_varsbuf)
86-
vars!(_varsbuf, eq; op = Pre)
87-
filter!(x -> iscall(x) && operation(x) isa Pre, _varsbuf)
92+
SU.search_variables!(_varsbuf, eq; is_atomic = OperatorIsAtomic{Pre}())
93+
filter!(x -> iscall(x) && operation(x) === Pre(), _varsbuf)
8894
union!(params, _varsbuf)
8995
diffvs = collect_applied_operators(eq, Differential)
9096
union!(dvs, diffvs)
9197
end
9298
for eq in alg_eqs
9399
collect_vars!(dvs, params, eq, iv)
94100
end
95-
pre_params = filter(haspre value, params)
96-
sys_params = collect(setdiff(params, union(discrete_parameters, pre_params)))
101+
pre_params = filter(haspre, params)
102+
sys_params = SymbolicT[]
103+
disc_ps_set = Set{SymbolicT}(discrete_parameters)
104+
for p in params
105+
p in disc_ps_set && continue
106+
p in pre_params && continue
107+
push!(sys_params, p)
108+
end
97109
discretes = map(tovar, discrete_parameters)
98110
dvs = collect(dvs)
99111
_dvs = map(default_toterm, dvs)
100112

101-
rev_map = Dict(zip(discrete_parameters, discretes))
102-
subs = merge(rev_map, Dict(zip(dvs, _dvs)))
113+
rev_map = Dict{SymbolicT, SymbolicT}(zip(discrete_parameters, discretes))
114+
subs = merge(rev_map, Dict{SymbolicT, SymbolicT}(zip(dvs, _dvs)))
103115
affect = substitute(affect, subs)
104116
alg_eqs = substitute(alg_eqs, subs)
105117

@@ -108,14 +120,13 @@ function AffectSystem(affect::Vector{Equation}; discrete_parameters = Any[],
108120
collect(union(pre_params, sys_params)); is_discrete = true)
109121
affectsys = mtkcompile(affectsys; fully_determined = nothing)
110122
# get accessed parameters p from Pre(p) in the callback parameters
111-
accessed_params = Vector{Any}(filter(isparameter, map(unPre, collect(pre_params))))
123+
accessed_params = Vector{SymbolicT}(filter(isparameter, map(unPre, collect(pre_params))))
112124
union!(accessed_params, sys_params)
113125

114126
# add scalarized unknowns to the map.
115-
_dvs = reduce(vcat, map(scalarize, _dvs), init = Any[])
127+
_dvs = reduce(vcat, map(scalarize, _dvs), init = SymbolicT[])
116128

117-
AffectSystem(affectsys, collect(_dvs), collect(accessed_params),
118-
collect(discrete_parameters))
129+
AffectSystem(affectsys, _dvs, accessed_params, discrete_parameters)
119130
end
120131

121132
system(a::AffectSystem) = a.system

0 commit comments

Comments
 (0)