Skip to content

Commit fb8d78b

Browse files
committed
Merge branch 'master' into path-values
2 parents 9ddb7ac + bc125e9 commit fb8d78b

16 files changed

+64
-108
lines changed

R/engines.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ load_libs <- function(x, quiet, attach = FALSE) {
8282
#' @examples
8383
#' # First, set general arguments using the standardized names
8484
#' mod <-
85-
#' logistic_reg(mixture = 1/3) %>%
85+
#' logistic_reg(penalty = 0.01, mixture = 1/3) %>%
8686
#' # now say how you want to fit the model and another other options
8787
#' set_engine("glmnet", nlambda = 10)
8888
#' translate(mod, engine = "glmnet")

R/linear_reg.R

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@ translate.linear_reg <- function(x, engine = x$engine, ...) {
122122
# Since the `fit` information is gone for the penalty, we need to have an
123123
# evaluated value for the parameter.
124124
x$args$penalty <- rlang::eval_tidy(x$args$penalty)
125+
check_glmnet_penalty(x)
125126
}
126127
x
127128
}

R/logistic_reg.R

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,7 @@ translate.logistic_reg <- function(x, engine = x$engine, ...) {
124124
# Since the `fit` information is gone for the penalty, we need to have an
125125
# evaluated value for the parameter.
126126
x$args$penalty <- rlang::eval_tidy(x$args$penalty)
127+
check_glmnet_penalty(x)
127128
}
128129

129130
if (engine == "LiblineaR") {

R/misc.R

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -323,4 +323,13 @@ stan_conf_int <- function(object, newdata) {
323323
rlang::eval_tidy(fn)
324324
}
325325

326-
326+
check_glmnet_penalty <- function(x) {
327+
if (length(x$args$penalty) != 1) {
328+
rlang::abort(c(
329+
"For the glmnet engine, `penalty` must be a single number (or a value of `tune()`).",
330+
glue::glue("There are {length(x$args$penalty)} values for `penalty`."),
331+
"To try multiple values for total regularization, use the tune package.",
332+
"To predict multiple penalties, use `multi_predict()`"
333+
))
334+
}
335+
}

R/translate.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
#' translate(lm_spec, engine = "spark")
3939
#'
4040
#' # with a placeholder for an unknown argument value:
41-
#' translate(linear_reg(mixture = varying()), engine = "glmnet")
41+
#' translate(linear_reg(penalty = varying(), mixture = varying()), engine = "glmnet")
4242
#'
4343
#' @export
4444

man/linear_reg.Rd

Lines changed: 4 additions & 5 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/logistic_reg.Rd

Lines changed: 4 additions & 6 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/multinom_reg.Rd

Lines changed: 4 additions & 4 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/rmd/linear-reg.Rmd

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,16 +10,14 @@ Engines may have pre-set default arguments when executing the model fit call. Fo
1010
```{r lm-reg}
1111
linear_reg() %>%
1212
set_engine("lm") %>%
13-
set_mode("regression") %>%
1413
translate()
1514
```
1615

1716
## glmnet
1817

1918
```{r glmnet-csl}
20-
linear_reg() %>%
19+
linear_reg(penalty = 0.1) %>%
2120
set_engine("glmnet") %>%
22-
set_mode("regression") %>%
2321
translate()
2422
```
2523

@@ -69,7 +67,6 @@ returned.
6967
```{r spark-reg}
7068
linear_reg() %>%
7169
set_engine("spark") %>%
72-
set_mode("regression") %>%
7370
translate()
7471
```
7572

@@ -78,7 +75,6 @@ linear_reg() %>%
7875
```{r keras-reg}
7976
linear_reg() %>%
8077
set_engine("keras") %>%
81-
set_mode("regression") %>%
8278
translate()
8379
```
8480

man/rmd/logistic-reg.Rmd

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,13 @@ For this type of model, the template of the fit calls are below.
1111
```{r glm-reg}
1212
logistic_reg() %>%
1313
set_engine("glm") %>%
14-
set_mode("classification") %>%
1514
translate()
1615
```
1716

1817
## glmnet
1918

2019
```{r glmnet-csl}
21-
logistic_reg() %>%
20+
logistic_reg(penalty = 0.1) %>%
2221
set_engine("glmnet") %>%
2322
translate()
2423
```
@@ -52,7 +51,6 @@ with all of the penalty results.
5251
```{r liblinear-reg}
5352
logistic_reg() %>%
5453
set_engine("LiblineaR") %>%
55-
set_mode("classification") %>%
5654
translate()
5755
```
5856

@@ -68,7 +66,6 @@ regularized regression models do not, which will result in different parameter e
6866
```{r stan-reg}
6967
logistic_reg() %>%
7068
set_engine("stan") %>%
71-
set_mode("classification") %>%
7269
translate()
7370
```
7471

@@ -86,7 +83,6 @@ returned.
8683
```{r spark-reg}
8784
logistic_reg() %>%
8885
set_engine("spark") %>%
89-
set_mode("classification") %>%
9086
translate()
9187
```
9288

@@ -95,7 +91,6 @@ logistic_reg() %>%
9591
```{r keras-reg}
9692
logistic_reg() %>%
9793
set_engine("keras") %>%
98-
set_mode("classification") %>%
9994
translate()
10095
```
10196

0 commit comments

Comments
 (0)