Skip to content

Commit 2878064

Browse files
authored
Merge pull request #239 from StochasticTree/python-arg-order-hotfix
Update argument order in Python BART and BCF
2 parents 96f78f0 + 2d6b08e commit 2878064

File tree

2 files changed

+18
-18
lines changed

2 files changed

+18
-18
lines changed

stochtree/bart.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2193,25 +2193,19 @@ def compute_contrast(
21932193

21942194
def compute_posterior_interval(
21952195
self,
2196-
terms: Union[list[str], str] = "all",
2197-
level: float = 0.95,
2198-
scale: str = "linear",
21992196
X: np.array = None,
22002197
leaf_basis: np.array = None,
22012198
rfx_group_ids: np.array = None,
22022199
rfx_basis: np.array = None,
2200+
terms: Union[list[str], str] = "all",
2201+
level: float = 0.95,
2202+
scale: str = "linear",
22032203
) -> dict:
22042204
"""
22052205
Compute posterior credible intervals for specified terms from a fitted BART model. It supports intervals for mean functions, variance functions, random effects, and overall predictions.
22062206
22072207
Parameters
22082208
----------
2209-
terms : str, optional
2210-
Character string specifying the model term(s) for which to compute intervals. Options for BART models are `"mean_forest"`, `"variance_forest"`, `"rfx"`, `"y_hat"`, or `"all"`. Defaults to `"all"`.
2211-
scale : str, optional
2212-
Scale of mean function predictions. Options are "linear", which returns predictions on the original scale of the mean forest / RFX terms, and "probability", which transforms predictions into a probability of observing `y == 1`. "probability" is only valid for models fit with a probit outcome model. Defaults to `"linear"`.
2213-
level : float, optional
2214-
A numeric value between 0 and 1 specifying the credible interval level. Defaults to 0.95 for a 95% credible interval.
22152209
X : np.array, optional
22162210
Optional array or data frame of covariates at which to compute the intervals. Required if the requested term depends on covariates (e.g., mean forest, variance forest, or overall predictions).
22172211
leaf_basis : np.array, optional
@@ -2220,6 +2214,12 @@ def compute_posterior_interval(
22202214
Optional vector of group IDs for random effects. Required if the requested term includes random effects.
22212215
rfx_basis : np.array, optional
22222216
Optional matrix of basis function evaluations for random effects. Required if the requested term includes random effects.
2217+
terms : str, optional
2218+
Character string specifying the model term(s) for which to compute intervals. Options for BART models are `"mean_forest"`, `"variance_forest"`, `"rfx"`, `"y_hat"`, or `"all"`. Defaults to `"all"`.
2219+
scale : str, optional
2220+
Scale of mean function predictions. Options are "linear", which returns predictions on the original scale of the mean forest / RFX terms, and "probability", which transforms predictions into a probability of observing `y == 1`. "probability" is only valid for models fit with a probit outcome model. Defaults to `"linear"`.
2221+
level : float, optional
2222+
A numeric value between 0 and 1 specifying the credible interval level. Defaults to 0.95 for a 95% credible interval.
22232223
22242224
Returns
22252225
-------

stochtree/bcf.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3260,26 +3260,20 @@ def compute_contrast(
32603260

32613261
def compute_posterior_interval(
32623262
self,
3263-
terms: Union[list[str], str] = "all",
3264-
level: float = 0.95,
3265-
scale: str = "linear",
32663263
X: np.array = None,
32673264
Z: np.array = None,
32683265
propensity: np.array = None,
32693266
rfx_group_ids: np.array = None,
32703267
rfx_basis: np.array = None,
3268+
terms: Union[list[str], str] = "all",
3269+
level: float = 0.95,
3270+
scale: str = "linear",
32713271
) -> dict:
32723272
"""
32733273
Compute posterior credible intervals for specified terms from a fitted BART model. It supports intervals for mean functions, variance functions, random effects, and overall predictions.
32743274
32753275
Parameters
32763276
----------
3277-
terms : str, optional
3278-
Character string specifying the model term(s) for which to compute intervals. Options for BCF models are `"prognostic_function"`, `"mu"`, `"cate"`, `"tau"`, `"variance_forest"`, `"rfx"`, or `"y_hat"`. Defaults to `"all"`. Note that `"mu"` is only different from `"prognostic_function"` if random effects are included with a model spec of `"intercept_only"` or `"intercept_plus_treatment"` and `"tau"` is only different from `"cate"` if random effects are included with a model spec of `"intercept_plus_treatment"`.
3279-
scale : str, optional
3280-
Scale of mean function predictions. Options are "linear", which returns predictions on the original scale of the mean forest / RFX terms, and "probability", which transforms predictions into a probability of observing `y == 1`. "probability" is only valid for models fit with a probit outcome model. Defaults to `"linear"`.
3281-
level : float, optional
3282-
A numeric value between 0 and 1 specifying the credible interval level. Defaults to 0.95 for a 95% credible interval.
32833277
X : np.array, optional
32843278
Optional array or data frame of covariates at which to compute the intervals. Required if the requested term depends on covariates (e.g., prognostic forest, treatment effect forest, variance forest, or overall predictions).
32853279
Z : np.array, optional
@@ -3290,6 +3284,12 @@ def compute_posterior_interval(
32903284
Optional vector of group IDs for random effects. Required if the requested term includes random effects.
32913285
rfx_basis : np.array, optional
32923286
Optional matrix of basis function evaluations for random effects. Required if the requested term includes random effects.
3287+
terms : str, optional
3288+
Character string specifying the model term(s) for which to compute intervals. Options for BCF models are `"prognostic_function"`, `"mu"`, `"cate"`, `"tau"`, `"variance_forest"`, `"rfx"`, or `"y_hat"`. Defaults to `"all"`. Note that `"mu"` is only different from `"prognostic_function"` if random effects are included with a model spec of `"intercept_only"` or `"intercept_plus_treatment"` and `"tau"` is only different from `"cate"` if random effects are included with a model spec of `"intercept_plus_treatment"`.
3289+
scale : str, optional
3290+
Scale of mean function predictions. Options are "linear", which returns predictions on the original scale of the mean forest / RFX terms, and "probability", which transforms predictions into a probability of observing `y == 1`. "probability" is only valid for models fit with a probit outcome model. Defaults to `"linear"`.
3291+
level : float, optional
3292+
A numeric value between 0 and 1 specifying the credible interval level. Defaults to 0.95 for a 95% credible interval.
32933293
32943294
Returns
32953295
-------

0 commit comments

Comments
 (0)