Skip to content

Commit a058736

Browse files
authored
Merge pull request #330 from hitonanode/kalman-filter
Multivariate Gaussian / Kalman filter
2 parents ee6082f + ce0b7dd commit a058736

File tree

2 files changed

+126
-0
lines changed

2 files changed

+126
-0
lines changed
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
#ifndef MULTIVARIATE_GAUSSIAN_HPP
2+
#define MULTIVARIATE_GAUSSIAN_HPP
3+
4+
#include <cassert>
5+
#include <vector>
6+
7+
// #include "linear_algebra_matrix/matrix.hpp"
8+
9+
// Multivariate Gausssian distribution / Kalman filter
10+
// 多変量正規分布の数値計算・カルマンフィルタ
11+
template <class Matrix> struct MultivariateGaussian {
12+
13+
// 正規分布 N(x, P)
14+
std::vector<double> x; // 期待値
15+
Matrix P; // 分散共分散行列
16+
17+
void set(const std::vector<double> &x0, const Matrix &P0) {
18+
const int dim = x0.size();
19+
assert(P0.height() == dim and P0.width() == dim);
20+
21+
x = x0;
22+
P = P0;
23+
}
24+
25+
// 加算
26+
// すなわち x <- x + dx
27+
void shift(const std::vector<double> &dx) {
28+
const int n = x.size();
29+
assert(dx.size() == n);
30+
31+
for (int i = 0; i < n; ++i) x.at(i) += dx.at(i);
32+
}
33+
34+
// F: 状態遷移行列 正方行列を想定
35+
// すなわち x <- Fx
36+
void linear_transform(const Matrix &F) {
37+
x = F * x;
38+
P = F * P * F.transpose();
39+
}
40+
41+
// Q: ゼロ平均ガウシアンノイズの分散共分散行列
42+
// すなわち x <- x + w, w ~ N(0, Q)
43+
void add_noise(const Matrix &Q) { P = P + Q; }
44+
45+
// 現在の x の分布を P(x | *) として、条件付き確率 P(x | *, z) で更新する
46+
// H: 観測行列, R: 観測ノイズの分散共分散行列, z: 観測値
47+
// すなわち z = Hx + v, v ~ N(0, R)
48+
void measure(const Matrix &H, const Matrix &R, const std::vector<double> &z,
49+
double regularlize = 1e-9) {
50+
const int nobs = z.size();
51+
52+
// 残差 e = z - Hx
53+
const auto Hx = H * x;
54+
std::vector<double> e(nobs);
55+
for (int i = 0; i < nobs; ++i) e.at(i) = z.at(i) - Hx.at(i);
56+
57+
// 残差共分散 S = R + H P H^T
58+
Matrix Sinv = R + H * P * H.transpose();
59+
Sinv = Sinv + Matrix::Identity(nobs) * regularlize; // 不安定かも?
60+
Sinv.inverse();
61+
62+
// カルマンゲイン K = P H^T S^-1
63+
Matrix K = P * H.transpose() * Sinv;
64+
65+
// Update x
66+
const auto dx = K * e;
67+
for (int i = 0; i < (int)x.size(); ++i) x.at(i) += dx.at(i);
68+
69+
P = P - K * H * P;
70+
}
71+
};
72+
73+
#endif

heuristic/multivariate_gaussian.md

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
---
2+
title: Multivariate Gaussian Distribution, Kalman filter / 多変量正規分布・カルマンフィルタ
3+
documentation_of: ./multivariate_gaussian.hpp
4+
---
5+
6+
多変量正規分布のパラメータを管理するクラス.線形変換・ノイズの加算・観測による事後確率の更新が行える.カルマンフィルタの実装に利用可能.
7+
8+
## 使用方法
9+
10+
線形システムのカルマンフィルタの実装例を以下に示す.
11+
12+
```cpp
13+
#include "linear_algebra_matrix/matrix.hpp"
14+
15+
// 初期化
16+
MultivariateGaussian<matrix<double>> kf;
17+
vector<double> mu(dim);
18+
matrix<double> Sigma(dim, dim);
19+
kf.set(mu, Sigma); // N(mu, Sigma) で初期化
20+
21+
// 以下の「時間発展」「雑音の付与」「制御信号の注入」「推定」を任意の順序で任意の回数行ってよい。
22+
23+
// 時間発展
24+
matrix<double> F(dim, dim); // 時間発展行列
25+
kf.linear_transform(F);
26+
27+
// 雑音の付与
28+
matrix<double> Q(dim, dim); // 正規雑音の分散・共分散行列
29+
kf.add_noise(Q);
30+
31+
// 制御信号の注入
32+
vector<double> u(dim); // 制御入力
33+
kf.shift(u);
34+
35+
// 観測
36+
matrix<double> H(o, dim); // 観測行列
37+
matrix<double> R(o, o); // 観測に重畳される正規雑音の分散・共分散行列
38+
vector<double> z(o); // 観測行列による観測結果
39+
double regularize = 1e-9; // 逆行列数値計算の安定のためのパラメータ
40+
kf.measure(H, R, z, regularize);
41+
42+
// 推定
43+
vector<double> est = kf.x;
44+
```
45+
46+
- 現在の MAP 推定量が知りたい -> mu を見ればよい
47+
- 周辺分布が欲しい -> mu と Sigma のうち特定の次元だけ取り出せばよい
48+
- 一部の次元だけ観測できた場合の残りの次元の条件付き分布が欲しい → 未実装です
49+
- サンプリングしたい → 未実装です
50+
51+
## 問題例
52+
53+
- [第一回マスターズ選手権 -決勝- A - Windy Drone Control (A)](https://atcoder.jp/contests/masters2024-final/tasks/masters2024_final_a)

0 commit comments

Comments
 (0)