We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
pred_var
1 parent 14d619e commit 2783e09Copy full SHA for 2783e09
econml/grf/_base_grf.py
@@ -795,8 +795,13 @@ def predict_full(self, X, interval=False, alpha=0.05):
795
point, pred_var = self._predict_point_and_var(X, full=True, point=True, var=True)
796
lb, ub = np.zeros(point.shape), np.zeros(point.shape)
797
for t in range(self.n_outputs_):
798
- lb[:, t] = scipy.stats.norm.ppf(alpha / 2, loc=point[:, t], scale=np.sqrt(pred_var[:, t, t]))
799
- ub[:, t] = scipy.stats.norm.ppf(1 - alpha / 2, loc=point[:, t], scale=np.sqrt(pred_var[:, t, t]))
+ var = pred_var[:, t, t]
+ assert np.isclose(var[var < 0], 0, atol=1e-8).all(), f'`pred_var` must be > 0 {var[var < 0]}'
800
+ var = np.maximum(var, 1e-32)
801
+
802
+ pred_dist = scipy.stats.norm(loc=point[:, t], scale=np.sqrt(var))
803
+ lb[:, t] = pred_dist.ppf(alpha / 2)
804
+ ub[:, t] = pred_dist.ppf(1 - (alpha / 2))
805
return point, lb, ub
806
return self._predict_point_and_var(X, full=True, point=True, var=False)
807
0 commit comments