Skip to content

Commit c606754

Browse files
committed
Updated multi-chain demos
1 parent 5035b44 commit c606754

File tree

5 files changed

+1312
-312
lines changed

5 files changed

+1312
-312
lines changed

demo/debug/multi_chain.py

Lines changed: 30 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
1-
# Multi Chain Demo Script
1+
# Multiple Initializations Demo Script
22

33
# Load necessary libraries
44
import matplotlib.pyplot as plt
55
import numpy as np
66
import pandas as pd
7-
import seaborn as sns
7+
import arviz as az
88
from sklearn.model_selection import train_test_split
99

1010
from stochtree import BARTModel
@@ -37,7 +37,9 @@ def outcome_mean(X, W):
3737

3838
# Generate outcome
3939
f_XW = outcome_mean(X, W)
40-
epsilon = rng.normal(0, 1, n)
40+
snr = 3
41+
noise_sd = np.std(f_XW) / snr
42+
epsilon = rng.normal(0, noise_sd, n)
4143
y = f_XW + epsilon
4244

4345
# Test-train split
@@ -53,116 +55,51 @@ def outcome_mean(X, W):
5355
y_test = y[test_inds]
5456

5557
# Run the GFR algorithm for a small number of iterations
56-
general_model_params = {"random_seed": -1}
57-
mean_forest_model_params = {"num_trees": 20}
5858
num_warmstart = 10
59-
num_mcmc = 10
60-
bart_model = BARTModel()
61-
bart_model.sample(
59+
xbart_model = BARTModel()
60+
xbart_model.sample(
6261
X_train=X_train,
6362
y_train=y_train,
6463
leaf_basis_train=basis_train,
6564
X_test=X_test,
6665
leaf_basis_test=basis_test,
6766
num_gfr=num_warmstart,
6867
num_mcmc=0,
69-
general_params=general_model_params,
70-
mean_forest_params=mean_forest_model_params,
7168
)
72-
bart_model_json = bart_model.to_json()
69+
xbart_model_json = xbart_model.to_json()
7370

74-
# Run several BART MCMC samples from the last GFR forest
75-
bart_model_2 = BARTModel()
76-
bart_model_2.sample(
71+
# Run several BART MCMC chains from the last GFR forest
72+
num_mcmc = 5000
73+
num_burnin = 2000
74+
num_chains = 4
75+
bart_model = BARTModel()
76+
bart_model.sample(
7777
X_train=X_train,
7878
y_train=y_train,
7979
leaf_basis_train=basis_train,
8080
X_test=X_test,
8181
leaf_basis_test=basis_test,
8282
num_gfr=0,
83+
num_burnin=num_burnin,
8384
num_mcmc=num_mcmc,
84-
previous_model_json=bart_model_json,
85+
previous_model_json=xbart_model_json,
8586
previous_model_warmstart_sample_num=num_warmstart - 1,
86-
general_params=general_model_params,
87-
mean_forest_params=mean_forest_model_params,
88-
)
89-
90-
# Run several BART MCMC samples from the second-to-last GFR forest
91-
bart_model_3 = BARTModel()
92-
bart_model_3.sample(
93-
X_train=X_train,
94-
y_train=y_train,
95-
leaf_basis_train=basis_train,
96-
X_test=X_test,
97-
leaf_basis_test=basis_test,
98-
num_gfr=0,
99-
num_mcmc=num_mcmc,
100-
previous_model_json=bart_model_json,
101-
previous_model_warmstart_sample_num=num_warmstart - 2,
102-
general_params=general_model_params,
103-
mean_forest_params=mean_forest_model_params,
104-
)
105-
106-
# Run several BART MCMC samples from root
107-
bart_model_4 = BARTModel()
108-
bart_model_4.sample(
109-
X_train=X_train,
110-
y_train=y_train,
111-
leaf_basis_train=basis_train,
112-
X_test=X_test,
113-
leaf_basis_test=basis_test,
114-
num_gfr=0,
115-
num_mcmc=num_mcmc,
116-
general_params=general_model_params,
117-
mean_forest_params=mean_forest_model_params,
87+
general_params={"num_chains": num_chains}
11888
)
11989

120-
# Inspect the model outputs
121-
bart_preds_2 = bart_model_2.predict(X_test, basis_test)
122-
y_hat_mcmc_2 = bart_preds_2['y_hat']
123-
y_avg_mcmc_2 = np.squeeze(y_hat_mcmc_2).mean(axis=1, keepdims=True)
124-
y_avg_mcmc_2 = np.squeeze(y_hat_mcmc_2).mean(axis=1, keepdims=True)
125-
bart_preds_3 = bart_model_3.predict(X_test, basis_test)
126-
y_hat_mcmc_3 = bart_preds_3['y_hat']
127-
y_avg_mcmc_3 = np.squeeze(y_hat_mcmc_3).mean(axis=1, keepdims=True)
128-
bart_preds_4 = bart_model_4.predict(X_test, basis_test)
129-
y_hat_mcmc_4 = bart_preds_4['y_hat']
130-
y_avg_mcmc_4 = np.squeeze(y_hat_mcmc_4).mean(axis=1, keepdims=True)
131-
y_df = pd.DataFrame(
132-
np.concatenate(
133-
(y_avg_mcmc_2, y_avg_mcmc_3, y_avg_mcmc_4, np.expand_dims(y_test, axis=1)),
134-
axis=1,
135-
),
136-
columns=["First Chain", "Second Chain", "Third Chain", "Outcome"],
90+
# Analyze model predictions collectively across all chains
91+
y_hat_test = bart_model.predict(
92+
covariates = X_test,
93+
basis = basis_test,
94+
type = "mean",
95+
terms = "y_hat"
13796
)
138-
139-
# Compare first warm-start chain to root chain with equal number of MCMC draws
140-
sns.scatterplot(data=y_df, x="First Chain", y="Third Chain")
141-
plt.axline((0, 0), slope=1, color="black", linestyle=(0, (3, 3)))
142-
plt.show()
143-
144-
# Compare first warm-start chain to outcome
145-
sns.scatterplot(data=y_df, x="First Chain", y="Outcome")
146-
plt.axline((0, 0), slope=1, color="black", linestyle=(0, (3, 3)))
147-
plt.show()
148-
149-
# Compare root chain to outcome
150-
sns.scatterplot(data=y_df, x="Third Chain", y="Outcome")
97+
plt.scatter(y_hat_test, y_test)
98+
plt.xlabel("Estimated conditional mean")
99+
plt.ylabel("Actual outcome")
151100
plt.axline((0, 0), slope=1, color="black", linestyle=(0, (3, 3)))
152-
plt.show()
153101

154-
# Compute RMSEs
155-
rmse_1 = np.sqrt(
156-
np.mean((np.squeeze(y_avg_mcmc_2) - y_test) * (np.squeeze(y_avg_mcmc_2) - y_test))
157-
)
158-
rmse_2 = np.sqrt(
159-
np.mean((np.squeeze(y_avg_mcmc_3) - y_test) * (np.squeeze(y_avg_mcmc_3) - y_test))
160-
)
161-
rmse_3 = np.sqrt(
162-
np.mean((np.squeeze(y_avg_mcmc_4) - y_test) * (np.squeeze(y_avg_mcmc_4) - y_test))
163-
)
164-
print(
165-
"Chain 1 rmse: {:0.3f}; Chain 2 rmse: {:0.3f}; Chain 3 rmse: {:0.3f}".format(
166-
rmse_1, rmse_2, rmse_3
167-
)
168-
)
102+
# Analyze each chain's parameter samples
103+
sigma2_samples = bart_model.global_var_samples
104+
sigma2_samples_by_chain = {"sigma2": np.reshape(sigma2_samples, (num_chains, num_mcmc))}
105+
az.plot_trace(sigma2_samples_by_chain)
Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
# Multiple Initializations Demo Script
2+
3+
# Load necessary libraries
4+
import matplotlib.pyplot as plt
5+
import numpy as np
6+
import pandas as pd
7+
import seaborn as sns
8+
from sklearn.model_selection import train_test_split
9+
10+
from stochtree import BARTModel
11+
12+
# Generate sample data
13+
# RNG
14+
random_seed = 1234
15+
rng = np.random.default_rng(random_seed)
16+
17+
# Generate covariates and basis
18+
n = 500
19+
p_X = 10
20+
p_W = 1
21+
X = rng.uniform(0, 1, (n, p_X))
22+
W = rng.uniform(0, 1, (n, p_W))
23+
24+
25+
# Define the outcome mean function
26+
def outcome_mean(X, W):
27+
return np.where(
28+
(X[:, 0] >= 0.0) & (X[:, 0] < 0.25),
29+
-7.5 * W[:, 0],
30+
np.where(
31+
(X[:, 0] >= 0.25) & (X[:, 0] < 0.5),
32+
-2.5 * W[:, 0],
33+
np.where((X[:, 0] >= 0.5) & (X[:, 0] < 0.75), 2.5 * W[:, 0], 7.5 * W[:, 0]),
34+
),
35+
)
36+
37+
38+
# Generate outcome
39+
f_XW = outcome_mean(X, W)
40+
epsilon = rng.normal(0, 1, n)
41+
y = f_XW + epsilon
42+
43+
# Test-train split
44+
sample_inds = np.arange(n)
45+
train_inds, test_inds = train_test_split(
46+
sample_inds, test_size=0.5, random_state=random_seed
47+
)
48+
X_train = X[train_inds, :]
49+
X_test = X[test_inds, :]
50+
basis_train = W[train_inds, :]
51+
basis_test = W[test_inds, :]
52+
y_train = y[train_inds]
53+
y_test = y[test_inds]
54+
55+
# Run the GFR algorithm for a small number of iterations
56+
general_model_params = {"random_seed": -1}
57+
mean_forest_model_params = {"num_trees": 20}
58+
num_warmstart = 10
59+
num_mcmc = 10
60+
bart_model = BARTModel()
61+
bart_model.sample(
62+
X_train=X_train,
63+
y_train=y_train,
64+
leaf_basis_train=basis_train,
65+
X_test=X_test,
66+
leaf_basis_test=basis_test,
67+
num_gfr=num_warmstart,
68+
num_mcmc=0,
69+
general_params=general_model_params,
70+
mean_forest_params=mean_forest_model_params,
71+
)
72+
bart_model_json = bart_model.to_json()
73+
74+
# Run several BART MCMC chains from the last GFR forest
75+
bart_model_2 = BARTModel()
76+
bart_model_2.sample(
77+
X_train=X_train,
78+
y_train=y_train,
79+
leaf_basis_train=basis_train,
80+
X_test=X_test,
81+
leaf_basis_test=basis_test,
82+
num_gfr=0,
83+
num_mcmc=num_mcmc,
84+
previous_model_json=bart_model_json,
85+
previous_model_warmstart_sample_num=num_warmstart - 1,
86+
general_params=general_model_params,
87+
mean_forest_params=mean_forest_model_params,
88+
)
89+
90+
# Run several BART MCMC samples from the second-to-last GFR forest
91+
bart_model_3 = BARTModel()
92+
bart_model_3.sample(
93+
X_train=X_train,
94+
y_train=y_train,
95+
leaf_basis_train=basis_train,
96+
X_test=X_test,
97+
leaf_basis_test=basis_test,
98+
num_gfr=0,
99+
num_mcmc=num_mcmc,
100+
previous_model_json=bart_model_json,
101+
previous_model_warmstart_sample_num=num_warmstart - 2,
102+
general_params=general_model_params,
103+
mean_forest_params=mean_forest_model_params,
104+
)
105+
106+
# Run several BART MCMC samples from root
107+
bart_model_4 = BARTModel()
108+
bart_model_4.sample(
109+
X_train=X_train,
110+
y_train=y_train,
111+
leaf_basis_train=basis_train,
112+
X_test=X_test,
113+
leaf_basis_test=basis_test,
114+
num_gfr=0,
115+
num_mcmc=num_mcmc,
116+
general_params=general_model_params,
117+
mean_forest_params=mean_forest_model_params,
118+
)
119+
120+
# Inspect the model outputs
121+
bart_preds_2 = bart_model_2.predict(X_test, basis_test)
122+
y_hat_mcmc_2 = bart_preds_2['y_hat']
123+
y_avg_mcmc_2 = np.squeeze(y_hat_mcmc_2).mean(axis=1, keepdims=True)
124+
y_avg_mcmc_2 = np.squeeze(y_hat_mcmc_2).mean(axis=1, keepdims=True)
125+
bart_preds_3 = bart_model_3.predict(X_test, basis_test)
126+
y_hat_mcmc_3 = bart_preds_3['y_hat']
127+
y_avg_mcmc_3 = np.squeeze(y_hat_mcmc_3).mean(axis=1, keepdims=True)
128+
bart_preds_4 = bart_model_4.predict(X_test, basis_test)
129+
y_hat_mcmc_4 = bart_preds_4['y_hat']
130+
y_avg_mcmc_4 = np.squeeze(y_hat_mcmc_4).mean(axis=1, keepdims=True)
131+
y_df = pd.DataFrame(
132+
np.concatenate(
133+
(y_avg_mcmc_2, y_avg_mcmc_3, y_avg_mcmc_4, np.expand_dims(y_test, axis=1)),
134+
axis=1,
135+
),
136+
columns=["First Chain", "Second Chain", "Third Chain", "Outcome"],
137+
)
138+
139+
# Compare first warm-start chain to root chain with equal number of MCMC draws
140+
sns.scatterplot(data=y_df, x="First Chain", y="Third Chain")
141+
plt.axline((0, 0), slope=1, color="black", linestyle=(0, (3, 3)))
142+
plt.show()
143+
144+
# Compare first warm-start chain to outcome
145+
sns.scatterplot(data=y_df, x="First Chain", y="Outcome")
146+
plt.axline((0, 0), slope=1, color="black", linestyle=(0, (3, 3)))
147+
plt.show()
148+
149+
# Compare root chain to outcome
150+
sns.scatterplot(data=y_df, x="Third Chain", y="Outcome")
151+
plt.axline((0, 0), slope=1, color="black", linestyle=(0, (3, 3)))
152+
plt.show()
153+
154+
# Compute RMSEs
155+
rmse_1 = np.sqrt(
156+
np.mean((np.squeeze(y_avg_mcmc_2) - y_test) * (np.squeeze(y_avg_mcmc_2) - y_test))
157+
)
158+
rmse_2 = np.sqrt(
159+
np.mean((np.squeeze(y_avg_mcmc_3) - y_test) * (np.squeeze(y_avg_mcmc_3) - y_test))
160+
)
161+
rmse_3 = np.sqrt(
162+
np.mean((np.squeeze(y_avg_mcmc_4) - y_test) * (np.squeeze(y_avg_mcmc_4) - y_test))
163+
)
164+
print(
165+
"Chain 1 rmse: {:0.3f}; Chain 2 rmse: {:0.3f}; Chain 3 rmse: {:0.3f}".format(
166+
rmse_1, rmse_2, rmse_3
167+
)
168+
)

0 commit comments

Comments
 (0)