@@ -703,15 +703,86 @@ class SparseGaussianProcessRegression
703703
704704// rebase_inducing_points takes a Sparse GP which was fit using some set of
705705// inducing points and creates a new fit relative to new inducing points.
706+ //
706707// Note that this will NOT be the equivalent to having fit the model with
707708// the new inducing points since some information may have been lost in
708709// the process.
709- template <typename ModelType, typename FeatureType, typename NewFeatureType>
710+ //
711+ // For example, consider the extreme case where your first fit
712+ // doesn't have any inducing points at all, all the information from the first
713+ // observations will have been lost, and when you rebase on new inducing points
714+ // you'd have the prior for those new points.
715+ //
716+ // For implementation details see the online documentation.
717+ //
718+ // The summary involves:
719+ // - Compute K_nn = cov(new, new)
720+ // - Compute K_pn = cov(prev, new)
721+ // - Compute A = L_pp^-1 K_pn
722+ // - Solve for Lhat_nn = chol(K_nn - A^T A)
723+ // - Solve for QRP^T = [Lat_nn
724+ // R_p P_p^T L_pp^-T A]
725+ // - Solve for L_nn = chol(K_nn)
726+ // - Solve for v_n = K_nn^-1 K_np v_p
727+ //
728+ template <typename CovFunc, typename MeanFunc, typename GrouperFunction,
729+ typename InducingPointStrategy, typename QRImplementation,
730+ typename FeatureType, typename NewFeatureType>
710731auto rebase_inducing_points (
711- const FitModel<ModelType, Fit<SparseGPFit<FeatureType>>> &fit_model,
732+ const FitModel<SparseGaussianProcessRegression<
733+ CovFunc, MeanFunc, GrouperFunction,
734+ InducingPointStrategy, QRImplementation>,
735+ Fit<SparseGPFit<FeatureType>>> &fit_model,
712736 const std::vector<NewFeatureType> &new_inducing_points) {
713- return fit_model.get_model ().fit_from_prediction (
714- new_inducing_points, fit_model.predict (new_inducing_points).joint ());
737+
738+ const auto &cov = fit_model.get_model ().get_covariance ();
739+ // Compute K_nn = cov(new, new)
740+ const Eigen::MatrixXd K_nn =
741+ cov (new_inducing_points, fit_model.get_model ().threads_ .get ());
742+
743+ // Compute K_pn = cov(prev, new)
744+ const Fit<SparseGPFit<FeatureType>> &prev_fit = fit_model.get_fit ();
745+ const auto &prev_inducing_points = prev_fit.train_features ;
746+ const Eigen::MatrixXd K_pn = cov (prev_inducing_points, new_inducing_points,
747+ fit_model.get_model ().threads_ .get ());
748+ // A = L_pp^-1 K_pn
749+ const Eigen::MatrixXd A = prev_fit.train_covariance .sqrt_solve (K_pn);
750+ const Eigen::Index p = K_pn.rows ();
751+ const Eigen::Index n = K_nn.rows ();
752+ Eigen::MatrixXd B = Eigen::MatrixXd::Zero (n + p, n);
753+
754+ // B[upper] = R P^T L_pp^-T A
755+ const auto LTiA = prev_fit.train_covariance .sqrt_transpose_solve (A);
756+ B.topRows (p) = prev_fit.R .template triangularView <Eigen::Upper>() *
757+ (prev_fit.P .transpose () * LTiA);
758+
759+ // B[lower] = chol(K_nn - A^T A)^T
760+ Eigen::MatrixXd S_nn = K_nn - A.transpose () * A;
761+ // This cholesky operation here is the most likely to experience numerical
762+ // instability because of the A^T A subtraction involved, so we add a nugget.
763+ const double nugget =
764+ fit_model.get_model ().get_params ()[details::inducing_nugget_name ()].value ;
765+ assert (nugget >= 0 );
766+ S_nn.diagonal () += Eigen::VectorXd::Constant (S_nn.rows (), nugget);
767+ B.bottomRows (n) = Eigen::SerializableLDLT (S_nn).sqrt_transpose ();
768+
769+ const auto B_qr =
770+ QRImplementation::compute (B, fit_model.get_model ().threads_ .get ());
771+
772+ Fit<SparseGPFit<FeatureType>> new_fit;
773+ new_fit.train_features = new_inducing_points;
774+ new_fit.train_covariance = Eigen::SerializableLDLT (K_nn);
775+ // v_n = K_nn^-1 K_np v_p
776+ new_fit.information = new_fit.train_covariance .solve (
777+ fit_model.predict (new_inducing_points).mean ());
778+ new_fit.P = get_P (*B_qr);
779+ new_fit.R = get_R (*B_qr);
780+ new_fit.numerical_rank = B_qr->rank ();
781+
782+ return FitModel<
783+ SparseGaussianProcessRegression<CovFunc, MeanFunc, GrouperFunction,
784+ InducingPointStrategy, QRImplementation>,
785+ Fit<SparseGPFit<FeatureType>>>(fit_model.get_model (), std::move (new_fit));
715786}
716787
717788template <typename CovFunc, typename MeanFunc, typename GrouperFunction,
0 commit comments