@@ -30,10 +30,8 @@ in the returned tuple, in which case the associated field will not be updated.
3030"""
3131struct ImperativeAffect
3232 f:: Any
33- obs:: Vector
34- obs_syms:: Vector{Symbol}
35- modified:: Vector
36- mod_syms:: Vector{Symbol}
33+ observed:: NamedTuple
34+ modified:: NamedTuple
3735 ctx:: Any
3836 skip_checks:: Bool
3937end
@@ -43,10 +41,7 @@ function ImperativeAffect(f;
4341 modified:: NamedTuple = NamedTuple {()} (()),
4442 ctx = nothing ,
4543 skip_checks = false )
46- ImperativeAffect (f,
47- collect (values (observed)), collect (keys (observed)),
48- collect (values (modified)), collect (keys (modified)),
49- ctx, skip_checks)
44+ ImperativeAffect (f, observed, modified, ctx, skip_checks)
5045end
5146function ImperativeAffect (f, modified:: NamedTuple ;
5247 observed:: NamedTuple = NamedTuple {()} (()), ctx = nothing , skip_checks = false )
@@ -68,61 +63,54 @@ function ImperativeAffect(; f, kwargs...)
6863end
6964
7065function Base. show (io:: IO , mfa:: ImperativeAffect )
71- obs_vals = join ( map ((ob, nm) -> " $ob => $nm " , mfa. obs, mfa . obs_syms), " , " )
72- mod_vals = join ( map ((md, nm) -> " $md => $nm " , mfa. modified, mfa . mod_syms), " , " )
66+ obs = mfa. observed
67+ mod = mfa. modified
7368 affect = mfa. f
7469 print (io,
75- " ImperativeAffect(observed: [$obs_vals ], modified: [$mod_vals ], affect:$affect )" )
70+ " ImperativeAffect(observed: [$(obs) ], modified: [$(mod) ], affect:$affect )" )
7671end
7772func (f:: ImperativeAffect ) = f. f
7873context (a:: ImperativeAffect ) = a. ctx
79- observed (a:: ImperativeAffect ) = a. obs
80- observed_syms (a:: ImperativeAffect ) = a. obs_syms
8174function discretes (a:: ImperativeAffect )
8275 Iterators. filter (ModelingToolkit. isparameter,
8376 Iterators. flatten (Iterators. map (
8477 x -> symbolic_type (x) == NotSymbolic () && x isa AbstractArray ? x : [x],
8578 a. modified)))
8679end
87- modified (a:: ImperativeAffect ) = a. modified
88- modified_syms (a:: ImperativeAffect ) = a. mod_syms
8980
9081function Base.:(== )(a1:: ImperativeAffect , a2:: ImperativeAffect )
91- isequal (a1. f, a2. f) && isequal (a1. obs , a2. obs) && isequal (a1 . modified, a2 . modified ) &&
92- isequal (a1. obs_syms , a2. obs_syms) && isequal (a1 . mod_syms, a2 . mod_syms ) &&
82+ isequal (a1. f, a2. f) && isequal (a1. observed , a2. observed ) &&
83+ isequal (a1. modified , a2. modified ) &&
9384 isequal (a1. ctx, a2. ctx)
9485end
9586
9687function Base. hash (a:: ImperativeAffect , s:: UInt )
9788 s = hash (a. f, s)
98- s = hash (a. obs, s)
99- s = hash (a. obs_syms, s)
89+ s = hash (a. observed, s)
10090 s = hash (a. modified, s)
101- s = hash (a. mod_syms, s)
10291 hash (a. ctx, s)
10392end
10493
10594namespace_affects (af:: ImperativeAffect , s) = namespace_affect (af, s)
106- function namespace_affect (affect :: ImperativeAffect , s)
107- rmn = []
108- for modded in modified (affect )
109- if symbolic_type (modded) == NotSymbolic () && modded isa AbstractArray
110- res = []
111- for m in modded
112- push! (res, renamespace (s, m))
113- end
114- push! (rmn, res )
95+
96+ function _namespace_nt (nt :: NamedTuple , s :: AbstractSystem )
97+ return NamedTuple {keys(nt)} ( _namespace_nt ( values (nt), s) )
98+ end
99+
100+ function _namespace_nt (nt :: Union{AbstractArray, Tuple} , s :: AbstractSystem )
101+ return map (nt) do v
102+ if symbolic_type (v) == NotSymbolic ()
103+ _namespace_nt (v, s )
115104 else
116- push! (rmn, renamespace (s, modded) )
105+ renamespace (s, v )
117106 end
118107 end
119- ImperativeAffect (func (affect),
120- namespace_expr .(observed (affect), (s,)),
121- observed_syms (affect),
122- rmn,
123- modified_syms (affect),
124- context (affect),
125- affect. skip_checks)
108+ end
109+
110+ function namespace_affect (affect:: ImperativeAffect , s)
111+ obs = _namespace_nt (affect. observed, s)
112+ mod = _namespace_nt (affect. modified, s)
113+ ImperativeAffect (affect. f, obs, mod, affect. ctx, affect. skip_checks)
126114end
127115
128116function invalid_variables (sys, expr)
@@ -139,21 +127,6 @@ function unassignable_variables(sys, expr)
139127 x -> ! any (isequal (x), assignable_syms), written)
140128end
141129
142- @generated function _generated_writeback (integ, setters:: NamedTuple{NS1, <:Tuple} ,
143- values:: NamedTuple{NS2, <:Tuple} ) where {NS1, NS2}
144- setter_exprs = []
145- for name in NS2
146- if ! (name in NS1)
147- missing_name = " Tried to write back to $name from affect; only declared states ($NS1 ) may be written to."
148- error (missing_name)
149- end
150- push! (setter_exprs, :(setters.$ name (integ, values.$ name)))
151- end
152- return :(begin
153- $ (setter_exprs... )
154- end )
155- end
156-
157130function check_assignable (sys, sym)
158131 if symbolic_type (sym) == ScalarSymbolic ()
159132 is_variable (sys, sym) || is_parameter (sys, sym)
@@ -167,6 +140,42 @@ function check_assignable(sys, sym)
167140 end
168141end
169142
143+ function _nt_check_valid (nt:: NamedTuple , s:: AbstractSystem , isobserved:: Bool )
144+ _nt_check_valid (values (nt), s, isobserved)
145+ end
146+
147+ function _nt_check_valid (
148+ nt:: Union{Tuple, AbstractArray} , s:: AbstractSystem , isobserved:: Bool )
149+ for v in nt
150+ if symbolic_type (v) == NotSymbolic ()
151+ _nt_check_valid (v, s, isobserved)
152+ continue
153+ end
154+ if ! isobserved && ! check_assignable (s, v)
155+ error ("""
156+ Expression $v cannot be assigned to; currently only unknowns and parameters may \
157+ be updated by an affect.
158+ """ )
159+ end
160+ invalid = invalid_variables (s, v)
161+ isempty (invalid) && continue
162+ name = isobserved ? " Observed" : " Modified"
163+ error ("""
164+ $name expression $(v) in affect refers to missing variable(s) $(invalid) ; \
165+ the variables may not have been added (e.g. if a component is missing).
166+ """ )
167+ end
168+ end
169+
170+ function _nt_check_overlap (nta:: NamedTuple , ntb:: NamedTuple )
171+ common = intersect (keys (nta), keys (ntb))
172+ isempty (common) && return
173+ @warn """
174+ The symbols $common are declared as both observed and modified; this is a code smell \
175+ because it becomes easy to confuse them and assign/not assign a value.
176+ """
177+ end
178+
170179function compile_functional_affect (
171180 affect:: ImperativeAffect , sys; reset_jumps = false , kwargs... )
172181 #=
@@ -176,93 +185,27 @@ function compile_functional_affect(
176185 call the affect method
177186 unpack and apply the resulting values
178187 =#
179- function check_dups (syms, exprs) # = (syms_dedup, exprs_dedup)
180- seen = Set {Symbol} ()
181- syms_dedup = []
182- exprs_dedup = []
183- for (sym, exp) in Iterators. zip (syms, exprs)
184- if ! in (sym, seen)
185- push! (syms_dedup, sym)
186- push! (exprs_dedup, exp)
187- push! (seen, sym)
188- elseif ! affect. skip_checks
189- @warn " Expression $(expr) is aliased as $sym , which has already been used. The first definition will be used."
190- end
191- end
192- return (syms_dedup, exprs_dedup)
193- end
194188
195- dvs = unknowns (sys)
196- ps = parameters (sys)
197-
198- obs_exprs = observed (affect)
199- if ! affect. skip_checks
200- for oexpr in obs_exprs
201- invalid_vars = invalid_variables (sys, oexpr)
202- if length (invalid_vars) > 0
203- error (" Observed equation $(oexpr) in affect refers to missing variable(s) $(invalid_vars) ; the variables may not have been added (e.g. if a component is missing)." )
204- end
205- end
206- end
207- obs_syms = observed_syms (affect)
208- obs_syms, obs_exprs = check_dups (obs_syms, obs_exprs)
209-
210- mod_exprs = modified (affect)
211189 if ! affect. skip_checks
212- for mexpr in mod_exprs
213- if ! check_assignable (sys, mexpr)
214- @warn (" Expression $mexpr cannot be assigned to; currently only unknowns and parameters may be updated by an affect." )
215- end
216- invalid_vars = unassignable_variables (sys, mexpr)
217- if length (invalid_vars) > 0
218- error (" Modified equation $(mexpr) in affect refers to missing variable(s) $(invalid_vars) ; the variables may not have been added (e.g. if a component is missing) or they may have been reduced away." )
219- end
220- end
221- end
222- mod_syms = modified_syms (affect)
223- mod_syms, mod_exprs = check_dups (mod_syms, mod_exprs)
224-
225- overlapping_syms = intersect (mod_syms, obs_syms)
226- if length (overlapping_syms) > 0 && ! affect. skip_checks
227- @warn " The symbols $overlapping_syms are declared as both observed and modified; this is a code smell because it becomes easy to confuse them and assign/not assign a value."
190+ _nt_check_valid (affect. observed, sys, true )
191+ _nt_check_valid (affect. modified, sys, false )
192+ _nt_check_overlap (affect. observed, affect. modified)
228193 end
229194
230195 # sanity checks done! now build the data and update function for observed values
231- mkzero (sz) =
232- if sz === ()
233- 0.0
234- else
235- zeros (sz)
236- end
237- obs_fun = build_explicit_observed_function (
238- sys, Symbolics. scalarize .(obs_exprs);
239- mkarray = (es, _) -> MakeTuple (es))
240- obs_sym_tuple = (obs_syms... ,)
241-
242- # okay so now to generate the stuff to assign it back into the system
243- mod_pairs = mod_exprs .=> mod_syms
244- mod_names = (mod_syms... ,)
245- mod_og_val_fun = build_explicit_observed_function (
246- sys, Symbolics. scalarize .(first .(mod_pairs));
247- mkarray = (es, _) -> MakeTuple (es))
196+ let user_affect = func (affect), ctx = context (affect),
197+ obs_getter = isempty (affect. observed) ? Returns ((;)) : getsym (sys, affect. observed),
198+ mod_getter = isempty (affect. modified) ? Returns ((;)) : getsym (sys, affect. modified),
199+ mod_setter = isempty (affect. modified) ? Returns ((;)) : setsym (sys, affect. modified),
200+ reset_jumps = reset_jumps
248201
249- upd_funs = NamedTuple {mod_names} ((setu .((sys,), first .(mod_pairs))... ,))
250-
251- let user_affect = func (affect), ctx = context (affect), reset_jumps = reset_jumps
252202 @inline function (integ)
253- # update the to-be-mutated values; this ensures that if you do a no-op then nothing happens
254- modvals = mod_og_val_fun (integ. u, integ. p, integ. t)
255- upd_component_array = NamedTuple {mod_names} (modvals)
256-
257- # update the observed values
258- obs_component_array = NamedTuple {obs_sym_tuple} (obs_fun (
259- integ. u, integ. p, integ. t))
203+ mod = mod_getter (integ)
204+ obs = obs_getter (integ)
260205
261206 # let the user do their thing
262- upd_vals = user_affect (upd_component_array, obs_component_array, ctx, integ)
263-
264- # write the new values back to the integrator
265- _generated_writeback (integ, upd_funs, upd_vals)
207+ upd_vals = user_affect (mod, obs, ctx, integ)
208+ mod_setter (integ, upd_vals)
266209
267210 reset_jumps && reset_aggregated_jumps! (integ)
268211 end
@@ -271,19 +214,22 @@ end
271214
272215scalarize_affects (affects:: ImperativeAffect ) = affects
273216
274- function vars! (vars, aff:: ImperativeAffect ; op = Differential)
275- for var in Iterators. flatten ((observed (aff), modified (aff)))
276- if symbolic_type (var) == NotSymbolic ()
277- if var isa AbstractArray
278- for v in var
279- v = unwrap (v)
280- vars! (vars, v)
281- end
282- end
283- else
284- var = unwrap (var)
285- vars! (vars, var)
217+ function _vars_nt! (vars, nt:: NamedTuple , op)
218+ _vars_nt! (vars, values (nt), op)
219+ end
220+
221+ function _vars_nt! (vars, nt:: Union{AbstractArray, Tuple} , op)
222+ for v in nt
223+ if symbolic_type (v) == NotSymbolic ()
224+ _vars_nt! (vars, v, op)
225+ continue
286226 end
227+ vars! (vars, v; op)
287228 end
229+ end
230+
231+ function vars! (vars, aff:: ImperativeAffect ; op = Differential)
232+ _vars_nt! (vars, aff. observed, op)
233+ _vars_nt! (vars, aff. modified, op)
288234 return vars
289235end
0 commit comments