File tree Expand file tree Collapse file tree 2 files changed +6
-2
lines changed Expand file tree Collapse file tree 2 files changed +6
-2
lines changed Original file line number Diff line number Diff line change 9999 if 'cudf' in str (type (y_train )):
100100 params .n_classes = y_train [y_train .columns [0 ]].nunique ()
101101 else :
102- params .n_classes = len (np .unique (y_train ))
102+ unique_y_train = np .unique (y_train )
103+ params .n_classes = len (unique_y_train )
104+ if max (unique_y_train ) != len (unique_y_train ) - 1 :
105+ params .n_classes = int (max (unique_y_train )) + 1
106+
103107 if params .n_classes > 2 :
104108 lgbm_params ['num_class' ] = params .n_classes
105109
Original file line number Diff line number Diff line change @@ -30,7 +30,7 @@ def convert_xgb_predictions(y_pred, objective):
3030 if objective == 'multi:softprob' :
3131 y_pred = convert_probs_to_classes (y_pred )
3232 elif objective == 'binary:logistic' :
33- y_pred = y_pred .astype (np .int32 )
33+ y_pred = ( y_pred >= 0.5 ) .astype (np .int32 )
3434 return y_pred
3535
3636
You can’t perform that action at this time.
0 commit comments