Skip to content

Commit 9617994

Browse files
committed
New Rebase Approach
1 parent 8e2f132 commit 9617994

File tree

1 file changed

+75
-4
lines changed

1 file changed

+75
-4
lines changed

include/albatross/src/models/sparse_gp.hpp

Lines changed: 75 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -703,15 +703,86 @@ class SparseGaussianProcessRegression
703703

704704
// rebase_inducing_points takes a Sparse GP which was fit using some set of
705705
// inducing points and creates a new fit relative to new inducing points.
706+
//
706707
// Note that this will NOT be the equivalent to having fit the model with
707708
// the new inducing points since some information may have been lost in
708709
// the process.
709-
template <typename ModelType, typename FeatureType, typename NewFeatureType>
710+
//
711+
// For example, consider the extreme case where your first fit
712+
// doesn't have any inducing points at all, all the information from the first
713+
// observations will have been lost, and when you rebase on new inducing points
714+
// you'd have the prior for those new points.
715+
//
716+
// For implementation details see the online documentation.
717+
//
718+
// The summary involves:
719+
// - Compute K_nn = cov(new, new)
720+
// - Compute K_pn = cov(prev, new)
721+
// - Compute A = L_pp^-1 K_pn
722+
// - Solve for Lhat_nn = chol(K_nn - A^T A)
723+
// - Solve for QRP^T = [Lat_nn
724+
// R_p P_p^T L_pp^-T A]
725+
// - Solve for L_nn = chol(K_nn)
726+
// - Solve for v_n = K_nn^-1 K_np v_p
727+
//
728+
template <typename CovFunc, typename MeanFunc, typename GrouperFunction,
729+
typename InducingPointStrategy, typename QRImplementation,
730+
typename FeatureType, typename NewFeatureType>
710731
auto rebase_inducing_points(
711-
const FitModel<ModelType, Fit<SparseGPFit<FeatureType>>> &fit_model,
732+
const FitModel<SparseGaussianProcessRegression<
733+
CovFunc, MeanFunc, GrouperFunction,
734+
InducingPointStrategy, QRImplementation>,
735+
Fit<SparseGPFit<FeatureType>>> &fit_model,
712736
const std::vector<NewFeatureType> &new_inducing_points) {
713-
return fit_model.get_model().fit_from_prediction(
714-
new_inducing_points, fit_model.predict(new_inducing_points).joint());
737+
738+
const auto &cov = fit_model.get_model().get_covariance();
739+
// Compute K_nn = cov(new, new)
740+
const Eigen::MatrixXd K_nn =
741+
cov(new_inducing_points, fit_model.get_model().threads_.get());
742+
743+
// Compute K_pn = cov(prev, new)
744+
const Fit<SparseGPFit<FeatureType>> &prev_fit = fit_model.get_fit();
745+
const auto &prev_inducing_points = prev_fit.train_features;
746+
const Eigen::MatrixXd K_pn = cov(prev_inducing_points, new_inducing_points,
747+
fit_model.get_model().threads_.get());
748+
// A = L_pp^-1 K_pn
749+
const Eigen::MatrixXd A = prev_fit.train_covariance.sqrt_solve(K_pn);
750+
const Eigen::Index p = K_pn.rows();
751+
const Eigen::Index n = K_nn.rows();
752+
Eigen::MatrixXd B = Eigen::MatrixXd::Zero(n + p, n);
753+
754+
// B[upper] = R P^T L_pp^-T A
755+
const auto LTiA = prev_fit.train_covariance.sqrt_transpose_solve(A);
756+
B.topRows(p) = prev_fit.R.template triangularView<Eigen::Upper>() *
757+
(prev_fit.P.transpose() * LTiA);
758+
759+
// B[lower] = chol(K_nn - A^T A)^T
760+
Eigen::MatrixXd S_nn = K_nn - A.transpose() * A;
761+
// This cholesky operation here is the most likely to experience numerical
762+
// instability because of the A^T A subtraction involved, so we add a nugget.
763+
const double nugget =
764+
fit_model.get_model().get_params()[details::inducing_nugget_name()].value;
765+
assert(nugget >= 0);
766+
S_nn.diagonal() += Eigen::VectorXd::Constant(S_nn.rows(), nugget);
767+
B.bottomRows(n) = Eigen::SerializableLDLT(S_nn).sqrt_transpose();
768+
769+
const auto B_qr =
770+
QRImplementation::compute(B, fit_model.get_model().threads_.get());
771+
772+
Fit<SparseGPFit<FeatureType>> new_fit;
773+
new_fit.train_features = new_inducing_points;
774+
new_fit.train_covariance = Eigen::SerializableLDLT(K_nn);
775+
// v_n = K_nn^-1 K_np v_p
776+
new_fit.information = new_fit.train_covariance.solve(
777+
fit_model.predict(new_inducing_points).mean());
778+
new_fit.P = get_P(*B_qr);
779+
new_fit.R = get_R(*B_qr);
780+
new_fit.numerical_rank = B_qr->rank();
781+
782+
return FitModel<
783+
SparseGaussianProcessRegression<CovFunc, MeanFunc, GrouperFunction,
784+
InducingPointStrategy, QRImplementation>,
785+
Fit<SparseGPFit<FeatureType>>>(fit_model.get_model(), std::move(new_fit));
715786
}
716787

717788
template <typename CovFunc, typename MeanFunc, typename GrouperFunction,

0 commit comments

Comments
 (0)