Skip to content

Commit 587476c

Browse files
refactor: make System more concretely typed
1 parent 8393cb7 commit 587476c

File tree

2 files changed

+81
-52
lines changed

2 files changed

+81
-52
lines changed

src/ModelingToolkit.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,9 @@ include("parameters.jl")
153153
include("independent_variables.jl")
154154
include("constants.jl")
155155

156+
const SymmapT = Dict{SymbolicT, SymbolicT}
157+
const COMMON_NOTHING = SU.Const{VartypeT}(nothing)
158+
156159
include("utils.jl")
157160

158161
include("systems/index_cache.jl")

src/systems/system.jl

Lines changed: 78 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ struct System <: IntermediateDeprecationSystem
6363
loss of an optimization problem. Scalar loss values must also be provided as a single-
6464
element vector.
6565
"""
66-
costs::Vector{<:Union{BasicSymbolic, Real}}
66+
costs::Vector{SymbolicT}
6767
"""
6868
A function which combines costs into a scalar value. This should take two arguments,
6969
the `costs` of this system and the consolidated costs of all subsystems in the order
@@ -76,20 +76,20 @@ struct System <: IntermediateDeprecationSystem
7676
The variables being solved for by this system. For example, in a differential equation
7777
system, this contains the dependent variables.
7878
"""
79-
unknowns::Vector
79+
unknowns::Vector{SymbolicT}
8080
"""
8181
The parameters of the system. Parameters can either be variables that parameterize the
8282
problem being solved for (e.g. the spring constant of a mass-spring system) or
8383
additional unknowns not part of the main dynamics of the system (e.g. discrete/clocked
8484
variables in a hybrid ODE).
8585
"""
86-
ps::Vector
86+
ps::Vector{SymbolicT}
8787
"""
8888
The brownian variables of the system, created via `@brownians`. Each brownian variable
8989
represents an independent noise. A system with brownians cannot be simulated directly.
9090
It needs to be compiled using `mtkcompile` into `noise_eqs`.
9191
"""
92-
brownians::Vector
92+
brownians::Vector{SymbolicT}
9393
"""
9494
The independent variable for a time-dependent system, or `nothing` for a time-independent
9595
system.
@@ -117,7 +117,7 @@ struct System <: IntermediateDeprecationSystem
117117
A mapping from the name of a variable to the actual symbolic variable in the system.
118118
This is used to enable `getproperty` syntax to access variables of a system.
119119
"""
120-
var_to_name::Dict{Symbol, Any}
120+
var_to_name::Dict{Symbol, SymbolicT}
121121
"""
122122
The name of the system.
123123
"""
@@ -132,11 +132,11 @@ struct System <: IntermediateDeprecationSystem
132132
by initial values provided to the problem constructor. Defaults of parent systems
133133
take priority over those in child systems.
134134
"""
135-
defaults::Dict
135+
defaults::SymmapT
136136
"""
137137
Guess values for variables of a system that are solved for during initialization.
138138
"""
139-
guesses::Dict
139+
guesses::SymmapT
140140
"""
141141
A list of subsystems of this system. Used for hierarchically building models.
142142
"""
@@ -167,7 +167,7 @@ struct System <: IntermediateDeprecationSystem
167167
associated error message. By default these assertions cause the generated code to
168168
output `NaN`s if violated, but can be made to error using `debug_system`.
169169
"""
170-
assertions::Dict{BasicSymbolic, String}
170+
assertions::Dict{SymbolicT, String}
171171
"""
172172
The metadata associated with this system, as a `Base.ImmutableDict`. This follows
173173
the same interface as SymbolicUtils.jl. Metadata can be queried and updated using
@@ -193,12 +193,12 @@ struct System <: IntermediateDeprecationSystem
193193
$INTERNAL_FIELD_WARNING
194194
The list of input variables of the system.
195195
"""
196-
inputs::OrderedSet{BasicSymbolic}
196+
inputs::OrderedSet{SymbolicT}
197197
"""
198198
$INTERNAL_FIELD_WARNING
199199
The list of output variables of the system.
200200
"""
201-
outputs::OrderedSet{BasicSymbolic}
201+
outputs::OrderedSet{SymbolicT}
202202
"""
203203
The `TearingState` of the system post-simplification with `mtkcompile`.
204204
"""
@@ -264,9 +264,9 @@ struct System <: IntermediateDeprecationSystem
264264
tag, eqs, noise_eqs, jumps, constraints, costs, consolidate, unknowns, ps,
265265
brownians, iv, observed, parameter_dependencies, var_to_name, name, description,
266266
defaults, guesses, systems, initialization_eqs, continuous_events, discrete_events,
267-
connector_type, assertions = Dict{BasicSymbolic, String}(),
267+
connector_type, assertions = Dict{SymbolicT, String}(),
268268
metadata = MetadataT(), gui_metadata = nothing, is_dde = false, tstops = [],
269-
inputs = Set{BasicSymbolic}(), outputs = Set{BasicSymbolic}(),
269+
inputs = Set{SymbolicT}(), outputs = Set{SymbolicT}(),
270270
tearing_state = nothing, namespacing = true,
271271
complete = false, index_cache = nothing, ignored_connections = nothing,
272272
preface = nothing, parent = nothing, initializesystem = nothing,
@@ -321,6 +321,26 @@ function default_consolidate(costs, subcosts)
321321
return reduce(+, costs; init = 0.0) + reduce(+, subcosts; init = 0.0)
322322
end
323323

324+
function unwrap_vars(vars::AbstractArray{SymbolicT})
325+
vec(vars)
326+
end
327+
function unwrap_vars(vars)
328+
result = SymbolicT[]
329+
for var in vars
330+
push!(result, unwrap(var))
331+
end
332+
return result
333+
end
334+
335+
defsdict(x::SymmapT) = x
336+
function defsdict(x::AbstractDict)
337+
result = SymmapT()
338+
for (k, v) in x
339+
result[unwrap(k)] = SU.Const{VartypeT}(v)
340+
end
341+
return result
342+
end
343+
324344
"""
325345
$(TYPEDSIGNATURES)
326346
@@ -337,71 +357,72 @@ for time-independent systems, unknowns `dvs`, parameters `ps` and brownian varia
337357
All other keyword arguments are named identically to the corresponding fields in
338358
[`System`](@ref).
339359
"""
340-
function System(eqs::Vector{Equation}, iv, dvs, ps, brownians = [];
341-
constraints = Union{Equation, Inequality}[], noise_eqs = nothing, jumps = [],
342-
costs = BasicSymbolic[], consolidate = default_consolidate,
343-
observed = Equation[], parameter_dependencies = Equation[], defaults = Dict(),
344-
guesses = Dict(), systems = System[], initialization_eqs = Equation[],
360+
function System(eqs::Vector{Equation}, iv, dvs, ps, brownians = SymbolicT[];
361+
constraints = Union{Equation, Inequality}[], noise_eqs = nothing, jumps = JumpType[],
362+
costs = SymbolicT[], consolidate = default_consolidate,
363+
observed = Equation[], parameter_dependencies = Equation[], defaults = SymmapT(),
364+
guesses = SymmapT(), systems = System[], initialization_eqs = Equation[],
345365
continuous_events = SymbolicContinuousCallback[], discrete_events = SymbolicDiscreteCallback[],
346-
connector_type = nothing, assertions = Dict{BasicSymbolic, String}(),
366+
connector_type = nothing, assertions = Dict{SymbolicT, String}(),
347367
metadata = MetadataT(), gui_metadata = nothing,
348-
is_dde = nothing, tstops = [], inputs = OrderedSet{BasicSymbolic}(),
349-
outputs = OrderedSet{BasicSymbolic}(), tearing_state = nothing,
368+
is_dde = nothing, tstops = [], inputs = OrderedSet{SymbolicT}(),
369+
outputs = OrderedSet{SymbolicT}(), tearing_state = nothing,
350370
ignored_connections = nothing, parent = nothing,
351371
description = "", name = nothing, discover_from_metadata = true,
352372
initializesystem = nothing, is_initializesystem = false, is_discrete = false,
353373
preface = [], checks = true)
354374
name === nothing && throw(NoNameError())
355375
if !isempty(parameter_dependencies)
356-
@warn """
357-
The `parameter_dependencies` keyword argument is deprecated. Please provide all
358-
such equations as part of the normal equations of the system.
359-
"""
376+
@invokelatest warn_pdeps()
360377
eqs = Equation[eqs; parameter_dependencies]
361378
end
362379

363380
iv = unwrap(iv)
364-
ps = unwrap.(ps)
365-
dvs = unwrap.(dvs)
366-
filter!(!Base.Fix2(isdelay, iv), dvs)
367-
brownians = unwrap.(brownians)
381+
ps = unwrap_vars(ps)
382+
dvs = unwrap_vars(dvs)
383+
if iv !== nothing
384+
filter!(!Base.Fix2(isdelay, iv), dvs)
385+
end
386+
brownians = unwrap_vars(brownians)
368387

369-
if !(eqs isa AbstractArray)
370-
eqs = [eqs]
388+
if !(eqs isa Vector{Equation})
389+
eqs = Equation[eqs]
371390
end
391+
eqs = eqs::Vector{Equation}
372392

373393
if noise_eqs !== nothing
374394
noise_eqs = unwrap.(noise_eqs)
375395
end
376396

377-
costs = unwrap.(costs)
378-
if isempty(costs)
379-
costs = Union{BasicSymbolic, Real}[]
380-
end
397+
costs = unwrap_vars(costs)
381398

382-
defaults = anydict(defaults)
383-
guesses = anydict(guesses)
384-
inputs = OrderedSet{BasicSymbolic}(inputs)
385-
outputs = OrderedSet{BasicSymbolic}(outputs)
399+
defaults = defsdict(defaults)
400+
guesses = defsdict(guesses)
401+
if !(inputs isa OrderedSet{SymbolicT})
402+
inputs = OrderedSet{SymbolicT}(inputs)
403+
end
404+
if !(outputs isa OrderedSet{SymbolicT})
405+
outputs = OrderedSet{SymbolicT}(outputs)
406+
end
386407
for subsys in systems
387-
for var in ModelingToolkit.inputs(subsys)
408+
for var in get_inputs(subsys)
388409
push!(inputs, renamespace(subsys, var))
389410
end
390-
for var in ModelingToolkit.outputs(subsys)
411+
for var in get_outputs(subsys)
391412
push!(outputs, renamespace(subsys, var))
392413
end
393414
end
394-
var_to_name = anydict()
415+
var_to_name = Dict{Symbol, SymbolicT}()
395416

396-
let defaults = discover_from_metadata ? defaults : Dict(),
397-
guesses = discover_from_metadata ? guesses : Dict(),
398-
inputs = discover_from_metadata ? inputs : Set(),
399-
outputs = discover_from_metadata ? outputs : Set()
417+
let defaults = discover_from_metadata ? defaults : SymmapT(),
418+
guesses = discover_from_metadata ? guesses : SymmapT(),
419+
inputs = discover_from_metadata ? inputs : OrderedSet{SymbolicT}(),
420+
outputs = discover_from_metadata ? outputs : OrderedSet{SymbolicT}()
400421

401422
process_variables!(var_to_name, defaults, guesses, dvs)
402423
process_variables!(var_to_name, defaults, guesses, ps)
403-
process_variables!(var_to_name, defaults, guesses, [eq.lhs for eq in observed])
404-
process_variables!(var_to_name, defaults, guesses, [eq.rhs for eq in observed])
424+
process_variables!(var_to_name, defaults, guesses, SymbolicT[eq.lhs for eq in observed])
425+
process_variables!(var_to_name, defaults, guesses, SymbolicT[eq.rhs for eq in observed])
405426

406427
for var in dvs
407428
if isinput(var)
@@ -411,10 +432,8 @@ function System(eqs::Vector{Equation}, iv, dvs, ps, brownians = [];
411432
end
412433
end
413434
end
414-
filter!(!(isnothing last), defaults)
415-
filter!(!(isnothing last), guesses)
416-
defaults = anydict([unwrap(k) => unwrap(v) for (k, v) in defaults])
417-
guesses = anydict([unwrap(k) => unwrap(v) for (k, v) in guesses])
435+
filter!(!(Base.Fix1(===, COMMON_NOTHING) last), defaults)
436+
filter!(!(Base.Fix1(===, COMMON_NOTHING) last), guesses)
418437

419438
sysnames = nameof.(systems)
420439
unique_sysnames = Set(sysnames)
@@ -433,7 +452,7 @@ function System(eqs::Vector{Equation}, iv, dvs, ps, brownians = [];
433452
is_dde = _check_if_dde(eqs, iv, systems)
434453
end
435454

436-
assertions = Dict{BasicSymbolic, String}(unwrap(k) => v for (k, v) in assertions)
455+
assertions = Dict{SymbolicT, String}(unwrap(k) => v for (k, v) in assertions)
437456

438457
if isempty(metadata)
439458
metadata = MetadataT()
@@ -456,6 +475,13 @@ function System(eqs::Vector{Equation}, iv, dvs, ps, brownians = [];
456475
initializesystem, is_initializesystem, is_discrete; checks)
457476
end
458477

478+
@noinline function warn_pdeps()
479+
@warn """
480+
The `parameter_dependencies` keyword argument is deprecated. Please provide all
481+
such equations as part of the normal equations of the system.
482+
"""
483+
end
484+
459485
SymbolicIndexingInterface.getname(x::System) = nameof(x)
460486

461487
"""

0 commit comments

Comments
 (0)