1313# limitations under the License.
1414
1515
16- from typing import Sequence , Union
16+ from collections import namedtuple
17+ from typing import Sequence , Tuple , Union
1718
1819import numpy as np
1920import pymc as pm
2223__all__ = ["R2D2M2CP" ]
2324
2425
25- def _psivar2musigma (psi : pt .TensorVariable , explained_var : pt .TensorVariable , psi_mask ):
26+ def _psivar2musigma (
27+ psi : pt .TensorVariable ,
28+ explained_var : pt .TensorVariable ,
29+ psi_mask : Union [pt .TensorLike , None ],
30+ ) -> Tuple [pt .TensorVariable , pt .TensorVariable ]:
31+ sign = pt .sign (psi - 0.5 )
32+ if psi_mask is not None :
33+ # any computation might be ignored for ~psi_mask
34+ # sign and explained_var are used
35+ psi = pt .where (psi_mask , psi , 0.5 )
2636 pi = pt .erfinv (2 * psi - 1 )
2737 f = (1 / (2 * pi ** 2 + 1 )) ** 0.5
2838 sigma = explained_var ** 0.5 * f
2939 mu = sigma * pi * 2 ** 0.5
3040 if psi_mask is not None :
3141 return (
32- pt .where (psi_mask , mu , pt . sign ( pi ) * explained_var ** 0.5 ),
42+ pt .where (psi_mask , mu , sign * explained_var ** 0.5 ),
3343 pt .where (psi_mask , sigma , 0 ),
3444 )
3545 else :
@@ -47,7 +57,7 @@ def _R2D2M2CP_beta(
4757 psi_mask ,
4858 dims : Union [str , Sequence [str ]],
4959 centered = False ,
50- ):
60+ ) -> pt . TensorVariable :
5161 """R2D2M2CP beta prior.
5262
5363 Parameters
@@ -65,7 +75,7 @@ def _R2D2M2CP_beta(
6575 psi: tensor
6676 probability of a coefficients to be positive
6777 """
68- explained_variance = phi * pt .expand_dims (r2 * output_sigma ** 2 , - 1 )
78+ explained_variance = phi * pt .expand_dims (r2 * output_sigma ** 2 , ( - 1 ,) )
6979 mu_param , std_param = _psivar2musigma (psi , explained_variance , psi_mask = psi_mask )
7080 if not centered :
7181 with pm .Model (name ):
@@ -107,7 +117,10 @@ def _R2D2M2CP_beta(
107117 return beta
108118
109119
110- def _broadcast_as_dims (* values , dims ):
120+ def _broadcast_as_dims (
121+ * values : np .ndarray ,
122+ dims : Sequence [str ],
123+ ) -> Union [Tuple [np .ndarray , ...], np .ndarray ]:
111124 model = pm .modelcontext (None )
112125 shape = [len (model .coords [d ]) for d in dims ]
113126 ret = tuple (np .broadcast_to (v , shape ) for v in values )
@@ -117,7 +130,12 @@ def _broadcast_as_dims(*values, dims):
117130 return ret
118131
119132
120- def _psi_masked (positive_probs , positive_probs_std , * , dims ):
133+ def _psi_masked (
134+ positive_probs : pt .TensorLike ,
135+ positive_probs_std : pt .TensorLike ,
136+ * ,
137+ dims : Sequence [str ],
138+ ) -> Tuple [Union [pt .TensorLike , None ], pt .TensorVariable ]:
121139 if not (
122140 isinstance (positive_probs , pt .Constant ) and isinstance (positive_probs_std , pt .Constant )
123141 ):
@@ -152,7 +170,12 @@ def _psi_masked(positive_probs, positive_probs_std, *, dims):
152170 return mask , psi
153171
154172
155- def _psi (positive_probs , positive_probs_std , * , dims ):
173+ def _psi (
174+ positive_probs : pt .TensorLike ,
175+ positive_probs_std : Union [pt .TensorLike , None ],
176+ * ,
177+ dims : Sequence [str ],
178+ ) -> Tuple [Union [pt .TensorLike , None ], pt .TensorVariable ]:
156179 if positive_probs_std is not None :
157180 mask , psi = _psi_masked (
158181 positive_probs = pt .as_tensor (positive_probs ),
@@ -171,12 +194,12 @@ def _psi(positive_probs, positive_probs_std, *, dims):
171194
172195
173196def _phi (
174- variables_importance ,
175- variance_explained ,
176- importance_concentration ,
197+ variables_importance : Union [ pt . TensorLike , None ] ,
198+ variance_explained : Union [ pt . TensorLike , None ] ,
199+ importance_concentration : Union [ pt . TensorLike , None ] ,
177200 * ,
178- dims ,
179- ):
201+ dims : Sequence [ str ] ,
202+ ) -> pt . TensorVariable :
180203 * broadcast_dims , dim = dims
181204 model = pm .modelcontext (None )
182205 if variables_importance is not None :
@@ -200,47 +223,50 @@ def _phi(
200223 return phi
201224
202225
226+ R2D2M2CPOut = namedtuple ("R2D2M2CPOut" , ["eps" , "beta" ])
227+
228+
203229def R2D2M2CP (
204- name ,
205- output_sigma ,
206- input_sigma ,
230+ name : str ,
231+ output_sigma : pt . TensorLike ,
232+ input_sigma : pt . TensorLike ,
207233 * ,
208- dims ,
209- r2 ,
210- variables_importance = None ,
211- variance_explained = None ,
212- importance_concentration = None ,
213- r2_std = None ,
214- positive_probs = 0.5 ,
215- positive_probs_std = None ,
216- centered = False ,
217- ):
234+ dims : Sequence [ str ] ,
235+ r2 : pt . TensorLike ,
236+ variables_importance : Union [ pt . TensorLike , None ] = None ,
237+ variance_explained : Union [ pt . TensorLike , None ] = None ,
238+ importance_concentration : Union [ pt . TensorLike , None ] = None ,
239+ r2_std : Union [ pt . TensorLike , None ] = None ,
240+ positive_probs : Union [ pt . TensorLike , None ] = 0.5 ,
241+ positive_probs_std : Union [ pt . TensorLike , None ] = None ,
242+ centered : bool = False ,
243+ ) -> R2D2M2CPOut :
218244 """R2D2M2CP Prior.
219245
220246 Parameters
221247 ----------
222248 name : str
223249 Name for the distribution
224- output_sigma : tensor
250+ output_sigma : Tensor
225251 Output standard deviation
226- input_sigma : tensor
252+ input_sigma : Tensor
227253 Input standard deviation
228254 dims : Union[str, Sequence[str]]
229255 Dims for the distribution
230- r2 : tensor
256+ r2 : Tensor
231257 :math:`R^2` estimate
232- variables_importance : tensor , optional
258+ variables_importance : Tensor , optional
233259 Optional estimate for variables importance, positive, by default None
234- variance_explained : tensor , optional
260+ variance_explained : Tensor , optional
235261 Alternative estimate for variables importance which is point estimate of
236262 variance explained, should sum up to one, by default None
237- importance_concentration : tensor , optional
263+ importance_concentration : Tensor , optional
238264 Confidence around variance explained or variable importance estimate
239- r2_std : tensor , optional
265+ r2_std : Tensor , optional
240266 Optional uncertainty over :math:`R^2`, by default None
241- positive_probs : tensor , optional
267+ positive_probs : Tensor , optional
242268 Optional probability of variables contribution to be positive, by default 0.5
243- positive_probs_std : tensor , optional
269+ positive_probs_std : Tensor , optional
244270 Optional uncertainty over effect direction probability, by default None
245271 centered : bool, optional
246272 Centered or Non-Centered parametrization of the distribution, by default Non-Centered. Advised to check both
@@ -419,4 +445,4 @@ def R2D2M2CP(
419445 psi_mask = mask ,
420446 )
421447 resid_sigma = (1 - r2 ) ** 0.5 * output_sigma
422- return resid_sigma , beta
448+ return R2D2M2CPOut ( resid_sigma , beta )
0 commit comments