Skip to content

Commit 41ca68b

Browse files
fix: make add_initialization_parameters type-stable
1 parent badfb9f commit 41ca68b

File tree

1 file changed

+18
-10
lines changed

1 file changed

+18
-10
lines changed

src/systems/abstractsystem.jl

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -537,16 +537,20 @@ function add_initialization_parameters(sys::AbstractSystem; split = true)
537537
supports_initialization(sys) || return sys
538538
is_initializesystem(sys) && return sys
539539

540-
all_initialvars = Set{BasicSymbolic}()
540+
all_initialvars = Set{SymbolicT}()
541541
# time-independent systems don't initialize unknowns
542542
# but may initialize parameters using guesses for unknowns
543543
eqs = equations(sys)
544-
if !(eqs isa Vector{Equation})
545-
eqs = Equation[x for x in eqs if x isa Equation]
546-
end
547544
obs, eqs = unhack_observed(observed(sys), eqs)
548-
for x in Iterators.flatten((unknowns(sys), Iterators.map(eq -> eq.lhs, obs)))
549-
x = unwrap(x)
545+
for x in unknowns(sys)
546+
if iscall(x) && operation(x) == getindex && split
547+
push!(all_initialvars, arguments(x)[1])
548+
else
549+
push!(all_initialvars, x)
550+
end
551+
end
552+
for eq in obs
553+
x = eq.lhs
550554
if iscall(x) && operation(x) == getindex && split
551555
push!(all_initialvars, arguments(x)[1])
552556
else
@@ -556,15 +560,19 @@ function add_initialization_parameters(sys::AbstractSystem; split = true)
556560

557561
# add derivatives of all variables for steady-state initial conditions
558562
if is_time_dependent(sys) && !is_discrete_system(sys)
559-
D = Differential(get_iv(sys))
560-
union!(all_initialvars, [D(v) for v in all_initialvars if iscall(v)])
563+
D = Differential(get_iv(sys)::SymbolicT)
564+
for v in all_initialvars
565+
iscall(v) && push!(all_initialvars, D(v))
566+
end
561567
end
562568
for eq in get_parameter_dependencies(sys)
563569
is_variable_floatingpoint(eq.lhs) || continue
564570
push!(all_initialvars, eq.lhs)
565571
end
566-
all_initialvars = collect(all_initialvars)
567-
initials = map(Initial(), all_initialvars)
572+
initials = collect(all_initialvars)
573+
for (i, v) in enumerate(initials)
574+
initials[i] = Initial()(v)
575+
end
568576
@set! sys.ps = unique!([get_ps(sys); initials])
569577
defs = copy(get_defaults(sys))
570578
for ivar in initials

0 commit comments

Comments
 (0)