|
| 1 | +#ifndef MULTIVARIATE_GAUSSIAN_HPP |
| 2 | +#define MULTIVARIATE_GAUSSIAN_HPP |
| 3 | + |
| 4 | +#include <cassert> |
| 5 | +#include <vector> |
| 6 | + |
| 7 | +// #include "linear_algebra_matrix/matrix.hpp" |
| 8 | + |
| 9 | +// Multivariate Gausssian distribution / Kalman filter |
| 10 | +// 多変量正規分布の数値計算・カルマンフィルタ |
| 11 | +template <class Matrix> struct MultivariateGaussian { |
| 12 | + |
| 13 | + // 正規分布 N(x, P) |
| 14 | + std::vector<double> x; // 期待値 |
| 15 | + Matrix P; // 分散共分散行列 |
| 16 | + |
| 17 | + void set(const std::vector<double> &x0, const Matrix &P0) { |
| 18 | + const int dim = x0.size(); |
| 19 | + assert(P0.height() == dim and P0.width() == dim); |
| 20 | + |
| 21 | + x = x0; |
| 22 | + P = P0; |
| 23 | + } |
| 24 | + |
| 25 | + // 加算 |
| 26 | + // すなわち x <- x + dx |
| 27 | + void shift(const std::vector<double> &dx) { |
| 28 | + const int n = x.size(); |
| 29 | + assert(dx.size() == n); |
| 30 | + |
| 31 | + for (int i = 0; i < n; ++i) x.at(i) += dx.at(i); |
| 32 | + } |
| 33 | + |
| 34 | + // F: 状態遷移行列 正方行列を想定 |
| 35 | + // すなわち x <- Fx |
| 36 | + void linear_transform(const Matrix &F) { |
| 37 | + x = F * x; |
| 38 | + P = F * P * F.transpose(); |
| 39 | + } |
| 40 | + |
| 41 | + // Q: ゼロ平均ガウシアンノイズの分散共分散行列 |
| 42 | + // すなわち x <- x + w, w ~ N(0, Q) |
| 43 | + void add_noise(const Matrix &Q) { P = P + Q; } |
| 44 | + |
| 45 | + // 現在の x の分布を P(x | *) として、条件付き確率 P(x | *, z) で更新する |
| 46 | + // H: 観測行列, R: 観測ノイズの分散共分散行列, z: 観測値 |
| 47 | + // すなわち z = Hx + v, v ~ N(0, R) |
| 48 | + void measure(const Matrix &H, const Matrix &R, const std::vector<double> &z, |
| 49 | + double regularlize = 1e-9) { |
| 50 | + const int nobs = z.size(); |
| 51 | + |
| 52 | + // 残差 e = z - Hx |
| 53 | + const auto Hx = H * x; |
| 54 | + std::vector<double> e(nobs); |
| 55 | + for (int i = 0; i < nobs; ++i) e.at(i) = z.at(i) - Hx.at(i); |
| 56 | + |
| 57 | + // 残差共分散 S = R + H P H^T |
| 58 | + Matrix Sinv = R + H * P * H.transpose(); |
| 59 | + Sinv = Sinv + Matrix::Identity(nobs) * regularlize; // 不安定かも? |
| 60 | + Sinv.inverse(); |
| 61 | + |
| 62 | + // カルマンゲイン K = P H^T S^-1 |
| 63 | + Matrix K = P * H.transpose() * Sinv; |
| 64 | + |
| 65 | + // Update x |
| 66 | + const auto dx = K * e; |
| 67 | + for (int i = 0; i < (int)x.size(); ++i) x.at(i) += dx.at(i); |
| 68 | + |
| 69 | + P = P - K * H * P; |
| 70 | + } |
| 71 | +}; |
| 72 | + |
| 73 | +#endif |
0 commit comments