Skip to content

Commit 6804adb

Browse files
fix: better handle absent parameter derivatives in simplification
1 parent 83faba1 commit 6804adb

File tree

2 files changed

+11
-7
lines changed

2 files changed

+11
-7
lines changed

src/structural_transformation/symbolics_tearing.jl

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -72,9 +72,7 @@ function eq_derivative!(ts::TearingState, ieq::Int; kwargs...)
7272
vs = Set{SymbolicT}()
7373
SU.search_variables!(vs, eq.rhs)
7474
for v in vs
75-
# parameters with unknown derivatives have a value of `nothing` in the map,
76-
# so use `missing` as the default.
77-
get(ts.param_derivative_map, v, missing) === nothing || continue
75+
v in ts.no_deriv_params || continue
7876
_original_eq = equations(ts)[ieq]
7977
error("""
8078
Encountered derivative of discrete variable `$(only(arguments(v)))` when \

src/systems/systemstructure.jl

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,7 @@ mutable struct TearingState{T <: AbstractSystem} <: AbstractTearingState{T}
209209
structure::SystemStructure
210210
extra_eqs::Vector{Equation}
211211
param_derivative_map::Dict{SymbolicT, SymbolicT}
212+
no_deriv_params::Set{SymbolicT}
212213
original_eqs::Vector{Equation}
213214
"""
214215
Additional user-provided observed equations. The variables calculated here
@@ -362,6 +363,7 @@ function TearingState(sys; check = true, sort_eqs = true)
362363
original_eqs = copy(eqs)
363364
neqs = length(eqs)
364365
param_derivative_map = Dict{SymbolicT, SymbolicT}()
366+
no_deriv_params = Set{SymbolicT}()
365367
fullvars = SymbolicT[]
366368
# * Scalarize unknowns
367369
dvs = Set{SymbolicT}()
@@ -380,7 +382,7 @@ function TearingState(sys; check = true, sort_eqs = true)
380382
varsbuf = Set{SymbolicT}()
381383
eqs_to_retain = trues(length(eqs))
382384
for (i, eq) in enumerate(eqs)
383-
eq, is_statemachine_equation = canonicalize_eq!(param_derivative_map, eqs_to_retain, ps, iv, i, eq)
385+
eq, is_statemachine_equation = canonicalize_eq!(param_derivative_map, no_deriv_params, eqs_to_retain, ps, iv, i, eq)
384386
empty!(varsbuf)
385387
SU.search_variables!(varsbuf, eq; is_atomic = OperatorIsAtomic{SU.Operator}())
386388
incidence = Set{SymbolicT}()
@@ -396,7 +398,7 @@ function TearingState(sys; check = true, sort_eqs = true)
396398
if symbolic_contains(v, ps) ||
397399
getmetadata(v, SymScope, LocalScope()) isa GlobalScope && isparameter(v)
398400
if is_time_dependent_parameter(v, ps, iv) &&
399-
!haskey(param_derivative_map, Differential(iv)(v))
401+
!haskey(param_derivative_map, Differential(iv)(v)) && !(Differential(iv)(v) in no_deriv_params)
400402
# Parameter derivatives default to zero - they stay constant
401403
# between callbacks
402404
param_derivative_map[Differential(iv)(v)] = Symbolics.COMMON_ZERO
@@ -480,6 +482,8 @@ function TearingState(sys; check = true, sort_eqs = true)
480482
push!(symbolic_incidence, collect(incidence))
481483
end
482484

485+
filter!(Base.Fix2(!==, COMMON_NOTHING) last, param_derivative_map)
486+
483487
eqs = eqs[eqs_to_retain]
484488
original_eqs = original_eqs[eqs_to_retain]
485489
neqs = length(eqs)
@@ -520,7 +524,7 @@ function TearingState(sys; check = true, sort_eqs = true)
520524
return TearingState{typeof(sys)}(sys, fullvars,
521525
SystemStructure(complete(var_to_diff), complete(eq_to_diff),
522526
complete(graph), nothing, var_types, false),
523-
Equation[], param_derivative_map, original_eqs, Equation[], typeof(sys)[])
527+
Equation[], param_derivative_map, no_deriv_params, original_eqs, Equation[], typeof(sys)[])
524528
end
525529

526530
function sort_fullvars(fullvars::Vector{SymbolicT}, dervaridxs::Vector{Int}, var_types::Vector{VariableType}, @nospecialize(iv::Union{SymbolicT, Nothing}))
@@ -594,7 +598,7 @@ function collect_vars_to_set!(buffer::Set{SymbolicT}, vars::Vector{SymbolicT})
594598
end
595599
end
596600

597-
function canonicalize_eq!(param_derivative_map::Dict{SymbolicT, SymbolicT}, eqs_to_retain::BitVector, ps::Set{SymbolicT}, @nospecialize(iv::Union{Nothing, SymbolicT}), i::Int, eq::Equation)
601+
function canonicalize_eq!(param_derivative_map::Dict{SymbolicT, SymbolicT}, no_deriv_params::Set{SymbolicT}, eqs_to_retain::BitVector, ps::Set{SymbolicT}, @nospecialize(iv::Union{Nothing, SymbolicT}), i::Int, eq::Equation)
598602
is_statemachine_equation = false
599603
lhs = eq.lhs
600604
rhs = eq.rhs
@@ -612,6 +616,8 @@ function canonicalize_eq!(param_derivative_map::Dict{SymbolicT, SymbolicT}, eqs_
612616
else
613617
# change the equation if the RHS is `missing` so the rest of this loop works
614618
eq = Symbolics.COMMON_ZERO ~ Symbolics.COMMON_ZERO
619+
push!(no_deriv_params, lhs)
620+
delete!(param_derivative_map, lhs)
615621
end
616622
eqs_to_retain[i] = false
617623
end

0 commit comments

Comments
 (0)