Skip to content

Commit 92eb863

Browse files
authored
Merge pull request #229 from StochasticTree/multi-chain-doc-upgrade
Updating docs and vignettes regarding multi-chain outputs
2 parents 5242d0b + c606754 commit 92eb863

File tree

13 files changed

+2586
-492
lines changed

13 files changed

+2586
-492
lines changed

DESCRIPTION

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,9 @@ Suggests:
4343
MASS,
4444
mvtnorm,
4545
rmarkdown,
46-
tgp
46+
tgp,
47+
coda,
48+
bayesplot
4749
VignetteBuilder: knitr
4850
SystemRequirements: C++17
4951
Imports:

R/bart.R

Lines changed: 38 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
#' @param num_burnin Number of "burn-in" iterations of the MCMC sampler. Default: 0.
2929
#' @param num_mcmc Number of "retained" iterations of the MCMC sampler. Default: 100.
3030
#' @param previous_model_json (Optional) JSON string containing a previous BART model. This can be used to "continue" a sampler interactively after inspecting the samples or to run parallel chains "warm-started" from existing forest samples. Default: `NULL`.
31-
#' @param previous_model_warmstart_sample_num (Optional) Sample number from `previous_model_json` that will be used to warmstart this BART sampler. One-indexed (so that the first sample is used for warm-start by setting `previous_model_warmstart_sample_num = 1`). Default: `NULL`.
31+
#' @param previous_model_warmstart_sample_num (Optional) Sample number from `previous_model_json` that will be used to warmstart this BART sampler. One-indexed (so that the first sample is used for warm-start by setting `previous_model_warmstart_sample_num = 1`). Default: `NULL`. If `num_chains` in the `general_params` list is > 1, then each successive chain will be initialized from a different sample, counting backwards from `previous_model_warmstart_sample_num`. That is, if `previous_model_warmstart_sample_num = 10` and `num_chains = 4`, then chain 1 will be initialized from sample 10, chain 2 from sample 9, chain 3 from sample 8, and chain 4 from sample 7. If `previous_model_json` is provided but `previous_model_warmstart_sample_num` is NULL, the last sample in the previous model will be used to initialize the first chain, counting backwards as noted before. If more chains are requested than there are samples in `previous_model_json`, a warning will be raised and only the last sample will be used.
3232
#' @param general_params (Optional) A list of general (non-forest-specific) model parameters, each of which has a default value processed internally, so this argument list is optional.
3333
#'
3434
#' - `cutpoint_grid_size` Maximum size of the "grid" of potential cutpoints to consider in the GFR algorithm. Default: `100`.
@@ -42,7 +42,7 @@
4242
#' - `keep_burnin` Whether or not "burnin" samples should be included in the stored samples of forests and other parameters. Default `FALSE`. Ignored if `num_mcmc = 0`.
4343
#' - `keep_gfr` Whether or not "grow-from-root" samples should be included in the stored samples of forests and other parameters. Default `FALSE`. Ignored if `num_mcmc = 0`.
4444
#' - `keep_every` How many iterations of the burned-in MCMC sampler should be run before forests and parameters are retained. Default `1`. Setting `keep_every <- k` for some `k > 1` will "thin" the MCMC samples by retaining every `k`-th sample, rather than simply every sample. This can reduce the autocorrelation of the MCMC samples.
45-
#' - `num_chains` How many independent MCMC chains should be sampled. If `num_mcmc = 0`, this is ignored. If `num_gfr = 0`, then each chain is run from root for `num_mcmc * keep_every + num_burnin` iterations, with `num_mcmc` samples retained. If `num_gfr > 0`, each MCMC chain will be initialized from a separate GFR ensemble, with the requirement that `num_gfr >= num_chains`. Default: `1`.
45+
#' - `num_chains` How many independent MCMC chains should be sampled. If `num_mcmc = 0`, this is ignored. If `num_gfr = 0`, then each chain is run from root for `num_mcmc * keep_every + num_burnin` iterations, with `num_mcmc` samples retained. If `num_gfr > 0`, each MCMC chain will be initialized from a separate GFR ensemble, with the requirement that `num_gfr >= num_chains`. Default: `1`. Note that if `num_chains > 1`, the returned model object will contain samples from all chains, stored consecutively. That is, if there are 4 chains with 100 samples each, the first 100 samples will be from chain 1, the next 100 samples will be from chain 2, etc... For more detail on working with multi-chain BART models, see the multi chain vignettes \code{vignette("Multiple-Chains", package = "stochtree")}.
4646
#' - `verbose` Whether or not to print progress during the sampling loops. Default: `FALSE`.
4747
#' - `probit_outcome_model` Whether or not the outcome should be modeled as explicitly binary via a probit link. If `TRUE`, `y` must only contain the values `0` and `1`. Default: `FALSE`.
4848
#' - `num_threads` Number of threads to use in the GFR and MCMC algorithms, as well as prediction. If OpenMP is not available on a user's setup, this will default to `1`, otherwise to the maximum number of available threads.
@@ -293,10 +293,36 @@ bart <- function(
293293

294294
# Check if previous model JSON is provided and parse it if so
295295
has_prev_model <- !is.null(previous_model_json)
296+
has_prev_model_index <- !is.null(previous_model_warmstart_sample_num)
296297
if (has_prev_model) {
297298
previous_bart_model <- createBARTModelFromJsonString(
298299
previous_model_json
299300
)
301+
prev_num_samples <- previous_bart_model$model_params$num_samples
302+
if (!has_prev_model_index) {
303+
previous_model_warmstart_sample_num <- prev_num_samples
304+
warning(
305+
"`previous_model_warmstart_sample_num` was not provided alongside `previous_model_json`, so it will be set to the number of samples available in `previous_model_json`"
306+
)
307+
} else {
308+
if (previous_model_warmstart_sample_num < 1) {
309+
stop(
310+
"`previous_model_warmstart_sample_num` must be a positive integer"
311+
)
312+
}
313+
if (previous_model_warmstart_sample_num > prev_num_samples) {
314+
stop(
315+
"`previous_model_warmstart_sample_num` exceeds the number of samples in `previous_model_json`"
316+
)
317+
}
318+
}
319+
previous_model_decrement <- T
320+
if (num_chains > previous_model_warmstart_sample_num) {
321+
warning(
322+
"The number of chains being sampled exceeds the number of previous model samples available from the requested position in `previous_model_json`. All chains will be initialized from the same sample."
323+
)
324+
previous_model_decrement <- F
325+
}
300326
previous_y_bar <- previous_bart_model$model_params$outcome_mean
301327
previous_y_scale <- previous_bart_model$model_params$outcome_scale
302328
if (previous_bart_model$model_params$include_mean_forest) {
@@ -1375,11 +1401,16 @@ bart <- function(
13751401
)
13761402
}
13771403
} else if (has_prev_model) {
1404+
warmstart_index <- ifelse(
1405+
previous_model_decrement,
1406+
previous_model_warmstart_sample_num - chain_num + 1,
1407+
previous_model_warmstart_sample_num
1408+
)
13781409
if (include_mean_forest) {
13791410
resetActiveForest(
13801411
active_forest_mean,
13811412
previous_forest_samples_mean,
1382-
previous_model_warmstart_sample_num - 1
1413+
warmstart_index - 1
13831414
)
13841415
resetForestModel(
13851416
forest_model_mean,
@@ -1393,7 +1424,7 @@ bart <- function(
13931424
(!is.null(previous_leaf_var_samples))
13941425
) {
13951426
leaf_scale_double <- previous_leaf_var_samples[
1396-
previous_model_warmstart_sample_num
1427+
warmstart_index
13971428
]
13981429
current_leaf_scale <- as.matrix(leaf_scale_double)
13991430
forest_model_config_mean$update_leaf_model_scale(
@@ -1405,7 +1436,7 @@ bart <- function(
14051436
resetActiveForest(
14061437
active_forest_variance,
14071438
previous_forest_samples_variance,
1408-
previous_model_warmstart_sample_num - 1
1439+
warmstart_index - 1
14091440
)
14101441
resetForestModel(
14111442
forest_model_variance,
@@ -1439,7 +1470,7 @@ bart <- function(
14391470
resetRandomEffectsModel(
14401471
rfx_model,
14411472
previous_rfx_samples,
1442-
previous_model_warmstart_sample_num - 1,
1473+
warmstart_index - 1,
14431474
sigma_alpha_init
14441475
)
14451476
resetRandomEffectsTracker(
@@ -1454,7 +1485,7 @@ bart <- function(
14541485
if (sample_sigma2_global) {
14551486
if (!is.null(previous_global_var_samples)) {
14561487
current_sigma2 <- previous_global_var_samples[
1457-
previous_model_warmstart_sample_num
1488+
warmstart_index
14581489
]
14591490
global_model_config$update_global_error_variance(
14601491
current_sigma2

R/bcf.R

Lines changed: 42 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
#' @param num_burnin Number of "burn-in" iterations of the MCMC sampler. Default: 0.
2626
#' @param num_mcmc Number of "retained" iterations of the MCMC sampler. Default: 100.
2727
#' @param previous_model_json (Optional) JSON string containing a previous BCF model. This can be used to "continue" a sampler interactively after inspecting the samples or to run parallel chains "warm-started" from existing forest samples. Default: `NULL`.
28-
#' @param previous_model_warmstart_sample_num (Optional) Sample number from `previous_model_json` that will be used to warmstart this BCF sampler. One-indexed (so that the first sample is used for warm-start by setting `previous_model_warmstart_sample_num = 1`). Default: `NULL`.
28+
#' @param previous_model_warmstart_sample_num (Optional) Sample number from `previous_model_json` that will be used to warmstart this BCF sampler. One-indexed (so that the first sample is used for warm-start by setting `previous_model_warmstart_sample_num = 1`). Default: `NULL`. If `num_chains` in the `general_params` list is > 1, then each successive chain will be initialized from a different sample, counting backwards from `previous_model_warmstart_sample_num`. That is, if `previous_model_warmstart_sample_num = 10` and `num_chains = 4`, then chain 1 will be initialized from sample 10, chain 2 from sample 9, chain 3 from sample 8, and chain 4 from sample 7. If `previous_model_json` is provided but `previous_model_warmstart_sample_num` is NULL, the last sample in the previous model will be used to initialize the first chain, counting backwards as noted before. If more chains are requested than there are samples in `previous_model_json`, a warning will be raised and only the last sample will be used.
2929
#' @param general_params (Optional) A list of general (non-forest-specific) model parameters, each of which has a default value processed internally, so this argument list is optional.
3030
#'
3131
#' - `cutpoint_grid_size` Maximum size of the "grid" of potential cutpoints to consider in the GFR algorithm. Default: `100`.
@@ -44,7 +44,7 @@
4444
#' - `keep_burnin` Whether or not "burnin" samples should be included in the stored samples of forests and other parameters. Default `FALSE`. Ignored if `num_mcmc = 0`.
4545
#' - `keep_gfr` Whether or not "grow-from-root" samples should be included in the stored samples of forests and other parameters. Default `FALSE`. Ignored if `num_mcmc = 0`.
4646
#' - `keep_every` How many iterations of the burned-in MCMC sampler should be run before forests and parameters are retained. Default `1`. Setting `keep_every <- k` for some `k > 1` will "thin" the MCMC samples by retaining every `k`-th sample, rather than simply every sample. This can reduce the autocorrelation of the MCMC samples.
47-
#' - `num_chains` How many independent MCMC chains should be sampled. If `num_mcmc = 0`, this is ignored. If `num_gfr = 0`, then each chain is run from root for `num_mcmc * keep_every + num_burnin` iterations, with `num_mcmc` samples retained. If `num_gfr > 0`, each MCMC chain will be initialized from a separate GFR ensemble, with the requirement that `num_gfr >= num_chains`. Default: `1`.
47+
#' - `num_chains` How many independent MCMC chains should be sampled. If `num_mcmc = 0`, this is ignored. If `num_gfr = 0`, then each chain is run from root for `num_mcmc * keep_every + num_burnin` iterations, with `num_mcmc` samples retained. If `num_gfr > 0`, each MCMC chain will be initialized from a separate GFR ensemble, with the requirement that `num_gfr >= num_chains`. Default: `1`. Note that if `num_chains > 1`, the returned model object will contain samples from all chains, stored consecutively. That is, if there are 4 chains with 100 samples each, the first 100 samples will be from chain 1, the next 100 samples will be from chain 2, etc... For more detail on working with multi-chain BCF models, see the multi chain vignettes \code{vignette("Multiple-Chains", package = "stochtree")}.
4848
#' - `verbose` Whether or not to print progress during the sampling loops. Default: `FALSE`.
4949
#' - `probit_outcome_model` Whether or not the outcome should be modeled as explicitly binary via a probit link. If `TRUE`, `y` must only contain the values `0` and `1`. Default: `FALSE`.
5050
#' - `num_threads` Number of threads to use in the GFR and MCMC algorithms, as well as prediction. If OpenMP is not available on a user's setup, this will default to `1`, otherwise to the maximum number of available threads.
@@ -397,8 +397,34 @@ bcf <- function(
397397

398398
# Check if previous model JSON is provided and parse it if so
399399
has_prev_model <- !is.null(previous_model_json)
400+
has_prev_model_index <- !is.null(previous_model_warmstart_sample_num)
400401
if (has_prev_model) {
401402
previous_bcf_model <- createBCFModelFromJsonString(previous_model_json)
403+
prev_num_samples <- previous_bcf_model$model_params$num_samples
404+
if (!has_prev_model_index) {
405+
previous_model_warmstart_sample_num <- prev_num_samples
406+
warning(
407+
"`previous_model_warmstart_sample_num` was not provided alongside `previous_model_json`, so it will be set to the number of samples available in `previous_model_json`"
408+
)
409+
} else {
410+
if (previous_model_warmstart_sample_num < 1) {
411+
stop(
412+
"`previous_model_warmstart_sample_num` must be a positive integer"
413+
)
414+
}
415+
if (previous_model_warmstart_sample_num > prev_num_samples) {
416+
stop(
417+
"`previous_model_warmstart_sample_num` exceeds the number of samples in `previous_model_json`"
418+
)
419+
}
420+
}
421+
previous_model_decrement <- T
422+
if (num_chains > previous_model_warmstart_sample_num) {
423+
warning(
424+
"The number of chains being sampled exceeds the number of previous model samples available from the requested position in `previous_model_json`. All chains will be initialized from the same sample."
425+
)
426+
previous_model_decrement <- F
427+
}
402428
previous_y_bar <- previous_bcf_model$model_params$outcome_mean
403429
previous_y_scale <- previous_bcf_model$model_params$outcome_scale
404430
previous_forest_samples_mu <- previous_bcf_model$forests_mu
@@ -1910,10 +1936,15 @@ bcf <- function(
19101936
)
19111937
}
19121938
} else if (has_prev_model) {
1939+
warmstart_index <- ifelse(
1940+
previous_model_decrement,
1941+
previous_model_warmstart_sample_num - chain_num + 1,
1942+
previous_model_warmstart_sample_num
1943+
)
19131944
resetActiveForest(
19141945
active_forest_mu,
19151946
previous_forest_samples_mu,
1916-
previous_model_warmstart_sample_num - 1
1947+
warmstart_index - 1
19171948
)
19181949
resetForestModel(
19191950
forest_model_mu,
@@ -1925,7 +1956,7 @@ bcf <- function(
19251956
resetActiveForest(
19261957
active_forest_tau,
19271958
previous_forest_samples_tau,
1928-
previous_model_warmstart_sample_num - 1
1959+
warmstart_index - 1
19291960
)
19301961
resetForestModel(
19311962
forest_model_tau,
@@ -1938,7 +1969,7 @@ bcf <- function(
19381969
resetActiveForest(
19391970
active_forest_variance,
19401971
previous_forest_samples_variance,
1941-
previous_model_warmstart_sample_num - 1
1972+
warmstart_index - 1
19421973
)
19431974
resetForestModel(
19441975
forest_model_variance,
@@ -1953,7 +1984,7 @@ bcf <- function(
19531984
(!is.null(previous_leaf_var_mu_samples))
19541985
) {
19551986
leaf_scale_mu_double <- previous_leaf_var_mu_samples[
1956-
previous_model_warmstart_sample_num
1987+
warmstart_index
19571988
]
19581989
current_leaf_scale_mu <- as.matrix(leaf_scale_mu_double)
19591990
forest_model_config_mu$update_leaf_model_scale(
@@ -1965,7 +1996,7 @@ bcf <- function(
19651996
(!is.null(previous_leaf_var_tau_samples))
19661997
) {
19671998
leaf_scale_tau_double <- previous_leaf_var_tau_samples[
1968-
previous_model_warmstart_sample_num
1999+
warmstart_index
19692000
]
19702001
current_leaf_scale_tau <- as.matrix(leaf_scale_tau_double)
19712002
forest_model_config_tau$update_leaf_model_scale(
@@ -1975,12 +2006,12 @@ bcf <- function(
19752006
if (adaptive_coding) {
19762007
if (!is.null(previous_b_1_samples)) {
19772008
current_b_1 <- previous_b_1_samples[
1978-
previous_model_warmstart_sample_num
2009+
warmstart_index
19792010
]
19802011
}
19812012
if (!is.null(previous_b_0_samples)) {
19822013
current_b_0 <- previous_b_0_samples[
1983-
previous_model_warmstart_sample_num
2014+
warmstart_index
19842015
]
19852016
}
19862017
tau_basis_train <- (1 - Z_train) *
@@ -2023,7 +2054,7 @@ bcf <- function(
20232054
resetRandomEffectsModel(
20242055
rfx_model,
20252056
previous_rfx_samples,
2026-
previous_model_warmstart_sample_num - 1,
2057+
warmstart_index - 1,
20272058
sigma_alpha_init
20282059
)
20292060
resetRandomEffectsTracker(
@@ -2038,7 +2069,7 @@ bcf <- function(
20382069
if (sample_sigma2_global) {
20392070
if (!is.null(previous_global_var_samples)) {
20402071
current_sigma2 <- previous_global_var_samples[
2041-
previous_model_warmstart_sample_num
2072+
warmstart_index
20422073
]
20432074
}
20442075
global_model_config$update_global_error_variance(

0 commit comments

Comments
 (0)