1515import numpy as np
1616
1717from pandas import DataFrame , Series
18+ from scipy .special import expit
1819
1920from pymc3 .distributions .distribution import NoDistribution
2021from pymc3 .distributions .tree import LeafNode , SplitNode , Tree
@@ -30,7 +31,6 @@ def __init__(
3031 m = 200 ,
3132 alpha = 0.25 ,
3233 split_prior = None ,
33- scale = None ,
3434 inv_link = None ,
3535 jitter = False ,
3636 * args ,
@@ -63,22 +63,32 @@ def __init__(
6363 )
6464 self .m = m
6565 self .alpha = alpha
66- self .y_std = Y .std ()
67-
68- if scale is None :
69- self .leaf_scale = NormalSampler (sigma = None )
70- elif isinstance (scale , (int , float )):
71- self .leaf_scale = NormalSampler (sigma = Y .std () / self .m ** scale )
7266
7367 if inv_link is None :
74- self .inv_link = lambda x : x
68+ self .inv_link = self .link = lambda x : x
69+ elif isinstance (inv_link , str ):
70+ # The link function is just a rough approximation in order to allow the PGBART sampler
71+ # to propose reasonable values for the leaf nodes.
72+ if inv_link == "logistic" :
73+ self .inv_link = expit
74+ self .link = lambda x : (x - 0.5 ) * 10
75+ elif inv_link == "exp" :
76+ self .inv_link = np .exp
77+ self .link = np .log
78+ self .Y [self .Y == 0 ] += 0.0001
79+ else :
80+ raise ValueError ("Accepted strings are 'logistic' or 'exp'" )
7581 else :
76- self .inv_link = inv_link
82+ self .inv_link , self .link = inv_link
83+
84+ self .init_mean = self .link (self .Y .mean ())
85+ self .Y_un = self .link (self .Y )
7786
7887 self .num_observations = X .shape [0 ]
7988 self .num_variates = X .shape [1 ]
8089 self .available_predictors = list (range (self .num_variates ))
8190 self .ssv = SampleSplittingVariable (split_prior , self .num_variates )
91+ self .initial_value_leaf_nodes = self .init_mean / self .m
8292 self .trees = self .init_list_of_trees ()
8393 self .all_trees = []
8494 self .mean = fast_mean ()
@@ -96,7 +106,7 @@ def preprocess_XY(self, X, Y):
96106 return X , Y , missing_data
97107
98108 def init_list_of_trees (self ):
99- initial_value_leaf_nodes = self .Y . mean () / self . m
109+ initial_value_leaf_nodes = self .initial_value_leaf_nodes
100110 initial_idx_data_points_leaf_nodes = np .array (range (self .num_observations ), dtype = "int32" )
101111 list_of_trees = []
102112 for i in range (self .m ):
@@ -110,7 +120,7 @@ def init_list_of_trees(self):
110120 # bartMachine: A Powerful Tool for Machine Learning in R. ArXiv e-prints, 2013
111121 # The sum_trees_output will contain the sum of the predicted output for all trees.
112122 # When R_j is needed we subtract the current predicted output for tree T_j.
113- self .sum_trees_output = np .full_like (self .Y , self .Y . mean () )
123+ self .sum_trees_output = np .full_like (self .Y , self .init_mean )
114124
115125 return list_of_trees
116126
@@ -181,14 +191,13 @@ def get_new_idx_data_points(self, current_split_node, idx_data_points):
181191
182192 def get_residuals (self ):
183193 """Compute the residuals."""
184- R_j = self .Y - self .inv_link (self .sum_trees_output )
185-
194+ R_j = self .Y_un - self .sum_trees_output
186195 return R_j
187196
188197 def draw_leaf_value (self , idx_data_points ):
189198 """Draw the residual mean."""
190199 R_j = self .get_residuals ()[idx_data_points ]
191- draw = self .mean (R_j ) + self . leaf_scale . random ()
200+ draw = self .mean (R_j )
192201 return draw
193202
194203 def predict (self , X_new ):
@@ -278,24 +287,6 @@ def rvs(self):
278287 return i
279288
280289
281- class NormalSampler :
282- def __init__ (self , sigma ):
283- self .size = 5000
284- self .cache = []
285- self .sigma = sigma
286-
287- def random (self ):
288- if self .sigma is None :
289- return 0
290- else :
291- if not self .cache :
292- self .update ()
293- return self .cache .pop ()
294-
295- def update (self ):
296- self .cache = np .random .normal (loc = 0.0 , scale = self .sigma , size = self .size ).tolist ()
297-
298-
299290class BART (BaseBART ):
300291 """
301292 BART distribution.
@@ -317,23 +308,17 @@ class BART(BaseBART):
317308 Each element of split_prior should be in the [0, 1] interval and the elements should sum
318309 to 1. Otherwise they will be normalized.
319310 Defaults to None, all variable have the same a prior probability
320- scale : float
321- Controls the variance of the proposed leaf value. The leaf values are computed as a
322- Gaussian with mean equal to the conditional residual mean and variance proportional to
323- the variance of the response variable, and inversely proportional to the number of trees
324- and the scale parameter. Defaults to None, i.e the variance is 0.
325- inv_link : numpy function
326- Inverse link function defaults to None, i.e. the identity function.
311+ inv_link : str or tuple of functions
312+ Inverse link function defaults to None, i.e. the identity function. Accepted strings are
313+ ``logistic`` or ``exp``.
327314 jitter : bool
328315 Whether to jitter the X values or not. Defaults to False. When values of X are repeated,
329316 jittering X has the effect of increasing the number of effective spliting variables,
330317 otherwise it does not have any effect.
331318 """
332319
333- def __init__ (
334- self , X , Y , m = 200 , alpha = 0.25 , split_prior = None , scale = None , inv_link = None , jitter = False
335- ):
336- super ().__init__ (X , Y , m , alpha , split_prior , scale , inv_link )
320+ def __init__ (self , X , Y , m = 200 , alpha = 0.25 , split_prior = None , inv_link = None , jitter = False ):
321+ super ().__init__ (X , Y , m , alpha , split_prior , inv_link )
337322
338323 def _str_repr (self , name = None , dist = None , formatting = "plain" ):
339324 if dist is None :
0 commit comments