Skip to content

Commit 5035b44

Browse files
committed
Updated multi-chain in R to initialize different chains from different forests if multiple chains requested
1 parent 9ddc9ba commit 5035b44

File tree

2 files changed

+61
-15
lines changed

2 files changed

+61
-15
lines changed

stochtree/bart.py

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ def sample(
184184
previous_model_json : str, optional
185185
JSON string containing a previous BART model. This can be used to "continue" a sampler interactively after inspecting the samples or to run parallel chains "warm-started" from existing forest samples. Defaults to `None`.
186186
previous_model_warmstart_sample_num : int, optional
187-
Sample number from `previous_model_json` that will be used to warmstart this BART sampler. Zero-indexed (so that the first sample is used for warm-start by setting `previous_model_warmstart_sample_num = 0`). Defaults to `None`.
187+
Sample number from `previous_model_json` that will be used to warmstart this BART sampler. Zero-indexed (so that the first sample is used for warm-start by setting `previous_model_warmstart_sample_num = 0`). Defaults to `None`. If `num_chains` in the `general_params` list is > 1, then each successive chain will be initialized from a different sample, counting backwards from `previous_model_warmstart_sample_num`. That is, if `previous_model_warmstart_sample_num = 10` and `num_chains = 4`, then chain 1 will be initialized from sample 10, chain 2 from sample 9, chain 3 from sample 8, and chain 4 from sample 7. If `previous_model_json` is provided but `previous_model_warmstart_sample_num` is NULL, the last sample in the previous model will be used to initialize the first chain, counting backwards as noted before. If more chains are requested than there are samples in `previous_model_json`, a warning will be raised and only the last sample will be used.
188188
189189
Returns
190190
-------
@@ -753,6 +753,7 @@ def sample(
753753

754754
# Check if previous model JSON is provided and parse it if so
755755
has_prev_model = previous_model_json is not None
756+
has_prev_model_index = previous_model_warmstart_sample_num is not None
756757
if has_prev_model:
757758
if num_gfr > 0:
758759
if num_mcmc == 0:
@@ -765,6 +766,27 @@ def sample(
765766
)
766767
previous_bart_model = BARTModel()
767768
previous_bart_model.from_json(previous_model_json)
769+
prev_num_samples = previous_bart_model.num_samples
770+
if not has_prev_model_index:
771+
previous_model_warmstart_sample_num = prev_num_samples
772+
warnings.warn(
773+
"`previous_model_warmstart_sample_num` was not provided alongside `previous_model_json`, so it will be set to the number of samples available in `previous_model_json`"
774+
)
775+
else:
776+
if previous_model_warmstart_sample_num < 0:
777+
raise ValueError(
778+
"`previous_model_warmstart_sample_num` must be a nonnegative integer"
779+
)
780+
if previous_model_warmstart_sample_num >= prev_num_samples:
781+
raise ValueError(
782+
"`previous_model_warmstart_sample_num` exceeds the number of samples in `previous_model_json`"
783+
)
784+
previous_model_decrement = True
785+
if num_chains > previous_model_warmstart_sample_num + 1:
786+
warnings.warn(
787+
"The number of chains being sampled exceeds the number of previous model samples available from the requested position in `previous_model_json`. All chains will be initialized from the same sample."
788+
)
789+
previous_model_decrement = False
768790
previous_y_scale = previous_bart_model.y_std
769791
previous_model_num_samples = previous_bart_model.num_samples
770792
if previous_bart_model.sample_sigma2_global:
@@ -1423,11 +1445,12 @@ def sample(
14231445
rfx_model.reset(self.rfx_container, forest_ind, sigma_alpha_init)
14241446
rfx_tracker.reset(rfx_model, rfx_dataset_train, residual_train, self.rfx_container)
14251447
elif has_prev_model:
1448+
warmstart_index = previous_model_warmstart_sample_num - chain_num if previous_model_decrement else previous_model_warmstart_sample_num
14261449
# Reset mean forest
14271450
if self.include_mean_forest:
14281451
active_forest_mean.reset(
14291452
previous_bart_model.forest_container_mean,
1430-
previous_model_warmstart_sample_num,
1453+
warmstart_index,
14311454
)
14321455
forest_sampler_mean.reconstitute_from_forest(
14331456
active_forest_mean,
@@ -1438,7 +1461,7 @@ def sample(
14381461
# Reset leaf scale
14391462
if sample_sigma2_leaf and previous_leaf_var_samples is not None:
14401463
leaf_scale_double = previous_leaf_var_samples[
1441-
previous_model_warmstart_sample_num
1464+
warmstart_index
14421465
]
14431466
current_leaf_scale[0, 0] = leaf_scale_double
14441467
forest_model_config_mean.update_leaf_model_scale(
@@ -1448,7 +1471,7 @@ def sample(
14481471
if self.include_variance_forest:
14491472
active_forest_variance.reset(
14501473
previous_bart_model.forest_container_variance,
1451-
previous_model_warmstart_sample_num,
1474+
warmstart_index,
14521475
)
14531476
forest_sampler_variance.reconstitute_from_forest(
14541477
active_forest_variance,
@@ -1459,12 +1482,12 @@ def sample(
14591482
# Reset global error scale
14601483
if self.sample_sigma2_global:
14611484
current_sigma2 = previous_global_var_samples[
1462-
previous_model_warmstart_sample_num
1485+
warmstart_index
14631486
]
14641487
global_model_config.update_global_error_variance(current_sigma2)
14651488
# Reset random effects
14661489
if self.has_rfx:
1467-
rfx_model.reset(previous_bart_model.rfx_container, forest_ind, sigma_alpha_init)
1490+
rfx_model.reset(previous_bart_model.rfx_container, warmstart_index, sigma_alpha_init)
14681491
rfx_tracker.reset(rfx_model, rfx_dataset_train, residual_train, previous_bart_model.rfx_container)
14691492
else:
14701493
# Reset mean forest

stochtree/bcf.py

Lines changed: 32 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,7 @@ def sample(
226226
previous_model_json : str, optional
227227
JSON string containing a previous BCF model. This can be used to "continue" a sampler interactively after inspecting the samples or to run parallel chains "warm-started" from existing forest samples. Defaults to `None`.
228228
previous_model_warmstart_sample_num : int, optional
229-
Sample number from `previous_model_json` that will be used to warmstart this BART sampler. Zero-indexed (so that the first sample is used for warm-start by setting `previous_model_warmstart_sample_num = 0`). Defaults to `None`.
229+
Sample number from `previous_model_json` that will be used to warmstart this BART sampler. Zero-indexed (so that the first sample is used for warm-start by setting `previous_model_warmstart_sample_num = 0`). Defaults to `None`. If `num_chains` in the `general_params` list is > 1, then each successive chain will be initialized from a different sample, counting backwards from `previous_model_warmstart_sample_num`. That is, if `previous_model_warmstart_sample_num = 10` and `num_chains = 4`, then chain 1 will be initialized from sample 10, chain 2 from sample 9, chain 3 from sample 8, and chain 4 from sample 7. If `previous_model_json` is provided but `previous_model_warmstart_sample_num` is NULL, the last sample in the previous model will be used to initialize the first chain, counting backwards as noted before. If more chains are requested than there are samples in `previous_model_json`, a warning will be raised and only the last sample will be used.
230230
231231
Returns
232232
-------
@@ -462,6 +462,7 @@ def sample(
462462

463463
# Check if previous model JSON is provided and parse it if so
464464
has_prev_model = previous_model_json is not None
465+
has_prev_model_index = previous_model_warmstart_sample_num is not None
465466
if has_prev_model:
466467
if num_gfr > 0:
467468
if num_mcmc == 0:
@@ -474,6 +475,27 @@ def sample(
474475
)
475476
previous_bcf_model = BCFModel()
476477
previous_bcf_model.from_json(previous_model_json)
478+
prev_num_samples = previous_bcf_model.num_samples
479+
if not has_prev_model_index:
480+
previous_model_warmstart_sample_num = prev_num_samples
481+
warnings.warn(
482+
"`previous_model_warmstart_sample_num` was not provided alongside `previous_model_json`, so it will be set to the number of samples available in `previous_model_json`"
483+
)
484+
else:
485+
if previous_model_warmstart_sample_num < 0:
486+
raise ValueError(
487+
"`previous_model_warmstart_sample_num` must be a nonnegative integer"
488+
)
489+
if previous_model_warmstart_sample_num >= prev_num_samples:
490+
raise ValueError(
491+
"`previous_model_warmstart_sample_num` exceeds the number of samples in `previous_model_json`"
492+
)
493+
previous_model_decrement = True
494+
if num_chains > previous_model_warmstart_sample_num + 1:
495+
warnings.warn(
496+
"The number of chains being sampled exceeds the number of previous model samples available from the requested position in `previous_model_json`. All chains will be initialized from the same sample."
497+
)
498+
previous_model_decrement = False
477499
previous_y_scale = previous_bcf_model.y_std
478500
previous_model_num_samples = previous_bcf_model.num_samples
479501
if previous_bcf_model.sample_sigma2_global:
@@ -2171,8 +2193,9 @@ def sample(
21712193
rfx_model.reset(self.rfx_container, forest_ind, sigma_alpha_init)
21722194
rfx_tracker.reset(rfx_model, rfx_dataset_train, residual_train, self.rfx_container)
21732195
elif has_prev_model:
2196+
warmstart_index = previous_model_warmstart_sample_num - chain_num if previous_model_decrement else previous_model_warmstart_sample_num
21742197
# Reset prognostic forest
2175-
active_forest_mu.reset(previous_bcf_model.forest_container_mu, previous_model_warmstart_sample_num)
2198+
active_forest_mu.reset(previous_bcf_model.forest_container_mu, warmstart_index)
21762199
forest_sampler_mu.reconstitute_from_forest(
21772200
active_forest_mu,
21782201
forest_dataset_train,
@@ -2191,7 +2214,7 @@ def sample(
21912214
if self.include_variance_forest:
21922215
active_forest_variance.reset(
21932216
previous_bcf_model.forest_container_variance,
2194-
previous_model_warmstart_sample_num,
2217+
warmstart_index,
21952218
)
21962219
forest_sampler_variance.reconstitute_from_forest(
21972220
active_forest_variance,
@@ -2202,13 +2225,13 @@ def sample(
22022225
# Reset global error scale
22032226
if self.sample_sigma2_global:
22042227
current_sigma2 = previous_global_var_samples[
2205-
previous_model_warmstart_sample_num
2228+
warmstart_index
22062229
]
22072230
global_model_config.update_global_error_variance(current_sigma2)
22082231
# Reset mu forest leaf scale
22092232
if sample_sigma2_leaf_mu and previous_leaf_var_mu_samples is not None:
22102233
leaf_scale_double_mu = previous_leaf_var_mu_samples[
2211-
previous_model_warmstart_sample_num
2234+
warmstart_index
22122235
]
22132236
current_leaf_scale_mu[0, 0] = leaf_scale_double_mu
22142237
forest_model_config_mu.update_leaf_model_scale(
@@ -2217,7 +2240,7 @@ def sample(
22172240
# Reset mu forest leaf scale
22182241
if sample_sigma2_leaf_tau and previous_leaf_var_tau_samples is not None:
22192242
leaf_scale_double_tau = previous_leaf_var_tau_samples[
2220-
previous_model_warmstart_sample_num
2243+
warmstart_index
22212244
]
22222245
current_leaf_scale_tau[0, 0] = leaf_scale_double_tau
22232246
forest_model_config_tau.update_leaf_model_scale(
@@ -2226,9 +2249,9 @@ def sample(
22262249
# Reset adaptive coding parameters
22272250
if self.adaptive_coding:
22282251
if previous_b0_samples is not None:
2229-
current_b_0 = previous_b0_samples[previous_model_warmstart_sample_num]
2252+
current_b_0 = previous_b0_samples[warmstart_index]
22302253
if previous_b1_samples is not None:
2231-
current_b_1 = previous_b1_samples[previous_model_warmstart_sample_num]
2254+
current_b_1 = previous_b1_samples[warmstart_index]
22322255
tau_basis_train = (
22332256
1 - np.squeeze(Z_train)
22342257
) * current_b_0 + np.squeeze(Z_train) * current_b_1
@@ -2243,7 +2266,7 @@ def sample(
22432266
)
22442267
# Reset random effects terms
22452268
if self.has_rfx:
2246-
rfx_model.reset(previous_bcf_model.rfx_container, forest_ind, sigma_alpha_init)
2269+
rfx_model.reset(previous_bcf_model.rfx_container, warmstart_index, sigma_alpha_init)
22472270
rfx_tracker.reset(rfx_model, rfx_dataset_train, residual_train, previous_bcf_model.rfx_container)
22482271
else:
22492272
# Reset prognostic forest

0 commit comments

Comments
 (0)