Skip to content

Commit 82eb4ba

Browse files
fix: make default_toterm type stable
1 parent 692928b commit 82eb4ba

File tree

1 file changed

+20
-9
lines changed

1 file changed

+20
-9
lines changed

src/variables.jl

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -245,19 +245,30 @@ chosen as a state in `mtkcompile`.
245245
"""
246246
state_priority(x) = convert(Float64, getmetadata(x, VariableStatePriority, 0.0))::Float64
247247

248-
normalize_to_differential(x) = x
248+
function normalize_to_differential(@nospecialize(op))
249+
if op isa Shift && op.t isa SymbolicT
250+
return Differential(op.t) ^ op.steps
251+
else
252+
return op
253+
end
254+
end
249255

250-
function default_toterm(x)
251-
if iscall(x) && (op = operation(x)) isa Operator
252-
if !(op isa Differential)
253-
if op isa Shift && op.steps < 0
256+
default_toterm(x) = x
257+
function default_toterm(x::SymbolicT)
258+
Moshi.Match.@match x begin
259+
BSImpl.Term(; f, args, shape, type, metadata) && if f isa Operator end => begin
260+
if f isa Shift && f.steps < 0
254261
return shift2term(x)
262+
elseif f isa Differential
263+
return Symbolics.diff2term(x)
264+
else
265+
newf = normalize_to_differential(f)
266+
f === newf && return x
267+
x = BSImpl.Term{VartypeT}(newf, args; type, shape, metadata)
268+
return Symbolics.diff2term(x)
255269
end
256-
x = normalize_to_differential(op)(arguments(x)...)
257270
end
258-
Symbolics.diff2term(x)
259-
else
260-
x
271+
_ => return x
261272
end
262273
end
263274

0 commit comments

Comments
 (0)