Skip to content

Commit 20d5ba4

Browse files
committed
prototype of changes for #431
1 parent 01168ca commit 20d5ba4

File tree

3 files changed

+60
-20
lines changed

3 files changed

+60
-20
lines changed

R/linear_reg.R

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -107,13 +107,22 @@ translate.linear_reg <- function(x, engine = x$engine, ...) {
107107
x <- translate.default(x, engine, ...)
108108

109109
if (engine == "glmnet") {
110-
# See discussion in https://github.com/tidymodels/parsnip/issues/195
111-
x$method$fit$args$lambda <- NULL
110+
if (any(names(x$eng_args) == "path_values")) {
111+
# Since we decouple the parsnip `penalty` argument from being the same
112+
# as the glmnet `lambda` value, this allows users to set the path
113+
# differently from the default that glmnet uses. See
114+
# https://github.com/tidymodels/parsnip/issues/431
115+
x$method$fit$args$lambda <- x$eng_args$path_values
116+
x$eng_args$path_values <- NULL
117+
x$method$fit$args$path_values <- NULL
118+
} else {
119+
# See discussion in https://github.com/tidymodels/parsnip/issues/195
120+
x$method$fit$args$lambda <- NULL
121+
}
112122
# Since the `fit` information is gone for the penalty, we need to have an
113123
# evaluated value for the parameter.
114124
x$args$penalty <- rlang::eval_tidy(x$args$penalty)
115125
}
116-
117126
x
118127
}
119128

man/linear_reg.Rd

Lines changed: 28 additions & 9 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/rmd/linear-reg.Rmd

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,14 +23,26 @@ linear_reg() %>%
2323
translate()
2424
```
2525

26-
For `glmnet` models, the full regularization path is always fit regardless of the
27-
value given to `penalty`. Also, there is the option to pass multiple values (or
28-
no values) to the `penalty` argument. When using the `predict()` method in these
29-
cases, the return value depends on the value of `penalty`. When using
30-
`predict()`, only a single value of the penalty can be used. When predicting on
31-
multiple penalties, the `multi_predict()` function can be used. It returns a
32-
tibble with a list column called `.pred` that contains a tibble with all of the
33-
penalty results.
26+
`linear_reg()` requires a single value for the `penalty` argument (a number
27+
or `tune()`). Despite this, the full regularization path is always fit
28+
regardless of the value given to `penalty`. To pass in a custom sequence of
29+
values for `lambda`, use the argument `path_values` in `set_engine()`. This
30+
will assign the value of the glmnet `lambda` parameter without disturbing
31+
the value given in `linear_reg(penalty)`. For example:
32+
33+
```{r glmnet-path}
34+
linear_reg(penalty = .1) %>%
35+
set_engine("glmnet", path_values = c(0, 10^seq(-10, 1, length.out = 20))) %>%
36+
set_mode("regression") %>%
37+
translate()
38+
```
39+
40+
When using `predict()`, the single penalty value used for prediction is the one
41+
given to `linear_reg()`.
42+
43+
To predict on multiple penalties, the `multi_predict()` function can be used.
44+
It returns a tibble with a list column called `.pred` that contains a tibble
45+
with all of the penalty results.
3446

3547
## stan
3648

0 commit comments

Comments
 (0)