1515# ===============================================================================
1616
1717import argparse
18-
1918import bench
2019from cuml .ensemble import RandomForestRegressor
2120
2221parser = argparse .ArgumentParser (description = 'cuml random forest '
2322 'regression benchmark' )
2423
25- parser .add_argument ('--criterion' , type = str , default = 'mse' ,
26- choices = ('mse' , 'mae' ),
27- help = 'The function to measure the quality of a split' )
2824parser .add_argument ('--split-algorithm' , type = str , default = 'hist' ,
2925 choices = ('hist' , 'global_quantile' ),
3026 help = 'The algorithm to determine how '
3127 'nodes are split in the tree' )
3228parser .add_argument ('--num-trees' , type = int , default = 100 ,
3329 help = 'Number of trees in the forest' )
34- parser .add_argument ('--max-features' , type = bench .float_or_int , default = None ,
30+ parser .add_argument ('--max-features' , type = bench .float_or_int , default = 1.0 ,
3531 help = 'Upper bound on features used at each split' )
36- parser .add_argument ('--max-depth' , type = int , default = None ,
32+ parser .add_argument ('--max-depth' , type = int , default = 16 ,
3733 help = 'Upper bound on depth of constructed trees' )
3834parser .add_argument ('--min-samples-split' , type = bench .float_or_int , default = 2 ,
3935 help = 'Minimum samples number for node splitting' )
4036parser .add_argument ('--max-leaf-nodes' , type = int , default = - 1 ,
4137 help = 'Maximum leaf nodes per tree' )
42- parser .add_argument ('--min-impurity-decrease' , type = float , default = 0. ,
38+ parser .add_argument ('--min-impurity-decrease' , type = float , default = 0.0 ,
4339 help = 'Needed impurity decrease for node splitting' )
4440parser .add_argument ('--no-bootstrap' , dest = 'bootstrap' , default = True ,
4541 action = 'store_false' , help = "Don't control bootstraping" )
4642
4743params = bench .parse_args (parser )
4844
4945# Load and convert data
50- X_train , X_test , y_train , y_test = bench .load_data (params )
51-
52- if params .criterion == 'mse' :
53- params .criterion = 2
54- else :
55- params .criterion = 3
46+ X_train , X_test , y_train , y_test = bench .load_data (params , int_label = True )
5647
5748if params .split_algorithm == 'hist' :
5849 params .split_algorithm = 0
6152
6253# Create our random forest regressor
6354regr = RandomForestRegressor (
64- split_criterion = params .criterion ,
65- split_algo = params .split_algorithm ,
6655 n_estimators = params .num_trees ,
67- max_depth = params .max_depth ,
56+ split_algo = params .split_algorithm ,
6857 max_features = params .max_features ,
6958 min_samples_split = params .min_samples_split ,
59+ max_depth = params .max_depth ,
7060 max_leaves = params .max_leaf_nodes ,
7161 min_impurity_decrease = params .min_impurity_decrease ,
7262 bootstrap = params .bootstrap ,
63+
7364)
7465
7566
@@ -82,7 +73,6 @@ def predict(regr, X):
8273
8374
8475fit_time , _ = bench .measure_function_time (fit , regr , X_train , y_train , params = params )
85-
8676y_pred = predict (regr , X_train )
8777train_rmse = bench .rmse_score (y_pred , y_train )
8878
0 commit comments