@@ -7,7 +7,8 @@ use xgboost_sys;
77
88use super :: { XGBResult , XGBError } ;
99
10- static KEY_ROOT_INDEX : & ' static str = "root_index" ;
10+ static KEY_GROUP_PTR : & ' static str = "group_ptr" ;
11+ static KEY_GROUP : & ' static str = "group" ;
1112static KEY_LABEL : & ' static str = "label" ;
1213static KEY_WEIGHT : & ' static str = "weight" ;
1314static KEY_BASE_MARGIN : & ' static str = "base_margin" ;
@@ -230,20 +231,6 @@ impl DMatrix {
230231 Ok ( DMatrix :: new ( out_handle) ?)
231232 }
232233
233- /// Gets the specified root index of each instance, can be used for multi task setting.
234- ///
235- /// See the XGBoost documentation for more information.
236- pub fn get_root_index ( & self ) -> XGBResult < & [ u32 ] > {
237- self . get_uint_info ( KEY_ROOT_INDEX )
238- }
239-
240- /// Sets the specified root index of each instance, can be used for multi task setting.
241- ///
242- /// See the XGBoost documentation for more information.
243- pub fn set_root_index ( & mut self , array : & [ u32 ] ) -> XGBResult < ( ) > {
244- self . set_uint_info ( KEY_ROOT_INDEX , array)
245- }
246-
247234 /// Get ground truth labels for each row of this matrix.
248235 pub fn get_labels ( & self ) -> XGBResult < & [ f32 ] > {
249236 self . get_float_info ( KEY_LABEL )
@@ -282,9 +269,20 @@ impl DMatrix {
282269 ///
283270 /// See the XGBoost documentation for more information.
284271 pub fn set_group ( & mut self , group : & [ u32 ] ) -> XGBResult < ( ) > {
285- xgb_call ! ( xgboost_sys:: XGDMatrixSetGroup ( self . handle, group. as_ptr( ) , group. len( ) as u64 ) )
272+ // same as xgb_call!(xgboost_sys::XGDMatrixSetGroup(self.handle, group.as_ptr(), group.len() as u64))
273+ self . set_uint_info ( KEY_GROUP , group)
274+ }
275+
276+ /// Get the index for the beginning and end of a group.
277+ ///
278+ /// Needed when the learning task is ranking.
279+ ///
280+ /// See the XGBoost documentation for more information.
281+ pub fn get_group ( & self ) -> XGBResult < & [ u32 ] > {
282+ self . get_uint_info ( KEY_GROUP_PTR )
286283 }
287284
285+
288286 fn get_float_info ( & self , field : & str ) -> XGBResult < & [ f32 ] > {
289287 let field = ffi:: CString :: new ( field) . unwrap ( ) ;
290288 let mut out_len = 0 ;
@@ -313,7 +311,6 @@ impl DMatrix {
313311 field. as_ptr( ) ,
314312 & mut out_len,
315313 & mut out_dptr) ) ?;
316-
317314 Ok ( unsafe { slice:: from_raw_parts ( out_dptr as * mut c_uint , out_len as usize ) } )
318315 }
319316
@@ -370,16 +367,6 @@ mod tests {
370367 // TODO: check contents as well, if possible
371368 }
372369
373- #[ test]
374- fn get_set_root_index ( ) {
375- let mut dmat = read_train_matrix ( ) . unwrap ( ) ;
376- assert_eq ! ( dmat. get_root_index( ) . unwrap( ) , & [ ] ) ;
377-
378- let root_index = [ 3 , 22 , 1 ] ;
379- assert ! ( dmat. set_root_index( & root_index) . is_ok( ) ) ;
380- assert_eq ! ( dmat. get_root_index( ) . unwrap( ) , & [ 3 , 22 , 1 ] ) ;
381- }
382-
383370 #[ test]
384371 fn get_set_labels ( ) {
385372 let mut dmat = read_train_matrix ( ) . unwrap ( ) ;
@@ -395,7 +382,7 @@ mod tests {
395382 let mut dmat = read_train_matrix ( ) . unwrap ( ) ;
396383 assert_eq ! ( dmat. get_weights( ) . unwrap( ) , & [ ] ) ;
397384
398- let weight = [ 1.0 , 10.0 , - 123.456789 , 44.9555 ] ;
385+ let weight = [ 1.0 , 10.0 , 44.9555 ] ;
399386 assert ! ( dmat. set_weights( & weight) . is_ok( ) ) ;
400387 assert_eq ! ( dmat. get_weights( ) . unwrap( ) , weight) ;
401388 }
@@ -411,11 +398,13 @@ mod tests {
411398 }
412399
413400 #[ test]
414- fn set_group ( ) {
401+ fn get_set_group ( ) {
415402 let mut dmat = read_train_matrix ( ) . unwrap ( ) ;
403+ assert_eq ! ( dmat. get_group( ) . unwrap( ) , & [ ] ) ;
416404
417- let group = [ 1 , 2 , 3 ] ;
405+ let group = [ 1 ] ;
418406 assert ! ( dmat. set_group( & group) . is_ok( ) ) ;
407+ assert_eq ! ( dmat. get_group( ) . unwrap( ) , & [ 0 , 1 ] ) ;
419408 }
420409
421410 #[ test]
@@ -426,7 +415,7 @@ mod tests {
426415
427416 let dmat = DMatrix :: from_csr ( & indptr, & indices, & data, None ) . unwrap ( ) ;
428417 assert_eq ! ( dmat. num_rows( ) , 4 ) ;
429- assert_eq ! ( dmat. num_cols( ) , 3 ) ;
418+ assert_eq ! ( dmat. num_cols( ) , 0 ) ; // https://github.com/dmlc/xgboost/pull/7265
430419
431420 let dmat = DMatrix :: from_csr ( & indptr, & indices, & data, Some ( 10 ) ) . unwrap ( ) ;
432421 assert_eq ! ( dmat. num_rows( ) , 4 ) ;
@@ -477,7 +466,7 @@ mod tests {
477466 assert_eq ! ( dmat. slice( & [ 1 ] ) . unwrap( ) . shape( ) , ( 1 , 2 ) ) ;
478467 assert_eq ! ( dmat. slice( & [ 0 , 1 ] ) . unwrap( ) . shape( ) , ( 2 , 2 ) ) ;
479468 assert_eq ! ( dmat. slice( & [ 3 , 2 , 1 ] ) . unwrap( ) . shape( ) , ( 3 , 2 ) ) ;
480- assert ! ( dmat. slice( & [ 10 , 11 , 12 ] ) . is_err ( ) ) ;
469+ assert_eq ! ( dmat. slice( & [ 10 , 11 , 12 ] ) . unwrap ( ) . shape ( ) , ( 0 , 0 ) ) ;
481470 }
482471
483472 #[ test]
0 commit comments