@@ -101,9 +101,10 @@ class PGBART(ArrayStepShared):
101101 Number of particles for the conditional SMC sampler. Defaults to 10
102102 max_stages : int
103103 Maximum number of iterations of the conditional SMC sampler. Defaults to 100.
104- batch : int
104+ batch : int or tuple
105105 Number of trees fitted per step. Defaults to "auto", which is the 10% of the `m` trees
106- during tuning and 20% after tuning.
106+ during tuning and 20% after tuning. If a tuple is passed the first element is the batch size
107+ during tuning and the second the batch size after tuning.
107108 model: PyMC Model
108109 Optional model for sampling step. Defaults to None (taken from context).
109110
@@ -138,9 +139,9 @@ def __init__(self, vars=None, num_particles=10, max_stages=100, batch="auto", mo
138139 self .alpha = self .bart .alpha
139140 self .k = self .bart .k
140141 self .response = self .bart .response
141- self .split_prior = self .bart .split_prior
142- if self .split_prior is None :
143- self .split_prior = np .ones (self .X .shape [1 ])
142+ self .alpha_vec = self .bart .split_prior
143+ if self .alpha_vec is None :
144+ self .alpha_vec = np .ones (self .X .shape [1 ])
144145
145146 self .init_mean = self .Y .mean ()
146147 # if data is binary
@@ -149,7 +150,7 @@ def __init__(self, vars=None, num_particles=10, max_stages=100, batch="auto", mo
149150 self .mu_std = 6 / (self .k * self .m ** 0.5 )
150151 # maybe we need to check for count data
151152 else :
152- self .mu_std = self .Y .std () / (self .k * self .m ** 0.5 )
153+ self .mu_std = ( 2 * self .Y .std () ) / (self .k * self .m ** 0.5 )
153154
154155 self .num_observations = self .X .shape [0 ]
155156 self .num_variates = self .X .shape [1 ]
@@ -167,14 +168,18 @@ def __init__(self, vars=None, num_particles=10, max_stages=100, batch="auto", mo
167168
168169 self .normal = NormalSampler ()
169170 self .prior_prob_leaf_node = compute_prior_probability (self .alpha )
170- self .ssv = SampleSplittingVariable (self .split_prior )
171+ self .ssv = SampleSplittingVariable (self .alpha_vec )
171172
172173 self .tune = True
173- self .idx = 0
174- self .batch = batch
175174
176- if self .batch == "auto" :
177- self .batch = max (1 , int (self .m * 0.1 ))
175+ if batch == "auto" :
176+ self .batch = (max (1 , int (self .m * 0.1 )), max (1 , int (self .m * 0.2 )))
177+ else :
178+ if isinstance (batch , (tuple , list )):
179+ self .batch = batch
180+ else :
181+ self .batch = (batch , batch )
182+
178183 self .log_num_particles = np .log (num_particles )
179184 self .indices = list (range (1 , num_particles ))
180185 self .len_indices = len (self .indices )
@@ -187,6 +192,9 @@ def __init__(self, vars=None, num_particles=10, max_stages=100, batch="auto", mo
187192 self .all_particles = []
188193 for i in range (self .m ):
189194 self .a_tree .tree_id = i
195+ self .a_tree .leaf_node_value = (
196+ self .init_mean / self .m + self .normal .random () * self .mu_std ,
197+ )
190198 p = ParticleTree (
191199 self .a_tree ,
192200 self .init_log_weight ,
@@ -201,20 +209,16 @@ def astep(self, q: RaveledVars) -> Tuple[RaveledVars, List[Dict[str, Any]]]:
201209 sum_trees_output = q .data
202210 variable_inclusion = np .zeros (self .num_variates , dtype = "int" )
203211
204- if self .idx == self .m :
205- self .idx = 0
206-
207- for tree_id in range (self .idx , self .idx + self .batch ):
208- if tree_id >= self .m :
209- break
212+ tree_ids = np .random .randint (0 , self .m , size = self .batch [~ self .tune ])
213+ for tree_id in tree_ids :
210214 # Generate an initial set of SMC particles
211215 # at the end of the algorithm we return one of these particles as the new tree
212216 particles = self .init_particles (tree_id )
213217 # Compute the sum of trees without the tree we are attempting to replace
214218 self .sum_trees_output_noi = sum_trees_output - particles [0 ].tree .predict_output ()
215219
216220 # The old tree is not growing so we update the weights only once.
217- self .update_weight (particles [0 ])
221+ self .update_weight (particles [0 ], new = True )
218222 for t in range (self .max_stages ):
219223 # Sample each particle (try to grow each tree), except for the first one.
220224 for p in particles [1 :]:
@@ -235,15 +239,15 @@ def astep(self, q: RaveledVars) -> Tuple[RaveledVars, List[Dict[str, Any]]]:
235239 if tree_grew :
236240 self .update_weight (p )
237241 # Normalize weights
238- W_t , normalized_weights = self .normalize (particles )
242+ W_t , normalized_weights = self .normalize (particles [ 1 :] )
239243
240244 # Resample all but first particle
241- re_n_w = normalized_weights [ 1 :] / normalized_weights [ 1 :]. sum ()
245+ re_n_w = normalized_weights
242246 new_indices = np .random .choice (self .indices , size = self .len_indices , p = re_n_w )
243247 particles [1 :] = particles [new_indices ]
244248
245249 # Set the new weights
246- for p in particles :
250+ for p in particles [ 1 :] :
247251 p .log_weight = W_t
248252
249253 # Check if particles can keep growing, otherwise stop iterating
@@ -254,23 +258,25 @@ def astep(self, q: RaveledVars) -> Tuple[RaveledVars, List[Dict[str, Any]]]:
254258 if all (non_available_nodes_for_expansion ):
255259 break
256260
261+ for p in particles [1 :]:
262+ p .log_weight = p .old_likelihood_logp
263+
264+ _ , normalized_weights = self .normalize (particles )
257265 # Get the new tree and update
258266 new_particle = np .random .choice (particles , p = normalized_weights )
259267 new_tree = new_particle .tree
260- self .all_trees [self . idx ] = new_tree
268+ self .all_trees [tree_id ] = new_tree
261269 new_particle .log_weight = new_particle .old_likelihood_logp - self .log_num_particles
262270 self .all_particles [tree_id ] = new_particle
263271 sum_trees_output = self .sum_trees_output_noi + new_tree .predict_output ()
264272
265273 if self .tune :
274+ self .ssv = SampleSplittingVariable (self .alpha_vec )
266275 for index in new_particle .used_variates :
267- self .split_prior [index ] += 1
268- self .ssv = SampleSplittingVariable (self .split_prior )
276+ self .alpha_vec [index ] += 1
269277 else :
270- self .batch = max (1 , int (self .m * 0.2 ))
271278 for index in new_particle .used_variates :
272279 variable_inclusion [index ] += 1
273- self .idx += 1
274280
275281 stats = {"variable_inclusion" : variable_inclusion , "bart_trees" : copy (self .all_trees )}
276282 sum_trees_output = RaveledVars (sum_trees_output , point_map_info )
@@ -323,7 +329,7 @@ def init_particles(self, tree_id: int) -> np.ndarray:
323329
324330 return np .array (particles )
325331
326- def update_weight (self , particle : List [ParticleTree ]) -> None :
332+ def update_weight (self , particle : List [ParticleTree ], new = False ) -> None :
327333 """
328334 Update the weight of a particle
329335
@@ -333,20 +339,22 @@ def update_weight(self, particle: List[ParticleTree]) -> None:
333339 new_likelihood = self .likelihood_logp (
334340 self .sum_trees_output_noi + particle .tree .predict_output ()
335341 )
336- particle .log_weight += new_likelihood - particle .old_likelihood_logp
337- particle .old_likelihood_logp = new_likelihood
342+ if new :
343+ particle .log_weight = new_likelihood
344+ else :
345+ particle .log_weight += new_likelihood - particle .old_likelihood_logp
346+ particle .old_likelihood_logp = new_likelihood
338347
339348
340349class SampleSplittingVariable :
341- def __init__ (self , alpha_prior ):
350+ def __init__ (self , alpha_vec ):
342351 """
343- Sample splitting variables proportional to `alpha_prior `.
352+ Sample splitting variables proportional to `alpha_vec `.
344353
345- This is equivalent as sampling weights from a Dirichlet distribution with `alpha_prior`
346- parameter and then using those weights to sample from the available spliting variables.
354+ This is equivalent to compute the posterior mean of a Dirichlet-Multinomial model.
347355 This enforce sparsity.
348356 """
349- self .enu = list (enumerate (np .cumsum (alpha_prior / alpha_prior .sum ())))
357+ self .enu = list (enumerate (np .cumsum (alpha_vec / alpha_vec .sum ())))
350358
351359 def rvs (self ):
352360 r = np .random .random ()
0 commit comments