Skip to content

Commit b30bcc8

Browse files
fix: make shift2term type-stable
1 parent f7231ba commit b30bcc8

File tree

2 files changed

+55
-40
lines changed

2 files changed

+55
-40
lines changed

src/structural_transformation/StructuralTransformations.jl

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,19 @@ module StructuralTransformations
33
using Setfield: @set!, @set
44
using UnPack: @unpack
55

6-
using Symbolics: unwrap, linear_expansion
6+
using Symbolics: unwrap, linear_expansion, VartypeT, SymbolicT
77
import Symbolics
88
using SymbolicUtils
9+
using SymbolicUtils: BSImpl
910
using SymbolicUtils.Code
1011
using SymbolicUtils.Rewriters
1112
using SymbolicUtils: maketerm, iscall
1213
import SymbolicUtils as SU
14+
import Moshi
1315

1416
using ModelingToolkit
1517
using ModelingToolkit: System, AbstractSystem, var_from_nested_derivative, Differential,
16-
unknowns, equations, vars, SymbolicT, diff2term_with_unit,
18+
unknowns, equations, vars, diff2term_with_unit,
1719
shift2term_with_unit, value,
1820
operation, arguments, simplify, symbolic_linear_solve,
1921
isdiffeq, isdifferential, isirreducible,
@@ -28,7 +30,8 @@ using ModelingToolkit: System, AbstractSystem, var_from_nested_derivative, Diffe
2830
filter_kwargs, lower_varname_with_unit,
2931
lower_shift_varname_with_unit, setio, SparseMatrixCLIL,
3032
get_fullvars, has_equations, observed,
31-
Schedule, schedule, iscomplete, get_schedule
33+
Schedule, schedule, iscomplete, get_schedule, VariableUnshifted,
34+
VariableShift
3235

3336
using ModelingToolkit.BipartiteGraphs
3437
import .BipartiteGraphs: invview, complete
@@ -41,7 +44,7 @@ using ModelingToolkit: algeqs, EquationsView,
4144
dervars_range, diffvars_range, algvars_range,
4245
DiffGraph, complete!,
4346
get_fullvars, system_subset
44-
using SymbolicIndexingInterface: symbolic_type, ArraySymbolic, NotSymbolic
47+
using SymbolicIndexingInterface: symbolic_type, ArraySymbolic, NotSymbolic, getname
4548

4649
using ModelingToolkit.DiffEqBase
4750
using ModelingToolkit.StaticArrays

src/structural_transformation/utils.jl

Lines changed: 48 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -503,43 +503,55 @@ end
503503
"""
504504
Rename a Shift variable with negative shift, Shift(t, k)(x(t)) to xₜ₋ₖ(t).
505505
"""
506-
function shift2term(var)
507-
iscall(var) || return var
508-
op = operation(var)
509-
op isa Shift || return var
510-
iv = op.t
511-
arg = only(arguments(var))
512-
if operation(arg) === getindex
513-
idxs = arguments(arg)[2:end]
514-
newvar = shift2term(op(first(arguments(arg))))[idxs...]
515-
unshifted = ModelingToolkit.getunshifted(newvar)[idxs...]
516-
newvar = setmetadata(newvar, ModelingToolkit.VariableUnshifted, unshifted)
517-
return newvar
506+
function shift2term(var::SymbolicT)
507+
Moshi.Match.@match var begin
508+
BSImpl.Term(f, args) && if f isa Shift end => begin
509+
op = f
510+
arg = args[1]
511+
Moshi.Match.@match arg begin
512+
BSImpl.Term(; f, args, type, shape, metadata) && if f === getindex end => begin
513+
newargs = copy(parent(args))
514+
newargs[1] = shift2term(op(newargs[1]))
515+
unshifted_args = copy(newargs)
516+
unshifted_args[1] = ModelingToolkit.getunshifted(newargs[1])
517+
unshifted = BSImpl.Term{VartypeT}(getindex, unshifted_args; type, shape, metadata)
518+
if metadata === nothing
519+
metadata = Base.ImmutableDict{DataType, Any}(VariableUnshifted, unshifted)
520+
elseif metadata isa Base.ImmutableDict{DataType, Any}
521+
metadata = Base.ImmutableDict(metadata, VariableUnshifted, unshifted)
522+
end
523+
return BSImpl.Term{VartypeT}(getindex, newargs; type, shape, metadata)
524+
end
525+
_ => nothing
526+
end
527+
unshifted = ModelingToolkit.getunshifted(arg)
528+
is_lowered = unshifted !== nothing
529+
backshift = op.steps + ModelingToolkit.getshift(arg)
530+
io = IOBuffer()
531+
O = (is_lowered ? unshifted : arg)::SymbolicT
532+
write(io, getname(O))
533+
# Char(0x209c) = ₜ
534+
write(io, Char(0x209c))
535+
# Char(0x208b) = ₋ (subscripted minus)
536+
# Char(0x208a) = ₊ (subscripted plus)
537+
pm = backshift > 0 ? Char(0x208a) : Char(0x208b)
538+
write(io, pm)
539+
backshift = abs(backshift)
540+
N = ndigits(backshift)
541+
den = 10 ^ (N - 1)
542+
for _ in 1:N
543+
# subscripted number, e.g. ₁
544+
write(io, Char(0x2080 + div(backshift, den) % 10))
545+
den = div(den, 10)
546+
end
547+
newname = Symbol(take!(io))
548+
newvar = Symbolics.rename(var, newname)
549+
newvar = setmetadata(newvar, ModelingToolkit.VariableUnshifted, O)
550+
newvar = setmetadata(newvar, ModelingToolkit.VariableShift, backshift)
551+
return newvar
552+
end
553+
_ => return var
518554
end
519-
is_lowered = !isnothing(ModelingToolkit.getunshifted(arg))
520-
521-
backshift = is_lowered ? op.steps + ModelingToolkit.getshift(arg) : op.steps
522-
523-
# Char(0x208b) = ₋ (subscripted minus)
524-
# Char(0x208a) = ₊ (subscripted plus)
525-
pm = backshift > 0 ? Char(0x208a) : Char(0x208b)
526-
# subscripted number, e.g. ₁
527-
num = join(Char(0x2080 + d) for d in reverse!(digits(abs(backshift))))
528-
# Char(0x209c) = ₜ
529-
# ds = ₜ₋₁
530-
ds = join([Char(0x209c), pm, num])
531-
532-
O = is_lowered ? ModelingToolkit.getunshifted(arg) : arg
533-
oldop = operation(O)
534-
newname = backshift != 0 ? Symbol(string(nameof(oldop)), ds) :
535-
Symbol(string(nameof(oldop)))
536-
537-
newvar = maketerm(typeof(O), Symbolics.rename(oldop, newname),
538-
arguments(O), SU.metadata(O))
539-
newvar = setmetadata(newvar, Symbolics.VariableSource, (:variables, newname))
540-
newvar = setmetadata(newvar, ModelingToolkit.VariableUnshifted, O)
541-
newvar = setmetadata(newvar, ModelingToolkit.VariableShift, backshift)
542-
return newvar
543555
end
544556

545557
function isdoubleshift(var)

0 commit comments

Comments
 (0)