Skip to content

Commit 45ec8e6

Browse files
ntreelimit has been deprecated in favour of iteration_range (#656)
* Adding the correct argument to the `xgb_pred` call. As `ntreelimit` as been deprecated, adding `iterationrange` to the `xgb_pred` call with the correct modifications to fit the current use case. * Reformat code * Use xgboost >= 1.5.0.1 because of `ntreelimit` / `iterationrange` change * Update pkgdown and testing actions * Update NEWS * Parameter is actually `iteration_range` * Update test for new `iteration_range` arg * Apply suggestions from code review Co-authored-by: Tiago Maié <tiagomaie@hotmail.com> Co-authored-by: Julia Silge <julia.silge@gmail.com>
1 parent 90ffe6b commit 45ec8e6

File tree

6 files changed

+37
-14
lines changed

6 files changed

+37
-14
lines changed

.github/workflows/pkgdown.yaml

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,20 @@ jobs:
2929
extra-packages: r-lib/pkgdown
3030
needs: website
3131

32-
- name: Install Miniconda + TensorFlow
32+
- name: Install dev reticulate
33+
run: pak::pkg_install('rstudio/reticulate')
34+
shell: Rscript {0}
35+
36+
- name: Install Miniconda
37+
# conda can fail at downgrading python, so we specify python version in advance
38+
env:
39+
RETICULATE_MINICONDA_PYTHON_VERSION: "3.7"
40+
run: reticulate::install_miniconda() # creates r-reticulate conda env by default
41+
shell: Rscript {0}
42+
43+
- name: Install TensorFlow
3344
run: |
34-
pak::pkg_install('rstudio/reticulate')
35-
reticulate::install_miniconda()
36-
reticulate::conda_create('r-reticulate', packages = c('python==3.6.9'))
37-
tensorflow::install_tensorflow(version='2.7.0')
45+
tensorflow::install_tensorflow(version='2.7', conda_python_version = NULL)
3846
shell: Rscript {0}
3947

4048
- name: Install package

.github/workflows/test-coverage.yaml

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,20 @@ jobs:
2626
extra-packages: any::covr
2727
needs: coverage
2828

29-
- name: Install Miniconda + TensorFlow
29+
- name: Install dev reticulate
30+
run: pak::pkg_install('rstudio/reticulate')
31+
shell: Rscript {0}
32+
33+
- name: Install Miniconda
34+
# conda can fail at downgrading python, so we specify python version in advance
35+
env:
36+
RETICULATE_MINICONDA_PYTHON_VERSION: "3.7"
37+
run: reticulate::install_miniconda() # creates r-reticulate conda env by default
38+
shell: Rscript {0}
39+
40+
- name: Install TensorFlow
3041
run: |
31-
pak::pkg_install('rstudio/reticulate')
32-
reticulate::install_miniconda()
33-
reticulate::conda_create('r-reticulate', packages = c('python==3.6.9'))
34-
tensorflow::install_tensorflow(version='2.7.0')
42+
tensorflow::install_tensorflow(version='2.7', conda_python_version = NULL)
3543
shell: Rscript {0}
3644

3745
- name: Test coverage

DESCRIPTION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ Suggests:
5959
sparklyr (>= 1.0.0),
6060
survival,
6161
testthat,
62-
xgboost
62+
xgboost (>= 1.5.0.1)
6363
VignetteBuilder:
6464
knitr
6565
ByteCompile: true

NEWS.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@
4242

4343
* parsnip is now more robust working with keras and tensorflow for a larger range of versions (#596).
4444

45+
* xgboost engines now use the new `iterationrange` parameter instead of the deprecated `ntreelimit` (#656).
46+
4547
# parsnip 0.1.7
4648

4749
## Model Specification Changes

R/boost_tree.R

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -482,7 +482,12 @@ multi_predict._xgb.Booster <-
482482
}
483483

484484
xgb_by_tree <- function(tree, object, new_data, type, ...) {
485-
pred <- xgb_pred(object$fit, newdata = new_data, ntreelimit = tree)
485+
pred <- xgb_pred(
486+
object$fit,
487+
newdata = new_data,
488+
iterationrange = c(1, tree + 1),
489+
ntreelimit = NULL
490+
)
486491

487492
# switch based on prediction type
488493
if (object$spec$mode == "regression") {

tests/testthat/test_boost_tree_xgboost.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ test_that('submodel prediction', {
191191

192192
x <- xgboost::xgb.DMatrix(as.matrix(mtcars[1:4, -1]))
193193

194-
pruned_pred <- predict(reg_fit$fit, x, ntreelimit = 5)
194+
pruned_pred <- predict(reg_fit$fit, x, iterationrange = c(1, 6))
195195

196196
mp_res <- multi_predict(reg_fit, new_data = mtcars[1:4, -1], trees = 5)
197197
mp_res <- do.call("rbind", mp_res$.pred)
@@ -206,7 +206,7 @@ test_that('submodel prediction', {
206206

207207
x <- xgboost::xgb.DMatrix(as.matrix(wa_churn[1:4, vars]))
208208

209-
pred_class <- predict(class_fit$fit, x, ntreelimit = 5)
209+
pred_class <- predict(class_fit$fit, x, iterationrange = c(1, 6))
210210

211211
mp_res <- multi_predict(class_fit, new_data = wa_churn[1:4, vars], trees = 5, type = "prob")
212212
mp_res <- do.call("rbind", mp_res$.pred)

0 commit comments

Comments
 (0)