3737__all__ = ["Latent" , "Marginal" , "TP" , "MarginalApprox" , "LatentKron" , "MarginalKron" ]
3838
3939
40+ _noise_deprecation_warning = (
41+ "The 'noise' parameter has been been changed to 'sigma' "
42+ "in order to standardize the GP API and will be "
43+ "deprecated in future releases."
44+ )
45+
46+
47+ def _handle_sigma_noise_parameters (sigma , noise ):
48+ """Helper function for transition of 'noise' parameter to be named 'sigma'."""
49+
50+ if (sigma is None and noise is None ) or (sigma is not None and noise is not None ):
51+ raise ValueError ("'sigma' argument must be specified." )
52+
53+ if sigma is None :
54+ warnings .warn (_noise_deprecation_warning , FutureWarning )
55+ return noise
56+
57+ return sigma
58+
59+
4060class Base :
4161 R"""
4262 Base class.
@@ -218,7 +238,7 @@ def conditional(self, name, Xnew, given=None, jitter=JITTER_DEFAULT, **kwargs):
218238 Xnew: array-like
219239 Function input values.
220240 given: dict
221- Can optionally take as key value pairs: `X`, `y`, `noise`,
241+ Can optionally take as key value pairs: `X`, `y`,
222242 and `gp`. See the section in the documentation on additive GP
223243 models in PyMC for more information.
224244 jitter: scalar
@@ -359,7 +379,7 @@ def conditional(self, name, Xnew, jitter=JITTER_DEFAULT, **kwargs):
359379 return pm .MvStudentT (name , nu = nu2 , mu = mu , cov = cov , ** kwargs )
360380
361381
362- @conditioned_vars (["X" , "y" , "noise " ])
382+ @conditioned_vars (["X" , "y" , "sigma " ])
363383class Marginal (Base ):
364384 R"""
365385 Marginal Gaussian process.
@@ -393,7 +413,7 @@ class Marginal(Base):
393413
394414 # Place a GP prior over the function f.
395415 sigma = pm.HalfCauchy("sigma", beta=3)
396- y_ = gp.marginal_likelihood("y", X=X, y=y, noise =sigma)
416+ y_ = gp.marginal_likelihood("y", X=X, y=y, sigma =sigma)
397417
398418 ...
399419
@@ -405,15 +425,15 @@ class Marginal(Base):
405425 fcond = gp.conditional("fcond", Xnew=Xnew)
406426 """
407427
408- def _build_marginal_likelihood (self , X , noise , jitter ):
428+ def _build_marginal_likelihood (self , X , noise_func , jitter ):
409429 mu = self .mean_func (X )
410430 Kxx = self .cov_func (X )
411- Knx = noise (X )
431+ Knx = noise_func (X )
412432 cov = Kxx + Knx
413433 return mu , stabilize (cov , jitter )
414434
415435 def marginal_likelihood (
416- self , name , X , y , noise , jitter = JITTER_DEFAULT , is_observed = True , ** kwargs
436+ self , name , X , y , sigma = None , noise = None , jitter = JITTER_DEFAULT , is_observed = True , ** kwargs
417437 ):
418438 R"""
419439 Returns the marginal likelihood distribution, given the input
@@ -435,23 +455,25 @@ def marginal_likelihood(
435455 y: array-like
436456 Data that is the sum of the function with the GP prior and Gaussian
437457 noise. Must have shape `(n, )`.
438- noise : scalar, Variable, or Covariance
458+ sigma : scalar, Variable, or Covariance
439459 Standard deviation of the Gaussian noise. Can also be a Covariance for
440460 non-white noise.
461+ noise: scalar, Variable, or Covariance
462+ Previous parameterization of `sigma`.
441463 jitter: scalar
442464 A small correction added to the diagonal of positive semi-definite
443465 covariance matrices to ensure numerical stability.
444466 **kwargs
445467 Extra keyword arguments that are passed to `MvNormal` distribution
446468 constructor.
447469 """
470+ sigma = _handle_sigma_noise_parameters (sigma = sigma , noise = noise )
448471
449- if not isinstance (noise , Covariance ):
450- noise = pm .gp .cov .WhiteNoise (noise )
451- mu , cov = self ._build_marginal_likelihood (X , noise , jitter )
472+ noise_func = sigma if isinstance (sigma , Covariance ) else pm .gp .cov .WhiteNoise (sigma )
473+ mu , cov = self ._build_marginal_likelihood (X = X , noise_func = noise_func , jitter = jitter )
452474 self .X = X
453475 self .y = y
454- self .noise = noise
476+ self .sigma = noise_func
455477 if is_observed :
456478 return pm .MvNormal (name , mu = mu , cov = cov , observed = y , ** kwargs )
457479 else :
@@ -472,20 +494,24 @@ def _get_given_vals(self, given):
472494 else :
473495 cov_total = self .cov_func
474496 mean_total = self .mean_func
475- if all (val in given for val in ["X" , "y" , "noise" ]):
476- X , y , noise = given ["X" ], given ["y" ], given ["noise" ]
477- if not isinstance (noise , Covariance ):
478- noise = pm .gp .cov .WhiteNoise (noise )
497+
498+ if "noise" in given :
499+ warnings .warn (_noise_deprecation_warning , FutureWarning )
500+ given ["sigma" ] = given ["noise" ]
501+
502+ if all (val in given for val in ["X" , "y" , "sigma" ]):
503+ X , y , sigma = given ["X" ], given ["y" ], given ["sigma" ]
504+ noise_func = sigma if isinstance (sigma , Covariance ) else pm .gp .cov .WhiteNoise (sigma )
479505 else :
480- X , y , noise = self .X , self .y , self .noise
481- return X , y , noise , cov_total , mean_total
506+ X , y , noise_func = self .X , self .y , self .sigma
507+ return X , y , noise_func , cov_total , mean_total
482508
483509 def _build_conditional (
484- self , Xnew , pred_noise , diag , X , y , noise , cov_total , mean_total , jitter
510+ self , Xnew , pred_noise , diag , X , y , noise_func , cov_total , mean_total , jitter
485511 ):
486512 Kxx = cov_total (X )
487513 Kxs = self .cov_func (X , Xnew )
488- Knx = noise (X )
514+ Knx = noise_func (X )
489515 rxx = y - mean_total (X )
490516 L = cholesky (stabilize (Kxx , jitter ) + Knx )
491517 A = solve_lower (L , Kxs )
@@ -495,13 +521,13 @@ def _build_conditional(
495521 Kss = self .cov_func (Xnew , diag = True )
496522 var = Kss - at .sum (at .square (A ), 0 )
497523 if pred_noise :
498- var += noise (Xnew , diag = True )
524+ var += noise_func (Xnew , diag = True )
499525 return mu , var
500526 else :
501527 Kss = self .cov_func (Xnew )
502528 cov = Kss - at .dot (at .transpose (A ), A )
503529 if pred_noise :
504- cov += noise (Xnew )
530+ cov += noise_func (Xnew )
505531 return mu , cov if pred_noise else stabilize (cov , jitter )
506532
507533 def conditional (
@@ -531,7 +557,7 @@ def conditional(
531557 Whether or not observation noise is included in the conditional.
532558 Default is `False`.
533559 given: dict
534- Can optionally take as key value pairs: `X`, `y`, `noise `,
560+ Can optionally take as key value pairs: `X`, `y`, `sigma `,
535561 and `gp`. See the section in the documentation on additive GP
536562 models in PyMC for more information.
537563 jitter: scalar
@@ -720,7 +746,9 @@ def _build_marginal_likelihood_loglik(self, y, X, Xu, sigma, jitter):
720746 quadratic = 0.5 * (at .dot (r , r_l ) - at .dot (c , c ))
721747 return - 1.0 * (constant + logdet + quadratic + trace )
722748
723- def marginal_likelihood (self , name , X , Xu , y , noise = None , jitter = JITTER_DEFAULT , ** kwargs ):
749+ def marginal_likelihood (
750+ self , name , X , Xu , y , sigma = None , noise = None , jitter = JITTER_DEFAULT , ** kwargs
751+ ):
724752 R"""
725753 Returns the approximate marginal likelihood distribution, given the input
726754 locations `X`, inducing point locations `Xu`, data `y`, and white noise
@@ -738,8 +766,10 @@ def marginal_likelihood(self, name, X, Xu, y, noise=None, jitter=JITTER_DEFAULT,
738766 y: array-like
739767 Data that is the sum of the function with the GP prior and Gaussian
740768 noise. Must have shape `(n, )`.
741- noise : scalar, Variable
769+ sigma : scalar, Variable
742770 Standard deviation of the Gaussian noise.
771+ noise: scalar, Variable
772+ Previous parameterization of `sigma`
743773 jitter: scalar
744774 A small correction added to the diagonal of positive semi-definite
745775 covariance matrices to ensure numerical stability.
@@ -752,12 +782,11 @@ def marginal_likelihood(self, name, X, Xu, y, noise=None, jitter=JITTER_DEFAULT,
752782 self .Xu = Xu
753783 self .y = y
754784
755- if noise is None :
756- raise ValueError ("noise argument must be specified" )
757- else :
758- self .sigma = noise
785+ self .sigma = _handle_sigma_noise_parameters (sigma = sigma , noise = noise )
759786
760- approx_loglik = self ._build_marginal_likelihood_loglik (y , X , Xu , noise , jitter )
787+ approx_loglik = self ._build_marginal_likelihood_loglik (
788+ y = self .y , X = self .X , Xu = self .Xu , sigma = self .sigma , jitter = jitter
789+ )
761790 pm .Potential (f"marginalapprox_loglik_{ name } " , approx_loglik , ** kwargs )
762791
763792 def _build_conditional (
@@ -828,7 +857,7 @@ def conditional(
828857 Whether or not observation noise is included in the conditional.
829858 Default is `False`.
830859 given: dict
831- Can optionally take as key value pairs: `X`, `Xu`, `y`, `noise `,
860+ Can optionally take as key value pairs: `X`, `Xu`, `y`, `sigma `,
832861 and `gp`. See the section in the documentation on additive GP
833862 models in PyMC for more information.
834863 jitter: scalar
0 commit comments