1717
1818from collections .abc import Sequence
1919from types import ModuleType
20- from typing import NamedTuple
2120
2221import numpy as np
2322import pytensor .tensor as pt
@@ -89,15 +88,9 @@ def calc_basis_periodic(
8988 return phi_cos , phi_sin
9089
9190
92- class HSGPParams (NamedTuple ):
93- m : int
94- c : float
95- S : float
96-
97-
9891def approx_hsgp_hyperparams (
9992 x_range : list [float ], lengthscale_range : list [float ], cov_func : str
100- ) -> HSGPParams :
93+ ) -> tuple [ int , float ] :
10194 """Utility function that uses heuristics to recommend minimum `m` and `c` values,
10295 based on recommendations from Ruitort-Mayol et. al.
10396
@@ -107,10 +100,10 @@ def approx_hsgp_hyperparams(
107100 that 95% of the prior mass of the lengthscale is between 1 and 5, set the
108101 `lengthscale_range` to be [1, 5], or maybe a touch wider.
109102
110- Also, be sure to pass in an `x ` that is exemplary of the domain not just of your
103+ Also, be sure to pass in an `x_range ` that is exemplary of the domain not just of your
111104 training data, but also where you intend to make predictions. For instance, if your
112- training x values are from [0, 10], and you intend to predict from [7, 15], you can
113- pass in `x_range = [0, 15]`.
105+ training x values are from [0, 10], and you intend to predict from [7, 15], the narrowest
106+ `x_range` you should pass in would be `x_range = [0, 15]`.
114107
115108 NB: These recommendations are based on a one-dimensional GP.
116109
@@ -126,15 +119,11 @@ def approx_hsgp_hyperparams(
126119
127120 Returns
128121 -------
129- HSGPParams
130- A named tuple containing the recommended values for `m`, `c`, and `S`.
131- - `m` : int
132- Number of basis vectors. Increasing it helps approximate smaller lengthscales, but increases computational cost.
133- - `c` : float
134- Scaling factor such that L = c * S, where L is the boundary of the approximation.
135- Increasing it helps approximate larger lengthscales, but may require increasing m.
136- - `S` : float
137- The value of `S`, which is half the range, or radius, of `x`.
122+ - `m` : int
123+ Number of basis vectors. Increasing it helps approximate smaller lengthscales, but increases computational cost.
124+ - `c` : float
125+ Scaling factor such that L = c * S, where L is the boundary of the approximation.
126+ Increasing it helps approximate larger lengthscales, but may require increasing m.
138127
139128 Raises
140129 ------
@@ -171,7 +160,7 @@ def approx_hsgp_hyperparams(
171160 c = max (a1 * (lengthscale_range [1 ] / S ), 1.2 )
172161 m = int (a2 * c / (lengthscale_range [0 ] / S ))
173162
174- return HSGPParams ( m = m , c = c , S = S )
163+ return m , c
175164
176165
177166class HSGP (Base ):
0 commit comments