@@ -53,6 +53,32 @@ impl Default for TreeMethod {
5353 fn default ( ) -> Self { TreeMethod :: Auto }
5454}
5555
56+ impl From < String > for TreeMethod
57+ {
58+ fn from ( s : String ) -> Self
59+ {
60+ use std:: borrow:: Borrow ;
61+ Self :: from ( s. borrow ( ) )
62+ }
63+ }
64+
65+ impl < ' a > From < & ' a str > for TreeMethod
66+ {
67+ fn from ( s : & ' a str ) -> Self
68+ {
69+ match s
70+ {
71+ "auto" => TreeMethod :: Auto ,
72+ "exact" => TreeMethod :: Exact ,
73+ "approx" => TreeMethod :: Approx ,
74+ "hist" => TreeMethod :: Hist ,
75+ "gpu_exact" => TreeMethod :: GpuExact ,
76+ "gpu_hist" => TreeMethod :: GpuHist ,
77+ _ => panic ! ( "no known tree_method for {}" , s)
78+ }
79+ }
80+ }
81+
5682/// Provides a modular way to construct and to modify the trees. This is an advanced parameter that is usually set
5783/// automatically, depending on some other parameters. However, it could be also set explicitly by a user.
5884#[ derive( Clone ) ]
@@ -191,7 +217,7 @@ pub struct TreeBoosterParameters {
191217 ///
192218 /// * range: [0,∞]
193219 /// * default: 0
194- gamma : u32 ,
220+ gamma : f32 ,
195221
196222 /// Maximum depth of a tree, increase this value will make the model more complex / likely to be overfitting.
197223 /// 0 indicates no limit, limit is required for depth-wise grow policy.
@@ -208,7 +234,7 @@ pub struct TreeBoosterParameters {
208234 ///
209235 /// * range: [0,∞]
210236 /// * default: 1
211- min_child_weight : u32 ,
237+ min_child_weight : f32 ,
212238
213239 /// Maximum delta step we allow each tree’s weight estimation to be.
214240 /// If the value is set to 0, it means there is no constraint. If it is set to a positive value,
@@ -218,7 +244,7 @@ pub struct TreeBoosterParameters {
218244 ///
219245 /// * range: [0,∞]
220246 /// * default: 0
221- max_delta_step : u32 ,
247+ max_delta_step : f32 ,
222248
223249 /// Subsample ratio of the training instance. Setting it to 0.5 means that XGBoost randomly collected half
224250 /// of the data instances to grow trees and this will prevent overfitting.
@@ -239,15 +265,21 @@ pub struct TreeBoosterParameters {
239265 /// * default: 1.0
240266 colsample_bylevel : f32 ,
241267
268+ /// Subsample ratio of columns for each node.
269+ ///
270+ /// * range: (0.0, 1.0]
271+ /// * default: 1.0
272+ colsample_bynode : f32 ,
273+
242274 /// L2 regularization term on weights, increase this value will make model more conservative.
243275 ///
244276 /// * default: 1
245- lambda : u32 ,
277+ lambda : f32 ,
246278
247279 /// L1 regularization term on weights, increase this value will make model more conservative.
248280 ///
249281 /// * default: 0
250- alpha : u32 ,
282+ alpha : f32 ,
251283
252284 /// The tree construction algorithm used in XGBoost.
253285 #[ builder( default = "TreeMethod::default()" ) ]
@@ -270,7 +302,7 @@ pub struct TreeBoosterParameters {
270302
271303 /// Sequence of tree updaters to run, providing a modular way to construct and to modify the trees.
272304 ///
273- /// * default: [TreeUpdater::GrowColMaker, TreeUpdater::Prune ]
305+ /// * default: vec![ ]
274306 updater : Vec < TreeUpdater > ,
275307
276308 /// This is a parameter of the ‘refresh’ updater plugin. When this flag is true, tree leafs as well as tree nodes'
@@ -300,6 +332,11 @@ pub struct TreeBoosterParameters {
300332 /// * default: 256
301333 max_bin : u32 ,
302334
335+ /// Number of trees to train in parallel for boosted random forest.
336+ ///
337+ /// * default: 1
338+ num_parallel_tree : u32 ,
339+
303340 /// The type of predictor algorithm to use. Provides the same results but allows the use of GPU or CPU.
304341 ///
305342 /// * default: [`Predictor::Cpu`](enum.Predictor.html#variant.Cpu)
@@ -310,24 +347,26 @@ impl Default for TreeBoosterParameters {
310347 fn default ( ) -> Self {
311348 TreeBoosterParameters {
312349 eta : 0.3 ,
313- gamma : 0 ,
350+ gamma : 0.0 ,
314351 max_depth : 6 ,
315- min_child_weight : 1 ,
316- max_delta_step : 0 ,
352+ min_child_weight : 1.0 ,
353+ max_delta_step : 0.0 ,
317354 subsample : 1.0 ,
318355 colsample_bytree : 1.0 ,
319356 colsample_bylevel : 1.0 ,
320- lambda : 1 ,
321- alpha : 0 ,
357+ colsample_bynode : 1.0 ,
358+ lambda : 1.0 ,
359+ alpha : 0.0 ,
322360 tree_method : TreeMethod :: default ( ) ,
323361 sketch_eps : 0.03 ,
324362 scale_pos_weight : 1.0 ,
325- updater : vec ! [ TreeUpdater :: GrowColMaker , TreeUpdater :: Prune ] ,
363+ updater : Vec :: new ( ) ,
326364 refresh_leaf : true ,
327365 process_type : ProcessType :: default ( ) ,
328366 grow_policy : GrowPolicy :: default ( ) ,
329367 max_leaves : 0 ,
330368 max_bin : 256 ,
369+ num_parallel_tree : 1 ,
331370 predictor : Predictor :: default ( ) ,
332371 }
333372 }
@@ -347,19 +386,29 @@ impl TreeBoosterParameters {
347386 v. push ( ( "subsample" . to_owned ( ) , self . subsample . to_string ( ) ) ) ;
348387 v. push ( ( "colsample_bytree" . to_owned ( ) , self . colsample_bytree . to_string ( ) ) ) ;
349388 v. push ( ( "colsample_bylevel" . to_owned ( ) , self . colsample_bylevel . to_string ( ) ) ) ;
389+ v. push ( ( "colsample_bynode" . to_owned ( ) , self . colsample_bynode . to_string ( ) ) ) ;
350390 v. push ( ( "lambda" . to_owned ( ) , self . lambda . to_string ( ) ) ) ;
351391 v. push ( ( "alpha" . to_owned ( ) , self . alpha . to_string ( ) ) ) ;
352392 v. push ( ( "tree_method" . to_owned ( ) , self . tree_method . to_string ( ) ) ) ;
353393 v. push ( ( "sketch_eps" . to_owned ( ) , self . sketch_eps . to_string ( ) ) ) ;
354394 v. push ( ( "scale_pos_weight" . to_owned ( ) , self . scale_pos_weight . to_string ( ) ) ) ;
355- v. push ( ( "updater" . to_owned ( ) , self . updater . iter ( ) . map ( |u| u. to_string ( ) ) . collect :: < Vec < String > > ( ) . join ( "," ) ) ) ;
356395 v. push ( ( "refresh_leaf" . to_owned ( ) , ( self . refresh_leaf as u8 ) . to_string ( ) ) ) ;
357396 v. push ( ( "process_type" . to_owned ( ) , self . process_type . to_string ( ) ) ) ;
358397 v. push ( ( "grow_policy" . to_owned ( ) , self . grow_policy . to_string ( ) ) ) ;
359398 v. push ( ( "max_leaves" . to_owned ( ) , self . max_leaves . to_string ( ) ) ) ;
360399 v. push ( ( "max_bin" . to_owned ( ) , self . max_bin . to_string ( ) ) ) ;
400+ v. push ( ( "num_parallel_tree" . to_owned ( ) , self . num_parallel_tree . to_string ( ) ) ) ;
361401 v. push ( ( "predictor" . to_owned ( ) , self . predictor . to_string ( ) ) ) ;
362402
403+ // Don't pass anything to XGBoost if the user didn't specify anything.
404+ // This allows XGBoost to figure it out on it's own, and suppresses the
405+ // warning message during training.
406+ // See: https://github.com/davechallis/rust-xgboost/issues/7
407+ if self . updater . len ( ) != 0
408+ {
409+ v. push ( ( "updater" . to_owned ( ) , self . updater . iter ( ) . map ( |u| u. to_string ( ) ) . collect :: < Vec < String > > ( ) . join ( "," ) ) ) ;
410+ }
411+
363412 v
364413 }
365414}
@@ -370,6 +419,7 @@ impl TreeBoosterParametersBuilder {
370419 Interval :: new_open_closed ( 0.0 , 1.0 ) . validate ( & self . subsample , "subsample" ) ?;
371420 Interval :: new_open_closed ( 0.0 , 1.0 ) . validate ( & self . colsample_bytree , "colsample_bytree" ) ?;
372421 Interval :: new_open_closed ( 0.0 , 1.0 ) . validate ( & self . colsample_bylevel , "colsample_bylevel" ) ?;
422+ Interval :: new_open_closed ( 0.0 , 1.0 ) . validate ( & self . colsample_bynode , "colsample_bynode" ) ?;
373423 Interval :: new_open_open ( 0.0 , 1.0 ) . validate ( & self . sketch_eps , "sketch_eps" ) ?;
374424 Ok ( ( ) )
375425 }
0 commit comments