Skip to content

Commit f7231ba

Browse files
fix: improve inference of several utility functions
1 parent 82eb4ba commit f7231ba

File tree

1 file changed

+43
-18
lines changed

1 file changed

+43
-18
lines changed

src/utils.jl

Lines changed: 43 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -526,13 +526,10 @@ ModelingToolkit.collect_applied_operators(eq, Differential) == Set([D(y)])
526526
527527
The difference compared to `collect_operator_variables` is that `collect_operator_variables` returns the variable without the operator applied.
528528
"""
529-
function collect_applied_operators(x, op)
530-
v = vars(x, op = op)
531-
filter(v) do x
532-
issym(x) && return false
533-
iscall(x) && return operation(x) isa op
534-
false
535-
end
529+
function collect_applied_operators(x::SymbolicT, ::Type{op}) where {op}
530+
v = Set{SymbolicT}()
531+
SU.search_variables!(v, x; is_atomic = OnlyOperatorIsAtomic{op}())
532+
return v
536533
end
537534

538535
"""
@@ -543,12 +540,12 @@ Search through equations and parameter dependencies of `sys`, where sys is at a
543540
recursively searches through all subsystems of `sys`, increasing the depth if it is not
544541
`-1`. A depth of `-1` indicates searching for variables with `GlobalScope`.
545542
"""
546-
function collect_scoped_vars!(unknowns, parameters, sys, iv; depth = 1, op = Differential)
543+
function collect_scoped_vars!(unknowns::OrderedSet{SymbolicT}, parameters::OrderedSet{SymbolicT}, sys::AbstractSystem, iv::Union{SymbolicT, Nothing}; depth = 1, op = Differential)
547544
if has_eqs(sys)
548545
for eq in equations(sys)
549546
eqtype_supports_collect_vars(eq) || continue
550547
if eq isa Equation
551-
eq.lhs isa Union{SymbolicT, Number} || continue
548+
symtype(eq.lhs) <: Number || continue
552549
end
553550
collect_vars!(unknowns, parameters, eq, iv; depth, op)
554551
end
@@ -622,6 +619,24 @@ function Base.showerror(io::IO, err::OperatorIndepvarMismatchError)
622619
end
623620
end
624621

622+
struct OnlyOperatorIsAtomic{O} end
623+
624+
function (::OnlyOperatorIsAtomic{O})(ex::SymbolicT) where {O}
625+
Moshi.Match.@match ex begin
626+
BSImpl.Term(; f) && if f isa O end => true
627+
_ => false
628+
end
629+
end
630+
631+
struct OperatorIsAtomic{O} end
632+
633+
function (::OperatorIsAtomic{O})(ex::SymbolicT) where {O}
634+
SU.default_is_atomic(ex) && Moshi.Match.@match ex begin
635+
BSImpl.Term(; f) && if f isa Operator end => f isa O
636+
_ => true
637+
end
638+
end
639+
625640
"""
626641
$(TYPEDSIGNATURES)
627642
@@ -636,12 +651,15 @@ can be checked using `check_scope_depth`.
636651
637652
This function should return `nothing`.
638653
"""
639-
function collect_vars!(unknowns, parameters, expr, iv; depth = 0, op = Symbolics.Operator)
640-
if issym(expr)
641-
return collect_var!(unknowns, parameters, expr, iv; depth)
654+
function collect_vars!(unknowns::OrderedSet{SymbolicT}, parameters::OrderedSet{SymbolicT}, expr::SymbolicT, iv::Union{SymbolicT, Nothing}; depth = 0, op = Symbolics.Operator)
655+
Moshi.Match.@match expr begin
656+
BSImpl.Const(;) => return
657+
BSImpl.Sym(;) => return collect_var!(unknowns, parameters, expr, iv; depth)
658+
_ => nothing
642659
end
643-
SymbolicUtils.isconst(expr) && return
644-
for var in vars(expr; op)
660+
vars = Set{SymbolicT}()
661+
SU.search_variables!(vars, expr; is_atomic = OperatorIsAtomic{op}())
662+
for var in vars
645663
while iscall(var) && operation(var) isa op
646664
validate_operator(operation(var), arguments(var), iv; context = expr)
647665
var = arguments(var)[1]
@@ -651,6 +669,13 @@ function collect_vars!(unknowns, parameters, expr, iv; depth = 0, op = Symbolics
651669
return nothing
652670
end
653671

672+
function collect_vars!(unknowns::OrderedSet{SymbolicT}, parameters::OrderedSet{SymbolicT}, expr::AbstractArray{SymbolicT}, iv::Union{SymbolicT, Nothing}; depth = 0, op = Symbolics.Operator)
673+
for var in expr
674+
collect_vars!(unknowns, parameters, var, iv; depth, op)
675+
end
676+
return nothing
677+
end
678+
654679
"""
655680
$(TYPEDSIGNATURES)
656681
@@ -696,7 +721,7 @@ function collect_var!(unknowns, parameters, var, iv; depth = 0)
696721
wrapped symbolic variables.
697722
""")
698723
end
699-
check_scope_depth(getmetadata(var, SymScope, LocalScope()), depth) || return nothing
724+
check_scope_depth(getmetadata(var, SymScope, LocalScope())::AllScopes, depth) || return nothing
700725
var = setmetadata(var, SymScope, LocalScope())
701726
if iscalledparameter(var)
702727
callable = getcalledparameter(var)
@@ -724,7 +749,7 @@ function check_scope_depth(scope, depth)
724749
if scope isa LocalScope
725750
return depth == 0
726751
elseif scope isa ParentScope
727-
return depth > 0 && check_scope_depth(scope.parent, depth - 1)
752+
return depth > 0 && check_scope_depth(scope.parent, depth - 1)::Bool
728753
elseif scope isa GlobalScope
729754
return depth == -1
730755
end
@@ -838,8 +863,8 @@ end
838863
Check if `T` is an appropriate symtype for a symbolic variable representing a floating
839864
point number or array of such numbers.
840865
"""
841-
function is_floatingpoint_symtype(T::Type)
842-
return T == Real || T == Number || T == Complex || T <: AbstractFloat ||
866+
function is_floatingpoint_symtype(T)
867+
return T === Real || T === Number || T === Complex || T <: AbstractFloat ||
843868
T <: AbstractArray && is_floatingpoint_symtype(eltype(T))
844869
end
845870

0 commit comments

Comments
 (0)