@@ -102,14 +102,14 @@ def approx_hsgp_hyperparams(
102102 based on recommendations from Ruitort-Mayol et. al.
103103
104104 In practice, you need to choose `c` large enough to handle the largest lengthscales,
105- and `m` large enough to accommodate the smallest lengthscales. Use your prior on the
106- lengthscale as guidance for setting the prior range. For example, if you believe
105+ and `m` large enough to accommodate the smallest lengthscales. Use your prior on the
106+ lengthscale as guidance for setting the prior range. For example, if you believe
107107 that 95% of the prior mass of the lengthscale is between 1 and 5, set the
108108 `lengthscale_range` to be [1, 5], or maybe a touch wider.
109109
110- Also, be sure to pass in an `x` that is exemplary of the domain not just of your
110+ Also, be sure to pass in an `x` that is exemplary of the domain not just of your
111111 training data, but also where you intend to make predictions. For instance, if your
112- training x values are from [0, 10], and you intend to predict from [7, 15], you can
112+ training x values are from [0, 10], and you intend to predict from [7, 15], you can
113113 pass in `x_range = [0, 15]`.
114114
115115 NB: These recommendations are based on a one-dimensional GP.
@@ -295,6 +295,7 @@ def __init__(
295295
296296 if parametrization is not None :
297297 parametrization = parametrization .lower ().replace ("-" , "" ).replace ("_" , "" )
298+
298299 if parametrization not in ["centered" , "noncentered" ]:
299300 raise ValueError ("`parametrization` must be either 'centered' or 'noncentered'." )
300301
@@ -597,6 +598,7 @@ def __init__(
597598
598599 self ._m = m
599600 self .scale = scale
601+ self ._X_center = None
600602
601603 super ().__init__ (mean_func = mean_func , cov_func = cov_func )
602604
@@ -672,8 +674,8 @@ def prior_linearized(self, X: TensorLike):
672674 # Important: fix the computation of the midpoint of X.
673675 # If X is mutated later, the training midpoint will be subtracted, not the testing one.
674676 if self ._X_center is None :
675- self ._X_center = (pt .max (Xs , axis = 0 ) + pt .min (Xs , axis = 0 )).eval () / 2
676- Xs = Xs - self ._X_center # center for accurate computation
677+ self ._X_center = (pt .max (X , axis = 0 ) + pt .min (X , axis = 0 )).eval () / 2
678+ Xs = X - self ._X_center # center for accurate computation
677679
678680 # Index Xs using input_dim and active_dims of covariance function
679681 Xs , _ = self .cov_func ._slice (Xs )
@@ -715,7 +717,7 @@ def prior(self, name: str, X: TensorLike, dims: str | None = None): # type: ign
715717
716718 def _build_conditional (self , Xnew ):
717719 try :
718- beta , X_mean = self ._beta , self ._X_mean
720+ beta , X_center = self ._beta , self ._X_center
719721
720722 except AttributeError :
721723 raise ValueError (
@@ -724,7 +726,9 @@ def _build_conditional(self, Xnew):
724726
725727 Xnew , _ = self .cov_func ._slice (Xnew )
726728
727- phi_cos , phi_sin = calc_basis_periodic (Xnew - X_mean , self .cov_func .period , self ._m , tl = pt )
729+ phi_cos , phi_sin = calc_basis_periodic (
730+ Xnew - X_center , self .cov_func .period , self ._m , tl = pt
731+ )
728732 m = self ._m
729733 J = pt .arange (0 , m , 1 )
730734 # rescale basis coefficients by the sqrt variance term
0 commit comments