Skip to content

Commit c8ea8b3

Browse files
committed
reduce allocation KalmanFilter and ExtendedKalmanFilter
1 parent 9704c9b commit c8ea8b3

File tree

2 files changed

+32
-23
lines changed

2 files changed

+32
-23
lines changed

src/estimator/kalman.jl

Lines changed: 27 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,6 @@ struct KalmanFilter{NT<:Real, SM<:LinModel} <: StateEstimator{NT}
271271
::Hermitian{NT, Matrix{NT}}
272272
::Hermitian{NT, Matrix{NT}}
273273
::Matrix{NT}
274-
::Matrix{NT}
275274
direct::Bool
276275
corrected::Vector{Bool}
277276
buffer::StateEstimatorBuffer{NT}
@@ -291,7 +290,6 @@ struct KalmanFilter{NT<:Real, SM<:LinModel} <: StateEstimator{NT}
291290
P̂_0 = Hermitian(P̂_0, :L)
292291
= copy(P̂_0)
293292
= zeros(NT, nx̂, nym)
294-
= Hermitian(zeros(NT, nym, nym), :L)
295293
corrected = [false]
296294
buffer = StateEstimatorBuffer{NT}(nu, nx̂, nym, ny, nd)
297295
return new{NT, SM}(
@@ -301,7 +299,7 @@ struct KalmanFilter{NT<:Real, SM<:LinModel} <: StateEstimator{NT}
301299
As, Cs_u, Cs_y, nint_u, nint_ym,
302300
Â, B̂u, Ĉ, B̂d, D̂d,
303301
P̂_0, Q̂, R̂,
304-
K̂, M̂,
302+
K̂,
305303
direct, corrected,
306304
buffer
307305
)
@@ -783,7 +781,6 @@ struct ExtendedKalmanFilter{NT<:Real, SM<:SimModel} <: StateEstimator{NT}
783781
::Hermitian{NT, Matrix{NT}}
784782
::Hermitian{NT, Matrix{NT}}
785783
::Matrix{NT}
786-
::Matrix{NT}
787784
F̂_û::Matrix{NT}
788785
::Matrix{NT}
789786
direct::Bool
@@ -806,7 +803,6 @@ struct ExtendedKalmanFilter{NT<:Real, SM<:SimModel} <: StateEstimator{NT}
806803
= Hermitian(R̂, :L)
807804
= copy(P̂_0)
808805
= zeros(NT, nx̂, nym)
809-
= Hermitian(zeros(NT, nym, nym), :L)
810806
F̂_û, Ĥ = zeros(NT, nx̂+nu, nx̂), zeros(NT, ny, nx̂)
811807
corrected = [false]
812808
buffer = StateEstimatorBuffer{NT}(nu, nx̂, nym, ny, nd)
@@ -817,7 +813,7 @@ struct ExtendedKalmanFilter{NT<:Real, SM<:SimModel} <: StateEstimator{NT}
817813
As, Cs_u, Cs_y, nint_u, nint_ym,
818814
Â, B̂u, Ĉ, B̂d, D̂d,
819815
P̂_0, Q̂, R̂,
820-
K̂, M̂,
816+
K̂,
821817
F̂_û, Ĥ,
822818
direct, corrected,
823819
buffer
@@ -991,40 +987,44 @@ function validate_kfcov(nym, nx̂, Q̂, R̂, P̂_0=nothing)
991987
end
992988

993989
"""
994-
correct_estimate_kf!(estim::StateEstimator, y0m, d0, Ĉm)
990+
correct_estimate_kf!(estim::Union{KalmanFilter, ExtendedKalmanFilter}, y0m, d0, Ĉm)
995991
996992
Correct time-varying/extended Kalman Filter estimates with augmented `Ĉm` matrices.
997993
998994
Allows code reuse for [`KalmanFilter`](@ref), [`ExtendedKalmanFilterKalmanFilter`](@ref).
999995
See [`update_estimate_kf!`](@ref) for more information.
1000996
"""
1001-
function correct_estimate_kf!(estim::StateEstimator, y0m, d0, Ĉm)
1002-
R̂, M̂, = estim., estim., estim.
997+
function correct_estimate_kf!(estim::Union{KalmanFilter, ExtendedKalmanFilter}, y0m, d0, Ĉm)
998+
R̂, K̂ = estim.R̂, estim.
1003999
x̂0, P̂ = estim.x̂0, estim.
1000+
# in-place operations to reduce allocations:
10041001
P̂_Ĉmᵀ =
10051002
mul!(P̂_Ĉmᵀ, P̂.data, Ĉm') # the ".data" weirdly removes a type instability in mul!
1006-
mul!(M̂, Ĉm, P̂_Ĉmᵀ)
1007-
.+=
1003+
= estim.buffer.
1004+
mul!(Ŝ, Ĉm, P̂_Ĉmᵀ)
1005+
.+=
10081006
= P̂_Ĉmᵀ
1009-
M̂_chol = cholesky!(Hermitian()) # also modifies
1007+
M̂_chol = cholesky!(Hermitian()) # also modifies
10101008
rdiv!(K̂, M̂_chol)
10111009
ŷ0 = estim.buffer.
10121010
ĥ!(ŷ0, estim, estim.model, x̂0, d0)
10131011
ŷ0m = @views ŷ0[estim.i_ym]
10141012
= ŷ0m
10151013
v̂ .= y0m .- ŷ0m
1016-
x̂0corr, P̂corr = estim.x̂0, estim.
1017-
mul!(x̂0corr, K̂, v̂, 1, 1)
1018-
I_minus_K̂_Ĉm = estim.buffer.
1014+
x̂0corr = x̂0
1015+
mul!(x̂0corr, K̂, v̂, 1, 1) # also modifies estim.x̂0
1016+
I_minus_K̂_Ĉm = estim.buffer.
10191017
mul!(I_minus_K̂_Ĉm, K̂, Ĉm)
10201018
lmul!(-1, I_minus_K̂_Ĉm)
1021-
I_minus_K̂_Ĉm[diagind(I_minus_K̂_Ĉm)] .+= 1 # compute I - K̂*Ĉm
1022-
P̂corr .= Hermitian(I_minus_K̂_Ĉm * P̂) # TODO: remove this allocation
1019+
I_minus_K̂_Ĉm[diagind(I_minus_K̂_Ĉm)] .+= 1 # compute I - K̂*Ĉm in-place
1020+
P̂corr = estim.buffer.
1021+
mul!(P̂corr, I_minus_K̂_Ĉm, P̂)
1022+
estim.P̂ .= Hermitian(P̂corr, :L)
10231023
return nothing
10241024
end
10251025

10261026
"""
1027-
update_estimate_kf!(estim::StateEstimator, y0m, d0, u0, Ĉm, Â)
1027+
update_estimate_kf!(estim::Union{KalmanFilter, ExtendedKalmanFilter}, y0m, d0, u0, Ĉm, Â)
10281028
10291029
Update time-varying/extended Kalman Filter estimates with augmented `Ĉm` and `Â` matrices.
10301030
@@ -1034,18 +1034,23 @@ substitutes the augmented model matrices with its Jacobians (`Â = F̂` and `C
10341034
The implementation uses in-place operations and explicit factorization to reduce
10351035
allocations. See e.g. [`KalmanFilter`](@ref) docstring for the equations.
10361036
"""
1037-
function update_estimate_kf!(estim::StateEstimator, y0m, d0, u0, Ĉm, Â)
1037+
function update_estimate_kf!(estim::Union{KalmanFilter, ExtendedKalmanFilter}, y0m, d0, u0, Ĉm, Â)
10381038
if !estim.direct
10391039
correct_estimate_kf!(estim, y0m, d0, Ĉm)
10401040
end
10411041
x̂0corr, P̂corr = estim.x̂0, estim.
10421042
= estim.
10431043
x̂0next, û0 = estim.buffer.x̂, estim.buffer.
1044+
# in-place operations to reduce allocations:
10441045
f̂!(x̂0next, û0, estim, estim.model, x̂0corr, u0, d0)
1045-
# TODO: use buffer.P̂ to reduce allocations
1046-
P̂next = Hermitian(Â * P̂corr *' .+ Q̂, :L)
1046+
P̂corr_Âᵀ = estim.buffer.
1047+
mul!(P̂corr_Âᵀ, P̂corr, Â')
1048+
Â_P̂corr_Âᵀ = estim.buffer.
1049+
mul!(Â_P̂corr_Âᵀ, Â, P̂corr_Âᵀ)
1050+
P̂next = estim.buffer.
1051+
P̂next .= Â_P̂corr_Âᵀ .+
10471052
x̂0next .+= estim.f̂op .- estim.x̂op
10481053
estim.x̂0 .= x̂0next
1049-
estim.P̂ .= P̂next
1054+
estim.P̂ .= Hermitian(P̂next, :L)
10501055
return nothing
10511056
end

src/state_estim.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ struct StateEstimatorBuffer{NT<:Real}
2323
::Vector{NT}
2424
::Vector{NT}
2525
::Matrix{NT}
26+
::Matrix{NT}
27+
::Matrix{NT}
2628
ym::Vector{NT}
2729
::Vector{NT}
2830
d ::Vector{NT}
@@ -43,11 +45,13 @@ function StateEstimatorBuffer{NT}(
4345
= Vector{NT}(undef, nu)
4446
= Vector{NT}(undef, nx̂)
4547
= Matrix{NT}(undef, nx̂, nx̂)
48+
= Matrix{NT}(undef, nx̂, nx̂)
49+
= Matrix{NT}(undef, nym, nym)
4650
ym = Vector{NT}(undef, nym)
4751
= Vector{NT}(undef, ny)
4852
d = Vector{NT}(undef, nd)
4953
empty = Vector{NT}(undef, 0)
50-
return StateEstimatorBuffer{NT}(u, û, x̂, P̂, ym, ŷ, d, empty)
54+
return StateEstimatorBuffer{NT}(u, û, x̂, P̂, Q̂, R̂, ym, ŷ, d, empty)
5155
end
5256

5357
const IntVectorOrInt = Union{Int, Vector{Int}}

0 commit comments

Comments
 (0)