Skip to content

Commit 2783e09

Browse files
authored
Enforce pred_var is always greater than zero on GRF (#480)
1 parent 14d619e commit 2783e09

File tree

1 file changed

+7
-2
lines changed

1 file changed

+7
-2
lines changed

econml/grf/_base_grf.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -795,8 +795,13 @@ def predict_full(self, X, interval=False, alpha=0.05):
795795
point, pred_var = self._predict_point_and_var(X, full=True, point=True, var=True)
796796
lb, ub = np.zeros(point.shape), np.zeros(point.shape)
797797
for t in range(self.n_outputs_):
798-
lb[:, t] = scipy.stats.norm.ppf(alpha / 2, loc=point[:, t], scale=np.sqrt(pred_var[:, t, t]))
799-
ub[:, t] = scipy.stats.norm.ppf(1 - alpha / 2, loc=point[:, t], scale=np.sqrt(pred_var[:, t, t]))
798+
var = pred_var[:, t, t]
799+
assert np.isclose(var[var < 0], 0, atol=1e-8).all(), f'`pred_var` must be > 0 {var[var < 0]}'
800+
var = np.maximum(var, 1e-32)
801+
802+
pred_dist = scipy.stats.norm(loc=point[:, t], scale=np.sqrt(var))
803+
lb[:, t] = pred_dist.ppf(alpha / 2)
804+
ub[:, t] = pred_dist.ppf(1 - (alpha / 2))
800805
return point, lb, ub
801806
return self._predict_point_and_var(X, full=True, point=True, var=False)
802807

0 commit comments

Comments
 (0)