@@ -330,7 +330,7 @@ def L(self) -> pt.TensorVariable:
330330 def L (self , value : TensorLike ):
331331 self ._L = pt .as_tensor_variable (value )
332332
333- def prior_linearized (self , Xs : TensorLike ):
333+ def prior_linearized (self , X : TensorLike ):
334334 """Linearized version of the HSGP. Returns the Laplace eigenfunctions and the square root
335335 of the power spectral density needed to create the GP.
336336
@@ -343,7 +343,7 @@ def prior_linearized(self, Xs: TensorLike):
343343
344344 Parameters
345345 ----------
346- Xs : array-like
346+ X : array-like
347347 Function input values.
348348
349349 Returns
@@ -371,9 +371,9 @@ def prior_linearized(self, Xs: TensorLike):
371371 # L = [10] means the approximation is valid from Xs = [-10, 10]
372372 gp = pm.gp.HSGP(m=[200], L=[10], cov_func=cov_func)
373373
374+ # Set X as Data so it can be mutated later, and then pass it to the GP
374375 X = pm.Data("X", X)
375- # Pass X to the GP
376- phi, sqrt_psd = gp.prior_linearized(Xs=X)
376+ phi, sqrt_psd = gp.prior_linearized(X=X)
377377
378378 # Specify standard normal prior in the coefficients, the number of which
379379 # is given by the number of basis vectors, saved in `n_basis_vectors`.
@@ -403,8 +403,8 @@ def prior_linearized(self, Xs: TensorLike):
403403 # Important: fix the computation of the midpoint of X.
404404 # If X is mutated later, the training midpoint will be subtracted, not the testing one.
405405 if self ._X_center is None :
406- self ._X_center = (pt .max (Xs , axis = 0 ) + pt .min (Xs , axis = 0 )).eval () / 2
407- Xs = Xs - self ._X_center # center for accurate computation
406+ self ._X_center = (pt .max (X , axis = 0 ) + pt .min (X , axis = 0 )).eval () / 2
407+ Xs = X - self ._X_center # center for accurate computation
408408
409409 # Index Xs using input_dim and active_dims of covariance function
410410 Xs , _ = self .cov_func ._slice (Xs )
@@ -600,7 +600,7 @@ def __init__(
600600
601601 super ().__init__ (mean_func = mean_func , cov_func = cov_func )
602602
603- def prior_linearized (self , Xs : TensorLike ):
603+ def prior_linearized (self , X : TensorLike ):
604604 """Linearized version of the approximation. Returns the cosine and sine bases and coefficients
605605 of the expansion needed to create the GP.
606606
@@ -615,8 +615,8 @@ def prior_linearized(self, Xs: TensorLike):
615615
616616 Parameters
617617 ----------
618- Xs : array-like
619- Function input values. Assumes they have been mean subtracted or centered at zero.
618+ X : array-like
619+ Function input values.
620620
621621 Returns
622622 -------
@@ -640,15 +640,9 @@ def prior_linearized(self, Xs: TensorLike):
640640 # m=200 means 200 basis vectors
641641 gp = pm.gp.HSGPPeriodic(m=200, scale=scale, cov_func=cov_func)
642642
643- # Order is important. First calculate the mean, then make X a shared variable,
644- # then subtract the mean. When X is mutated later, the correct mean will be
645- # subtracted.
646- X_mean = np.mean(X, axis=0)
647- X = pm.MutableData("X", X)
648- Xs = X - X_mean
649-
650- # Pass the zero-subtracted Xs in to the GP
651- (phi_cos, phi_sin), psd = gp.prior_linearized(Xs=Xs)
643+ # Set X as Data so it can be mutated later, and then pass it to the GP
644+ X = pm.Data("X", X)
645+ (phi_cos, phi_sin), psd = gp.prior_linearized(X=X)
652646
653647 # Specify standard normal prior in the coefficients. The number of which
654648 # is twice the number of basis vectors minus one.
@@ -675,6 +669,13 @@ def prior_linearized(self, Xs: TensorLike):
675669 with model:
676670 ppc = pm.sample_posterior_predictive(idata, var_names=["f"])
677671 """
672+ # Important: fix the computation of the midpoint of X.
673+ # If X is mutated later, the training midpoint will be subtracted, not the testing one.
674+ if self ._X_center is None :
675+ self ._X_center = (pt .max (Xs , axis = 0 ) + pt .min (Xs , axis = 0 )).eval () / 2
676+ Xs = Xs - self ._X_center # center for accurate computation
677+
678+ # Index Xs using input_dim and active_dims of covariance function
678679 Xs , _ = self .cov_func ._slice (Xs )
679680
680681 phi_cos , phi_sin = calc_basis_periodic (Xs , self .cov_func .period , self ._m , tl = pt )
@@ -697,9 +698,7 @@ def prior(self, name: str, X: TensorLike, dims: str | None = None): # type: ign
697698 dims: None
698699 Dimension name for the GP random variable.
699700 """
700- self ._X_mean = pt .mean (X , axis = 0 )
701-
702- (phi_cos , phi_sin ), psd = self .prior_linearized (X - self ._X_mean )
701+ (phi_cos , phi_sin ), psd = self .prior_linearized (X )
703702
704703 m = self ._m
705704 self ._beta = pm .Normal (f"{ name } _hsgp_coeffs_" , size = (m * 2 - 1 ))
0 commit comments