Skip to content

Commit 8cf9ef5

Browse files
committed
Simplified iterator to solve a linear equation system (using a while-loop instead of a for-loop)
1 parent 494e71f commit 8cf9ef5

File tree

2 files changed

+100
-139
lines changed

2 files changed

+100
-139
lines changed

src/EquationAndStateInfo.jl

Lines changed: 89 additions & 126 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ Define linear equation system "A*x=b" with `length(x) = sum(vTear_lengths)`.
6161
If `A_is_constant = true` then `A` is a matrix that is constant after
6262
initialization.
6363
64-
For details of its usage for code generation see [`LinearEquationsIterator`](@ref).
64+
For details of its usage for code generation see [`LinearEquationsIteration`](@ref).
6565
"""
6666
mutable struct LinearEquations{FloatType <: Real}
6767
A_is_constant::Bool # = true, if A-matrix is constant
@@ -107,22 +107,19 @@ end
107107
LinearEquations(args...) = LinearEquations{Float64}(args...)
108108

109109

110-
111110
"""
112-
leqIterator = LinearEquationsIterator{FloatType}(leq::LinearEquations)
111+
iterating = LinearEquationsIteration(leq::LinearEquations{FloatType}, isInitial::Bool, time)
113112
114-
Return instance of a struct to iterate over LinearEquations with Base.iterate(leq::LinearEquationsIterator)
115-
with the following for-loop:
113+
This function solves a linear equation system in residual form "residual = A*x - b"
114+
by iterating with a while loop over this system:
116115
117116
```julia
118117
function getDerivatives!(_der_x, _x, _m, _time)::Nothing
119118
_leq::Union{Nothing,LinearEquations{FloatType}} = nothing
120119
...
121-
_leq = _m.linearEquations[<nr>] # leq::LinearEquations{FloatType}
122-
for _leq_evaluate = LinearEquationsIterator(_leq, _m.isInitial, _m.time)
123-
if !_leq_evaluate
124-
break
125-
end
120+
_leq = _m.linearEquations[<nr>] # leq::LinearEquations{FloatType}
121+
_leq.mode = -2 # initializes the iteration
122+
while LinearEquationsIteration(_leq, _m.isInitial, _m.time)
126123
v_tear1 = _leq.vTear_value[1:3]
127124
v_tear2 = _leq.vTear_value[4]
128125
...
@@ -135,31 +132,14 @@ function getDerivatives!(_der_x, _x, _m, _time)::Nothing
135132
end
136133
```
137134
"""
138-
struct LinearEquationsIterator{FloatType <: Real}
139-
leq::LinearEquations{FloatType}
140-
isInitial::Bool
141-
time
142-
end
143-
144-
const ncalls = [0]
145-
146-
function Base.iterate(iterator::LinearEquationsIterator{FloatType}, mode::Int = -3) where {FloatType <: Real}
135+
function LinearEquationsIteration(leq::LinearEquations{FloatType}, isInitial::Bool, time)::Bool where {FloatType}
147136
#=
148-
for item in iterator
149-
# body
150-
end
151-
152-
is translated to
153-
154-
next = iterate(iterator)
155-
while next != nothing
156-
(item, mode) = next
157-
# body
158-
next = iterate(iterator, mode)
159-
end
137+
while LinearEquationsIteration(leq, isInitial,time)
138+
# body of while-loop
139+
end
160140
161141
162-
iterate(iterator,mode=0) to solve "residuals = A*x - b" where
142+
Solve "residuals = A*x - b" where
163143
A and b can be functions of positive(c_i'*x - d_i). In this case this system
164144
is solved by a fixed point iteration scheme, that is "A*x = b" is solved
165145
until potentially present positive(c_i'*x - d_i) calls are consistent to x.
@@ -168,81 +148,71 @@ function Base.iterate(iterator::LinearEquationsIterator{FloatType}, mode::Int =
168148
(it might be that simulation is still successful, even if "x" and
169149
positive(..) are temporarily not consistent to each other).
170150
171-
Note, the current values of A,x,b,residuals are stored in iterator.leq.
151+
Note, the current values of A,x,b,residuals are stored in leq.
172152
If A is fixed after initialization (leq.A_is_constant = true), then A is computed
173153
only once at initialization, the LU decomposition of A is stored in leq
174154
and used in subsequent calls to solve the equation system.
175155
176156
177-
On input (leq = iterator.leq; x = leq.vTear_value):
178-
179-
mode = -3: # Initialize fixed point iteration
157+
On input (mode = leq.mode, x = leq.vTear_value):
158+
159+
mode = -2: # Initialize fixed point iteration
180160
leq.niter = 0
181-
leq.success = false
182-
goto mode = 0
161+
leq.success = false
162+
leq.mode = 0
163+
x .= 0
164+
return true
183165
184-
mode = 0: # Start next (fixed point) iteration
166+
mode = -1: if leq.success
167+
return false # Terminate while-loop
168+
elseif leq.niter > 20
169+
<warning>
170+
return false # Terminate while-loop
171+
end
172+
leq.mode = 0 # Re-initialize next solver-iteration
185173
x .= 0
186-
leq.mode = 0
187-
return (true, 1) # next mode = 1 (compute "b" after the next iteration)
188-
189-
mode = 1: b = -residuals # Compute "b".
190-
x[1] = 1.0
191-
leq.mode = 1
192-
return (true, 2) # next mode = 2 (compute A[:,1] after the next iteration)
174+
return true
175+
176+
mode = 0: b = -residuals # Compute b
177+
mode = 1 # Compute A[:,1] after the next iteration
178+
x[1] = 1.0
179+
return true
193180
194-
mode > 1: j = mode - 1
195-
A[:,j] = residuals + b # Compute A[:,j].
196-
x[j] = 0.0
197-
if j < length(x)
198-
x[j+1] = 1.0
199-
leq.mode = mode
200-
return (true, mode+1) # next mode = mode + 1 (compute A[:,j+1] after the next iteration)
201-
elseif j == length(x)
202-
x = A\b # Solve linear equation system A*x = b for x.
203-
leq.mode = -1 # Compute all variables IN the next iteration as function of x
204-
leq.niter += 1 # Increment number of iterations to solve A*x = b
205-
leq.success = true # Terminate for-loop at the beginning of the next iteration,
206-
# provided no positive(..) call changes its value.
207-
# (leq.success is set to false in positive(..), if the return value changes).
208-
return (true, -2) # next mode = -2 (either terminate for-loop or re-initialize the next solver-iteration)
209-
end
210-
211-
mode = -2: if leq.success
212-
return (false, 0) # Terminate for-loop
213-
elseif leq.niter > 20
214-
<warning>
215-
return (false, 0) # Terminate for-loop
216-
end
217-
goto mode = 0 # Re-initialize next solver-iteration
181+
mode > 1: j = mode
182+
A[:,j] = residuals + b # Compute A[:,j].
183+
x[j] = 0.0
184+
if j < nx
185+
x[j+1] = 1.0
186+
leq.mode = j+1
187+
return true # Compute A[:,j+1] after the next iteration
188+
elseif j == nx
189+
x = A\b # Solve linear equation system A*x = b for x.
190+
leq.mode = -1 # Compute all variables IN the next iteration as function of x
191+
leq.niter += 1 # Increment number of iterations to solve A*x = b
192+
leq.success = true # Terminate for-loop at the beginning of the next iteration,
193+
# provided no positive(..) call changes its value.
194+
# (leq.success is set to false in positive(..), if the return value changes).
195+
return true
196+
end
218197
=#
198+
mode = leq.mode
199+
x = leq.vTear_value
200+
nx = length(x)
219201

220-
#=
221-
println("... 1, mode = $mode, time = $(iterator.time)")
222-
if iterator.isInitial
223-
ncalls[1] = 0
224-
else
225-
ncalls[1] += 1
226-
if ncalls[1] > 20000
227-
error("... too many calls")
228-
end
229-
end
230-
=#
231-
232-
233-
leq = iterator.leq
234-
if mode == -3
202+
if mode == -2
235203
# Initialize fixed point iteration
236204
leq.niter = 0
237205
leq.success = false
238206
empty!(leq.inconsistentPositive)
239207
empty!(leq.inconsistentNegative)
240-
mode = 0
208+
leq.mode = 0
209+
x .= 0
210+
return true
241211

242-
elseif mode == -2
212+
elseif mode == -1
243213
# Either terminate fixed point iteration, or start a new iteration
244214
if leq.success
245-
return (false,0)
215+
return false
246216
elseif leq.niter > niter_max
247217
str = ""
248218
if length(leq.inconsistentPositive) > 0
@@ -254,35 +224,27 @@ function Base.iterate(iterator::LinearEquationsIterator{FloatType}, mode::Int =
254224
end
255225
str = str * "negative(expr) is inconsistent for expr = $(leq.inconsistentNegative)."
256226
end
257-
@warn "At time = $(iterator.time), no consistent solution found for mixed linear equation system.\n" *
227+
@warn "At time = $time, no consistent solution found for mixed linear equation system.\n" *
258228
"Simulation is continued although some variables might not be correct at this time instant.\n$str"
259-
return (false,0)
229+
return false
260230
end
261-
mode = 0
262-
end
263-
264-
x = leq.vTear_value
265-
if mode == 0
266-
# Re-initialize iteration variables and compute b-vector after next iteration
267-
x .= 0
268231
leq.mode = 0
269-
return (true,1)
232+
x .= 0
233+
return true
234+
235+
elseif mode < -2 || mode > nx
236+
@goto ERROR
270237
end
271238

272-
nx = length(x)
273239
A = leq.A
274240
b = leq.b
275241
nResiduals = leq.nResiduals
276242
residuals = leq.residuals
277243
residual_value = leq.residual_value
278244
residual_unitRanges = leq.residual_unitRanges
279245
residual_indices = leq.residual_indices
280-
281-
if mode < 1 || mode > nx+1
282-
@goto ERROR
283-
end
284246

285-
if iterator.isInitial && mode == 1
247+
if isInitial && mode == 0
286248
# Construct unit ranges for the residual variables vector to copy values into the residuals vector
287249
j = 1
288250
for i = 1:nResiduals
@@ -319,30 +281,32 @@ function Base.iterate(iterator::LinearEquationsIterator{FloatType}, mode::Int =
319281
end
320282
end
321283

322-
if !leq.A_is_constant || iterator.isInitial # A is not constant or A is constant and isInitial = true
323-
if mode == 1
324-
# Terminating code for mode = 0 (residuals = A*x - b -> b = -residuals)
284+
if !leq.A_is_constant || isInitial # A is not constant or A is constant and isInitial = true
285+
if mode == 0
286+
# residuals = A*x - b -> b = -residuals)
325287
for i = 1:nx
326288
b[i] = -residuals[i]
327289
end
328-
else
329-
# Terminating code for mode = 1..nx (residuals = A*x - b -> A[:,j] = residuals + b)
330-
j = mode-1
331-
for i = 1:nx
332-
A[i,j] = residuals[i] + b[i]
333-
end
334-
x[j] = 0
290+
leq.mode = 1
291+
x[1] = convert(FloatType, 1)
292+
return true
335293
end
294+
295+
# residuals = A*x - b -> A[:,j] = residuals + b)
296+
j = mode
297+
for i = 1:nx
298+
A[i,j] = residuals[i] + b[i]
299+
end
300+
x[j] = 0
336301

337-
if mode <= nx
338-
# Start code for mode = 1..nx
339-
x[mode] = 1
340-
leq.mode = mode
341-
return (true, mode+1)
302+
if j < nx
303+
leq.mode += 1
304+
x[leq.mode] = convert(FloatType, 1)
305+
return true
342306
end
343307

344-
# mode == nx+1; Solve linear equation system
345-
if length(x) == 1
308+
# Solve linear equation system
309+
if nx == 1
346310
x[1] = b[1]/A[1,1]
347311
if !isfinite(x[1])
348312
error("Linear scalar equation system is singular resulting in: ", leq.vTear_names[1], " = ", x[1])
@@ -353,13 +317,13 @@ function Base.iterate(iterator::LinearEquationsIterator{FloatType}, mode::Int =
353317
ldiv!(leq.luA, x)
354318
end
355319

356-
elseif leq.A_is_constant && !iterator.isInitial && mode == 1 # isInitial=false, LU decomposition of A is available in leq.luA
320+
elseif leq.A_is_constant && !isInitial # isInitial=false, LU decomposition of A is available in leq.luA
357321
for i = 1:nx
358322
x[i] = -residuals[i]
359323
end
360324

361325
# Solve linear equation system
362-
if length(x) == 1
326+
if nx == 1
363327
x[1] = x[1]/A[1,1]
364328
if !isfinite(x[1])
365329
error("Linear scalar equation system is singular resulting in: ", leq.vTear_names[1], " = ", x[1])
@@ -378,15 +342,14 @@ function Base.iterate(iterator::LinearEquationsIterator{FloatType}, mode::Int =
378342
leq.success = true # Terminate for-loop at the beginning of the next iteration,
379343
# provided no positive(..) call changes its value.
380344
# (leq.success is set to false in positive(..), if the return value changes).
381-
return (true, -2) # next mode = -2 (either terminate for-loop or re-initialize the next solver-iteration)
345+
return true
382346

383347
@label ERROR
384-
@error "Should not occur (Bug in file ModiaBase/src/EquationAndStateInfo.jl,\nmode = $mode; vTear = $(leq.vTear_names))."
348+
@error "Should not occur (Bug in file ModiaBase/src/EquationAndStateInfo.jl,\nleq.mode = $(leq.mode); vTear = $(leq.vTear_names), time=$time, isInitial=$isInitial)."
385349
end
386350

387351

388352

389-
390353
"""
391354
@enum EquationInfoStatus
392355

0 commit comments

Comments
 (0)