Skip to content

Commit 5cad99f

Browse files
refactor: get System to precompile in a trivial case
1 parent d2a74a0 commit 5cad99f

File tree

3 files changed

+117
-84
lines changed

3 files changed

+117
-84
lines changed

src/systems/callbacks.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -568,6 +568,9 @@ conditions(cb::AbstractCallback) = cb.conditions
568568
function conditions(cbs::Vector{<:AbstractCallback})
569569
reduce(vcat, conditions(cb) for cb in cbs; init = [])
570570
end
571+
function conditions(cbs::Vector{SymbolicContinuousCallback})
572+
mapreduce(conditions, vcat, cbs; init = Equation[])
573+
end
571574
equations(cb::AbstractCallback) = conditions(cb)
572575
equations(cb::Vector{<:AbstractCallback}) = conditions(cb)
573576

src/systems/system.jl

Lines changed: 69 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ struct System <: IntermediateDeprecationSystem
4646
this noise matrix is diagonal. Diagonal noise can be specified by providing an `N`
4747
length vector. If this field is `nothing`, the system does not have noise.
4848
"""
49-
noise_eqs::Union{Nothing, AbstractVector, AbstractMatrix}
49+
noise_eqs::Union{Nothing, Vector{SymbolicT}, Matrix{SymbolicT}}
5050
"""
5151
Jumps associated with the system. Each jump can be a `VariableRateJump`,
5252
`ConstantRateJump` or `MassActionJump`. See `JumpProcesses.jl` for more information.
@@ -279,30 +279,37 @@ struct System <: IntermediateDeprecationSystem
279279
"""))
280280
end
281281
@assert iv === nothing || symtype(iv) === Real
282-
jumps = Vector{JumpType}(jumps)
283-
if (checks == true || (checks & CheckComponents) > 0) && iv !== nothing
284-
check_independent_variables([iv])
282+
if (checks isa Bool && checks === true || checks isa Int && (checks & CheckComponents) > 0) && iv !== nothing
283+
check_independent_variables((iv,))
285284
check_variables(unknowns, iv)
286285
check_parameters(ps, iv)
287286
check_equations(eqs, iv)
288-
if noise_eqs !== nothing && size(noise_eqs, 1) != length(eqs)
289-
throw(IllFormedNoiseEquationsError(size(noise_eqs, 1), length(eqs)))
287+
Neq = length(eqs)
288+
if noise_eqs isa Matrix{SymbolicT}
289+
N1 = size(noise_eqs, 1)
290+
elseif noise_eqs isa Vector{SymbolicT}
291+
N1 = length(noise_eqs)
292+
elseif noise_eqs === nothing
293+
N1 = Neq
294+
else
295+
error()
290296
end
297+
N1 == Neq || throw(IllFormedNoiseEquationsError(N1, Neq))
291298
check_equations(equations(continuous_events), iv)
292299
check_subsystems(systems)
293300
end
294-
if checks == true || (checks & CheckUnits) > 0
295-
u = __get_unit_type(unknowns, ps, iv)
296-
if noise_eqs === nothing
297-
check_units(u, eqs)
298-
else
299-
check_units(u, eqs, noise_eqs)
300-
end
301-
if iv !== nothing
302-
check_units(u, jumps, iv)
303-
end
304-
isempty(constraints) || check_units(u, constraints)
305-
end
301+
# if checks == true || (checks & CheckUnits) > 0
302+
# u = __get_unit_type(unknowns, ps, iv)
303+
# if noise_eqs === nothing
304+
# check_units(u, eqs)
305+
# else
306+
# check_units(u, eqs, noise_eqs)
307+
# end
308+
# if iv !== nothing
309+
# check_units(u, jumps, iv)
310+
# end
311+
# isempty(constraints) || check_units(u, constraints)
312+
# end
306313
new(tag, eqs, noise_eqs, jumps, constraints, costs,
307314
consolidate, unknowns, ps, brownians, iv,
308315
observed, parameter_dependencies, var_to_name, name, description, defaults,
@@ -321,13 +328,11 @@ function default_consolidate(costs, subcosts)
321328
return reduce(+, costs; init = 0.0) + reduce(+, subcosts; init = 0.0)
322329
end
323330

324-
function unwrap_vars(vars::AbstractArray{SymbolicT})
325-
vec(vars)
326-
end
327-
function unwrap_vars(vars)
328-
result = SymbolicT[]
329-
for var in vars
330-
push!(result, unwrap(var))
331+
unwrap_vars(vars::AbstractArray{SymbolicT}) = vars
332+
function unwrap_vars(vars::AbstractArray)
333+
result = similar(vars, SymbolicT)
334+
for i in eachindex(vars)
335+
result[i] = SU.Const{VartypeT}(vars[i])
331336
end
332337
return result
333338
end
@@ -372,29 +377,30 @@ function System(eqs::Vector{Equation}, iv, dvs, ps, brownians = SymbolicT[];
372377
initializesystem = nothing, is_initializesystem = false, is_discrete = false,
373378
preface = [], checks = true)
374379
name === nothing && throw(NoNameError())
380+
381+
if !(eqs isa Vector{Equation})
382+
eqs = Equation[eqs]
383+
end
384+
eqs = eqs::Vector{Equation}
385+
375386
if !isempty(parameter_dependencies)
376387
@invokelatest warn_pdeps()
377-
eqs = Equation[eqs; parameter_dependencies]
388+
append!(eqs, parameter_dependencies)
378389
end
379390

380391
iv = unwrap(iv)
381-
ps = unwrap_vars(ps)
382-
dvs = unwrap_vars(dvs)
392+
ps = vec(unwrap_vars(ps))
393+
dvs = vec(unwrap_vars(dvs))
383394
if iv !== nothing
384395
filter!(!Base.Fix2(isdelay, iv), dvs)
385396
end
386397
brownians = unwrap_vars(brownians)
387398

388-
if !(eqs isa Vector{Equation})
389-
eqs = Equation[eqs]
390-
end
391-
eqs = eqs::Vector{Equation}
392-
393399
if noise_eqs !== nothing
394-
noise_eqs = unwrap.(noise_eqs)
400+
noise_eqs = unwrap_vars(noise_eqs)
395401
end
396402

397-
costs = unwrap_vars(costs)
403+
costs = vec(unwrap_vars(costs))
398404

399405
defaults = defsdict(defaults)
400406
guesses = defsdict(guesses)
@@ -421,8 +427,12 @@ function System(eqs::Vector{Equation}, iv, dvs, ps, brownians = SymbolicT[];
421427

422428
process_variables!(var_to_name, defaults, guesses, dvs)
423429
process_variables!(var_to_name, defaults, guesses, ps)
424-
process_variables!(var_to_name, defaults, guesses, SymbolicT[eq.lhs for eq in observed])
425-
process_variables!(var_to_name, defaults, guesses, SymbolicT[eq.rhs for eq in observed])
430+
buffer = SymbolicT[]
431+
for eq in observed
432+
push!(buffer, eq.lhs)
433+
push!(buffer, eq.rhs)
434+
end
435+
process_variables!(var_to_name, defaults, guesses, buffer)
426436

427437
for var in dvs
428438
if isinput(var)
@@ -435,10 +445,9 @@ function System(eqs::Vector{Equation}, iv, dvs, ps, brownians = SymbolicT[];
435445
filter!(!(Base.Fix1(===, COMMON_NOTHING) last), defaults)
436446
filter!(!(Base.Fix1(===, COMMON_NOTHING) last), guesses)
437447

438-
sysnames = nameof.(systems)
439-
unique_sysnames = Set(sysnames)
440-
if length(unique_sysnames) != length(sysnames)
441-
throw(NonUniqueSubsystemsError(sysnames, unique_sysnames))
448+
449+
if !allunique(map(nameof, systems))
450+
nonunique_subsystems(systems)
442451
end
443452
continuous_events,
444453
discrete_events = create_symbolic_events(
@@ -452,7 +461,10 @@ function System(eqs::Vector{Equation}, iv, dvs, ps, brownians = SymbolicT[];
452461
is_dde = _check_if_dde(eqs, iv, systems)
453462
end
454463

455-
assertions = Dict{SymbolicT, String}(unwrap(k) => v for (k, v) in assertions)
464+
_assertions = Dict{SymbolicT, String}
465+
for (k, v) in assertions
466+
_assertions[unwrap(k)::SymbolicT] = v
467+
end
456468

457469
if isempty(metadata)
458470
metadata = MetadataT()
@@ -466,6 +478,7 @@ function System(eqs::Vector{Equation}, iv, dvs, ps, brownians = SymbolicT[];
466478
metadata = meta
467479
end
468480
metadata = refreshed_metadata(metadata)
481+
jumps = Vector{JumpType}(jumps)
469482
System(Threads.atomic_add!(SYSTEM_COUNT, UInt(1)), eqs, noise_eqs, jumps, constraints,
470483
costs, consolidate, dvs, ps, brownians, iv, observed, Equation[],
471484
var_to_name, name, description, defaults, guesses, systems, initialization_eqs,
@@ -475,6 +488,12 @@ function System(eqs::Vector{Equation}, iv, dvs, ps, brownians = SymbolicT[];
475488
initializesystem, is_initializesystem, is_discrete; checks)
476489
end
477490

491+
@noinline function nonunique_subsystems(systems)
492+
sysnames = nameof.(systems)
493+
unique_sysnames = Set(sysnames)
494+
throw(NonUniqueSubsystemsError(sysnames, unique_sysnames))
495+
end
496+
478497
@noinline function warn_pdeps()
479498
@warn """
480499
The `parameter_dependencies` keyword argument is deprecated. Please provide all
@@ -749,19 +768,15 @@ differential equations.
749768
"""
750769
is_dde(sys::AbstractSystem) = has_is_dde(sys) && get_is_dde(sys)
751770

752-
function _check_if_dde(eqs, iv, subsystems)
753-
is_dde = any(ModelingToolkit.is_dde, subsystems)
754-
if !is_dde
755-
vs = Set()
756-
for eq in eqs
757-
vars!(vs, eq)
758-
is_dde = any(vs) do sym
759-
isdelay(unwrap(sym), iv)
760-
end
761-
is_dde && break
762-
end
771+
_check_if_dde(eqs::Vector{Equation}, iv::Nothing, subsystems::Vector{System}) = false
772+
function _check_if_dde(eqs::Vector{Equation}, iv::SymbolicT, subsystems::Vector{System})
773+
any(ModelingToolkit.is_dde, subsystems) && return true
774+
pred = Base.Fix2(isdelay, iv)
775+
for eq in eqs
776+
SU.query!(pred, eq.lhs) && return true
777+
SU.query!(pred, eq.rhs) && return true
763778
end
764-
return is_dde
779+
return false
765780
end
766781

767782
"""

src/utils.jl

Lines changed: 45 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -117,34 +117,31 @@ const CheckUnits = 1 << 2
117117

118118
function check_independent_variables(ivs)
119119
for iv in ivs
120-
isparameter(iv) ||
121-
@warn "Independent variable $iv should be defined with @independent_variables $iv."
120+
isparameter(iv) || @invokelatest warn_indepvar(iv)
122121
end
123122
end
124123

124+
@noinline function warn_indepvar(iv::SymbolicT)
125+
@warn "Independent variable $iv should be defined with @independent_variables $iv."
126+
end
127+
125128
function check_parameters(ps, iv)
126129
for p in ps
127130
isequal(iv, p) &&
128131
throw(ArgumentError("Independent variable $iv not allowed in parameters."))
129132
end
130133
end
131134

132-
function is_delay_var(iv, var)
133-
if Symbolics.isarraysymbolic(var)
134-
return is_delay_var(iv, first(collect(var)))
135-
end
136-
args = nothing
137-
try
138-
args = arguments(var)
139-
catch
140-
return false
135+
function is_delay_var(iv::SymbolicT, var::SymbolicT)
136+
Moshi.Match.@match var begin
137+
BSImpl.Term(; f, args) => begin
138+
length(args) > 1 && return false
139+
arg = args[1]
140+
isequal(arg, iv) && return false
141+
return symtype(arg) <: Real
142+
end
143+
_ => false
141144
end
142-
length(args) > 1 && return false
143-
isequal(first(args), iv) && return false
144-
delay = iv - first(args)
145-
delay isa Integer ||
146-
delay isa AbstractFloat ||
147-
(delay isa Num && isreal(value(delay)))
148145
end
149146

150147
function check_variables(dvs, iv)
@@ -187,20 +184,35 @@ function collect_ivs(eqs, op = Differential)
187184
return ivs
188185
end
189186

187+
struct IndepvarCheckPredicate
188+
iv::SymbolicT
189+
end
190+
191+
function (icp::IndepvarCheckPredicate)(ex::SymbolicT)
192+
Moshi.Match.@match ex begin
193+
BSImpl.Term(; f) && if f isa Differential end => begin
194+
f = f::Differential
195+
isequal(f.x, icp.iv) || throw_multiple_iv(icp.iv, f.x)
196+
return false
197+
end
198+
_ => false
199+
end
200+
end
201+
202+
@noinline function throw_multiple_iv(iv, newiv)
203+
throw(ArgumentError("Differential w.r.t. variable ($newiv) other than the independent variable ($iv) are not allowed."))
204+
end
205+
190206
"""
191207
check_equations(eqs, iv)
192208
193209
Assert that equations are well-formed when building ODE, i.e., only containing a single independent variable.
194210
"""
195-
function check_equations(eqs, iv)
196-
ivs = collect_ivs(eqs)
197-
display = collect(ivs)
198-
length(ivs) <= 1 ||
199-
throw(ArgumentError("Differential w.r.t. multiple variables $display are not allowed."))
200-
if length(ivs) == 1
201-
single_iv = pop!(ivs)
202-
isequal(single_iv, iv) ||
203-
throw(ArgumentError("Differential w.r.t. variable ($single_iv) other than the independent variable ($iv) are not allowed."))
211+
function check_equations(eqs::Vector{Equation}, iv::SymbolicT)
212+
icp = IndepvarCheckPredicate(iv)
213+
for eq in eqs
214+
SU.query!(icp, eq.lhs)
215+
SU.query!(icp, eq.rhs)
204216
end
205217
end
206218

@@ -211,10 +223,12 @@ Assert that the subsystems have the appropriate namespacing behavior.
211223
"""
212224
function check_subsystems(systems)
213225
idxs = findall(!does_namespacing, systems)
214-
if !isempty(idxs)
215-
names = join(" " .* string.(nameof.(systems[idxs])), "\n")
216-
throw(ArgumentError("All subsystems must have namespacing enabled. The following subsystems do not perform namespacing:\n$(names)"))
217-
end
226+
isempty(idxs) || throw_bad_namespacing(systems, idxs)
227+
end
228+
229+
@noinline function throw_bad_namespacing(systems, idxs)
230+
names = join(" " .* string.(nameof.(systems[idxs])), "\n")
231+
throw(ArgumentError("All subsystems must have namespacing enabled. The following subsystems do not perform namespacing:\n$(names)"))
218232
end
219233

220234
"""
@@ -626,6 +640,7 @@ function collect_vars!(unknowns, parameters, expr, iv; depth = 0, op = Symbolics
626640
if issym(expr)
627641
return collect_var!(unknowns, parameters, expr, iv; depth)
628642
end
643+
SymbolicUtils.isconst(expr) && return
629644
for var in vars(expr; op)
630645
while iscall(var) && operation(var) isa op
631646
validate_operator(operation(var), arguments(var), iv; context = expr)

0 commit comments

Comments
 (0)