Skip to content

Commit 2b99854

Browse files
authored
Use PermutationMatrix instead of indices (#475)
1 parent c0e3b06 commit 2b99854

File tree

7 files changed

+82
-64
lines changed

7 files changed

+82
-64
lines changed

include/albatross/src/cereal/eigen.hpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,29 @@ inline void load(Archive &archive,
9191
v.indices() = indices;
9292
}
9393

94+
template <class Archive, int SizeAtCompileTime, int MaxSizeAtCompileTime,
95+
typename _StorageIndex>
96+
inline void
97+
save(Archive &archive,
98+
const Eigen::PermutationMatrix<SizeAtCompileTime, MaxSizeAtCompileTime,
99+
_StorageIndex> &v,
100+
const std::uint32_t) {
101+
archive(cereal::make_nvp("indices", v.indices()));
102+
}
103+
104+
template <class Archive, int SizeAtCompileTime, int MaxSizeAtCompileTime,
105+
typename _StorageIndex>
106+
inline void
107+
load(Archive &archive,
108+
Eigen::PermutationMatrix<SizeAtCompileTime, MaxSizeAtCompileTime,
109+
_StorageIndex> &v,
110+
const std::uint32_t) {
111+
typename Eigen::PermutationMatrix<SizeAtCompileTime, MaxSizeAtCompileTime,
112+
_StorageIndex>::IndicesType indices;
113+
archive(cereal::make_nvp("indices", indices));
114+
v.indices() = indices;
115+
}
116+
94117
template <typename Archive, typename _Scalar, int SizeAtCompileTime>
95118
inline void serialize(Archive &archive,
96119
Eigen::DiagonalMatrix<_Scalar, SizeAtCompileTime> &matrix,

include/albatross/src/cereal/gp.hpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,8 @@ inline void serialize(Archive &archive, Fit<SparseGPFit<FeatureType>> &fit,
4141
archive(cereal::make_nvp("information", fit.information));
4242
archive(cereal::make_nvp("train_covariance", fit.train_covariance));
4343
archive(cereal::make_nvp("train_features", fit.train_features));
44-
archive(cereal::make_nvp("sigma_R", fit.sigma_R));
45-
archive(cereal::make_nvp("permutation_indices", fit.permutation_indices));
44+
archive(cereal::make_nvp("R", fit.R));
45+
archive(cereal::make_nvp("P", fit.P));
4646
if (version > 1) {
4747
archive(cereal::make_nvp("numerical_rank", fit.numerical_rank));
4848
} else {
@@ -53,19 +53,19 @@ inline void serialize(Archive &archive, Fit<SparseGPFit<FeatureType>> &fit,
5353

5454
template <typename Archive, typename CovFunc, typename MeanFunc,
5555
typename ImplType>
56-
void save(Archive &archive,
57-
const GaussianProcessBase<CovFunc, MeanFunc, ImplType> &gp,
58-
const std::uint32_t) {
56+
inline void save(Archive &archive,
57+
const GaussianProcessBase<CovFunc, MeanFunc, ImplType> &gp,
58+
const std::uint32_t) {
5959
archive(cereal::make_nvp("name", gp.get_name()));
6060
archive(cereal::make_nvp("params", gp.get_params()));
6161
archive(cereal::make_nvp("insights", gp.insights));
6262
}
6363

6464
template <typename Archive, typename CovFunc, typename MeanFunc,
6565
typename ImplType>
66-
void load(Archive &archive,
67-
GaussianProcessBase<CovFunc, MeanFunc, ImplType> &gp,
68-
const std::uint32_t version) {
66+
inline void load(Archive &archive,
67+
GaussianProcessBase<CovFunc, MeanFunc, ImplType> &gp,
68+
const std::uint32_t version) {
6969
if (version > 0) {
7070
std::string model_name;
7171
archive(cereal::make_nvp("name", model_name));

include/albatross/src/core/declarations.hpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,13 @@ template <typename... Ts> class variant;
2121

2222
using mapbox::util::variant;
2323

24+
/*
25+
* Permutations
26+
*/
27+
namespace Eigen {
28+
using PermutationMatrixX = PermutationMatrix<Dynamic, Dynamic, Index>;
29+
}
30+
2431
namespace albatross {
2532

2633
/*

include/albatross/src/linalg/qr_utils.hpp

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -25,29 +25,31 @@ get_R(const Eigen::ColPivHouseholderQR<Eigen::MatrixXd> &qr) {
2525
.template triangularView<Eigen::Upper>();
2626
}
2727

28+
inline Eigen::PermutationMatrixX
29+
get_P(const Eigen::ColPivHouseholderQR<Eigen::MatrixXd> &qr) {
30+
return Eigen::PermutationMatrixX(
31+
qr.colsPermutation().indices().template cast<Eigen::Index>());
32+
}
33+
2834
/*
2935
* Computes R^-T P^T rhs given R and P from a QR decomposition.
3036
*/
31-
template <typename MatrixType, typename PermutationIndicesType>
37+
template <typename MatrixType, typename PermutationScalar>
3238
inline Eigen::MatrixXd
3339
sqrt_solve(const Eigen::MatrixXd &R,
34-
const PermutationIndicesType &permutation_indices,
40+
const Eigen::PermutationMatrix<Eigen::Dynamic, Eigen::Dynamic,
41+
PermutationScalar> &P,
3542
const MatrixType &rhs) {
36-
37-
Eigen::MatrixXd sqrt(rhs.rows(), rhs.cols());
38-
for (Eigen::Index i = 0; i < permutation_indices.size(); ++i) {
39-
sqrt.row(i) = rhs.row(permutation_indices.coeff(i));
40-
}
41-
sqrt = R.template triangularView<Eigen::Upper>().transpose().solve(sqrt);
42-
return sqrt;
43+
return R.template triangularView<Eigen::Upper>().transpose().solve(
44+
P.transpose() * rhs);
4345
}
4446

4547
template <typename MatrixType>
4648
inline Eigen::MatrixXd
4749
sqrt_solve(const Eigen::ColPivHouseholderQR<Eigen::MatrixXd> &qr,
4850
const MatrixType &rhs) {
4951
const Eigen::MatrixXd R = get_R(qr);
50-
return sqrt_solve(R, qr.colsPermutation().indices(), rhs);
52+
return sqrt_solve(R, qr.colsPermutation(), rhs);
5153
}
5254

5355
} // namespace albatross

include/albatross/src/linalg/spqr_utils.hpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,19 +19,20 @@ using SparseMatrix = Eigen::SparseMatrix<double>;
1919

2020
using SPQR = Eigen::SPQR<SparseMatrix>;
2121

22-
using SparsePermutationMatrix =
23-
Eigen::PermutationMatrix<Eigen::Dynamic, Eigen::Dynamic,
24-
SPQR::StorageIndex>;
25-
2622
inline Eigen::MatrixXd get_R(const SPQR &qr) {
2723
return qr.matrixR()
2824
.topLeftCorner(qr.cols(), qr.cols())
2925
.template triangularView<Eigen::Upper>();
3026
}
3127

28+
inline Eigen::PermutationMatrixX get_P(const SPQR &qr) {
29+
return Eigen::PermutationMatrixX(
30+
qr.colsPermutation().indices().template cast<Eigen::Index>());
31+
}
32+
3233
template <typename MatrixType>
3334
inline Eigen::MatrixXd sqrt_solve(const SPQR &qr, const MatrixType &rhs) {
34-
return sqrt_solve(get_R(qr), qr.colsPermutation().indices(), rhs);
35+
return sqrt_solve(get_R(qr), get_P(qr), rhs);
3536
}
3637

3738
// Matrices with any dimension smaller than this will use a special

include/albatross/src/models/sparse_gp.hpp

Lines changed: 22 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -97,21 +97,19 @@ template <typename FeatureType> struct Fit<SparseGPFit<FeatureType>> {
9797

9898
std::vector<FeatureType> train_features;
9999
Eigen::SerializableLDLT train_covariance;
100-
Eigen::MatrixXd sigma_R;
101-
PermutationIndices permutation_indices;
100+
Eigen::MatrixXd R;
101+
Eigen::PermutationMatrixX P;
102102
Eigen::VectorXd information;
103103
Eigen::Index numerical_rank;
104104

105105
Fit(){};
106106

107107
Fit(const std::vector<FeatureType> &features_,
108108
const Eigen::SerializableLDLT &train_covariance_,
109-
const Eigen::MatrixXd &sigma_R_,
110-
PermutationIndices &&permutation_indices_,
109+
const Eigen::MatrixXd &R_, const Eigen::PermutationMatrixX &P_,
111110
const Eigen::VectorXd &information_, Eigen::Index numerical_rank_)
112-
: train_features(features_), train_covariance(train_covariance_),
113-
sigma_R(sigma_R_), permutation_indices(std::move(permutation_indices_)),
114-
information(information_), numerical_rank(numerical_rank_) {}
111+
: train_features(features_), train_covariance(train_covariance_), R(R_),
112+
P(P_), information(information_), numerical_rank(numerical_rank_) {}
115113

116114
void shift_mean(const Eigen::VectorXd &mean_shift) {
117115
ALBATROSS_ASSERT(mean_shift.size() == information.size());
@@ -120,9 +118,8 @@ template <typename FeatureType> struct Fit<SparseGPFit<FeatureType>> {
120118

121119
bool operator==(const Fit<SparseGPFit<FeatureType>> &other) const {
122120
return (train_features == other.train_features &&
123-
train_covariance == other.train_covariance &&
124-
sigma_R == other.sigma_R &&
125-
permutation_indices == other.permutation_indices &&
121+
train_covariance == other.train_covariance && R == other.R &&
122+
P.indices() == other.P.indices() &&
126123
information == other.information &&
127124
numerical_rank == other.numerical_rank);
128125
}
@@ -325,20 +322,17 @@ class SparseGaussianProcessRegression
325322
compute_internal_components(old_fit.train_features, features, targets,
326323
&A_ldlt, &K_uu_ldlt, &K_fu, &y);
327324

328-
const Eigen::Index n_old = old_fit.sigma_R.rows();
325+
const Eigen::Index n_old = old_fit.R.rows();
329326
const Eigen::Index n_new = A_ldlt.rows();
330-
const Eigen::Index k = old_fit.sigma_R.cols();
327+
const Eigen::Index k = old_fit.R.cols();
331328
Eigen::MatrixXd B = Eigen::MatrixXd::Zero(n_old + n_new, k);
332329

333330
ALBATROSS_ASSERT(n_old == k);
334331

335332
// Form:
336333
// B = |R_old P_old^T| = |Q_1| R P^T
337334
// |A^{-1/2} K_fu| |Q_2|
338-
for (Eigen::Index i = 0; i < old_fit.permutation_indices.size(); ++i) {
339-
const Eigen::Index &pi = old_fit.permutation_indices.coeff(i);
340-
B.col(pi).topRows(i + 1) = old_fit.sigma_R.col(i).topRows(i + 1);
341-
}
335+
B.topRows(old_fit.P.rows()) = old_fit.R * old_fit.P.transpose();
342336
B.bottomRows(n_new) = A_ldlt.sqrt_solve(K_fu);
343337
const auto B_qr = QRImplementation::compute(B, Base::threads_.get());
344338

@@ -347,13 +341,9 @@ class SparseGaussianProcessRegression
347341
// |A^{-1/2} y |
348342
ALBATROSS_ASSERT(old_fit.information.size() == n_old);
349343
Eigen::VectorXd y_augmented(n_old + n_new);
350-
for (Eigen::Index i = 0; i < old_fit.permutation_indices.size(); ++i) {
351-
y_augmented[i] =
352-
old_fit.information[old_fit.permutation_indices.coeff(i)];
353-
}
354344
y_augmented.topRows(n_old) =
355-
old_fit.sigma_R.template triangularView<Eigen::Upper>() *
356-
y_augmented.topRows(n_old);
345+
old_fit.R.template triangularView<Eigen::Upper>() *
346+
(old_fit.P.transpose() * old_fit.information);
357347

358348
y_augmented.bottomRows(n_new) = A_ldlt.sqrt_solve(y, Base::threads_.get());
359349
const Eigen::VectorXd v = B_qr->solve(y_augmented);
@@ -365,10 +355,9 @@ class SparseGaussianProcessRegression
365355
Eigen::VectorXd::Constant(B_qr->cols(), details::cSparseRNugget);
366356
}
367357
using FitType = Fit<SparseGPFit<InducingPointFeatureType>>;
368-
return FitType(
369-
old_fit.train_features, old_fit.train_covariance, R,
370-
B_qr->colsPermutation().indices().template cast<Eigen::Index>(), v,
371-
B_qr->rank());
358+
359+
return FitType(old_fit.train_features, old_fit.train_covariance, R,
360+
get_P(*B_qr), v, B_qr->rank());
372361
}
373362

374363
// Here we create the QR decomposition of:
@@ -415,10 +404,7 @@ class SparseGaussianProcessRegression
415404
using InducingPointFeatureType = typename std::decay<decltype(u[0])>::type;
416405

417406
using FitType = Fit<SparseGPFit<InducingPointFeatureType>>;
418-
return FitType(
419-
u, K_uu_ldlt, get_R(*B_qr),
420-
B_qr->colsPermutation().indices().template cast<Eigen::Index>(), v,
421-
B_qr->rank());
407+
return FitType(u, K_uu_ldlt, get_R(*B_qr), get_P(*B_qr), v, B_qr->rank());
422408
}
423409

424410
template <typename FeatureType>
@@ -471,9 +457,8 @@ class SparseGaussianProcessRegression
471457
const Eigen::MatrixXd sigma_inv_sqrt = C_ldlt.sqrt_solve(K_zz);
472458
const auto B_qr = QRImplementation::compute(sigma_inv_sqrt, nullptr);
473459

474-
new_fit.permutation_indices =
475-
B_qr->colsPermutation().indices().template cast<Eigen::Index>();
476-
new_fit.sigma_R = get_R(*B_qr);
460+
new_fit.P = get_P(*B_qr);
461+
new_fit.R = get_R(*B_qr);
477462
new_fit.numerical_rank = B_qr->rank();
478463

479464
return output;
@@ -519,8 +504,8 @@ class SparseGaussianProcessRegression
519504
Q_sqrt.cwiseProduct(Q_sqrt).array().colwise().sum();
520505
marginal_variance -= Q_diag;
521506

522-
const Eigen::MatrixXd S_sqrt = sqrt_solve(
523-
sparse_gp_fit.sigma_R, sparse_gp_fit.permutation_indices, cross_cov);
507+
const Eigen::MatrixXd S_sqrt =
508+
sqrt_solve(sparse_gp_fit.R, sparse_gp_fit.P, cross_cov);
524509
const Eigen::VectorXd S_diag =
525510
S_sqrt.cwiseProduct(S_sqrt).array().colwise().sum();
526511
marginal_variance += S_diag;
@@ -537,8 +522,8 @@ class SparseGaussianProcessRegression
537522
this->covariance_function_(sparse_gp_fit.train_features, features);
538523
const Eigen::MatrixXd prior_cov = this->covariance_function_(features);
539524

540-
const Eigen::MatrixXd S_sqrt = sqrt_solve(
541-
sparse_gp_fit.sigma_R, sparse_gp_fit.permutation_indices, cross_cov);
525+
const Eigen::MatrixXd S_sqrt =
526+
sqrt_solve(sparse_gp_fit.R, sparse_gp_fit.P, cross_cov);
542527

543528
const Eigen::MatrixXd Q_sqrt =
544529
sparse_gp_fit.train_covariance.sqrt_solve(cross_cov);

tests/test_sparse_gp.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -322,10 +322,10 @@ TYPED_TEST(SparseGaussianProcessTest, test_update) {
322322
(updated_in_place_pred.covariance - full_pred.covariance).norm();
323323

324324
auto compute_sigma = [](const auto &fit_model) -> Eigen::MatrixXd {
325-
const Eigen::Index n = fit_model.get_fit().sigma_R.cols();
326-
Eigen::MatrixXd sigma = sqrt_solve(fit_model.get_fit().sigma_R,
327-
fit_model.get_fit().permutation_indices,
328-
Eigen::MatrixXd::Identity(n, n));
325+
const Eigen::Index n = fit_model.get_fit().R.cols();
326+
Eigen::MatrixXd sigma =
327+
sqrt_solve(fit_model.get_fit().R, fit_model.get_fit().P,
328+
Eigen::MatrixXd::Identity(n, n));
329329
return sigma.transpose() * sigma;
330330
};
331331

0 commit comments

Comments
 (0)