Skip to content

Commit 413060c

Browse files
committed
updates for other glmnet models #195
1 parent fd898e0 commit 413060c

13 files changed

+63
-36
lines changed

NAMESPACE

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,10 @@ S3method(translate,boost_tree)
5757
S3method(translate,decision_tree)
5858
S3method(translate,default)
5959
S3method(translate,linear_reg)
60+
S3method(translate,logistic_reg)
6061
S3method(translate,mars)
6162
S3method(translate,mlp)
63+
S3method(translate,multinom_reg)
6264
S3method(translate,nearest_neighbor)
6365
S3method(translate,rand_forest)
6466
S3method(translate,surv_reg)

R/linear_reg.R

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,9 +68,9 @@
6868
#'
6969
#' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::linear_reg(), "keras")}
7070
#'
71-
#' When using `glmnet` models, there is the option to pass
72-
#' multiple values (or no values) to the `penalty` argument. This
73-
#' can have an effect on the model object results. When using the
71+
#' For `glmnet` models, the full regularization path is always fit regardless
72+
#' of the value given to `penalty`. Also, there is the option to pass
73+
#' multiple values (or no values) to the `penalty` argument. When using the
7474
#' `predict()` method in these cases, the return value depends on
7575
#' the value of `penalty`. When using `predict()`, only a single
7676
#' value of the penalty can be used. When predicting on multiple

R/logistic_reg.R

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -66,9 +66,9 @@
6666
#'
6767
#' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::logistic_reg(), "keras")}
6868
#'
69-
#' When using `glmnet` models, there is the option to pass
70-
#' multiple values (or no values) to the `penalty` argument. This
71-
#' can have an effect on the model object results. When using the
69+
#' For `glmnet` models, the full regularization path is always fit regardless
70+
#' of the value given to `penalty`. Also, there is the option to pass
71+
#' multiple values (or no values) to the `penalty` argument. When using the
7272
#' `predict()` method in these cases, the return value depends on
7373
#' the value of `penalty`. When using `predict()`, only a single
7474
#' value of the penalty can be used. When predicting on multiple
@@ -137,6 +137,9 @@ print.logistic_reg <- function(x, ...) {
137137
invisible(x)
138138
}
139139

140+
#' @export
141+
translate.logistic_reg <- translate.linear_reg
142+
140143
# ------------------------------------------------------------------------------
141144

142145
#' @inheritParams update.boost_tree
@@ -235,7 +238,7 @@ organize_glmnet_prob <- function(x, object) {
235238
}
236239

237240
# ------------------------------------------------------------------------------
238-
# glmnet call stack for linear regression using `predict` when object has
241+
# glmnet call stack for logistic regression using `predict` when object has
239242
# classes "_lognet" and "model_fit" (for class predictions):
240243
#
241244
# predict()
@@ -247,7 +250,7 @@ organize_glmnet_prob <- function(x, object) {
247250
# predict.lognet()
248251

249252

250-
# glmnet call stack for linear regression using `multi_predict` when object has
253+
# glmnet call stack for logistic regression using `multi_predict` when object has
251254
# classes "_lognet" and "model_fit" (for class predictions):
252255
#
253256
# multi_predict()
@@ -266,6 +269,11 @@ predict._lognet <- function(object, new_data, type = NULL, opts = list(), penalt
266269
if (any(names(enquos(...)) == "newdata"))
267270
stop("Did you mean to use `new_data` instead of `newdata`?", call. = FALSE)
268271

272+
# See discussion in https://github.com/tidymodels/parsnip/issues/195
273+
if (is.null(penalty) & !is.null(object$spec$args$penalty)) {
274+
penalty <- object$spec$args$penalty
275+
}
276+
269277
object$spec$args$penalty <- check_penalty(penalty, object, multi)
270278

271279
object$spec <- eval_args(object$spec)
@@ -286,8 +294,16 @@ multi_predict._lognet <-
286294
penalty <- eval_tidy(penalty)
287295

288296
dots <- list(...)
289-
if (is.null(penalty))
290-
penalty <- eval_tidy(object$fit$lambda)
297+
298+
if (is.null(penalty)) {
299+
# See discussion in https://github.com/tidymodels/parsnip/issues/195
300+
if (!is.null(object$spec$args$penalty)) {
301+
penalty <- object$spec$args$penalty
302+
} else {
303+
penalty <- object$fit$lambda
304+
}
305+
}
306+
291307
dots$s <- penalty
292308

293309
if (is.null(type))

R/multinom_reg.R

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -57,9 +57,9 @@
5757
#'
5858
#' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::multinom_reg(), "keras")}
5959
#'
60-
#' When using `glmnet` models, there is the option to pass
61-
#' multiple values (or no values) to the `penalty` argument. This
62-
#' can have an effect on the model object results. When using the
60+
#' For `glmnet` models, the full regularization path is always fit regardless
61+
#' of the value given to `penalty`. Also, there is the option to pass
62+
#' multiple values (or no values) to the `penalty` argument. When using the
6363
#' `predict()` method in these cases, the return value depends on
6464
#' the value of `penalty`. When using `predict()`, only a single
6565
#' value of the penalty can be used. When predicting on multiple
@@ -112,14 +112,17 @@ print.multinom_reg <- function(x, ...) {
112112
cat("Multinomial Regression Model Specification (", x$mode, ")\n\n", sep = "")
113113
model_printer(x, ...)
114114

115-
if(!is.null(x$method$fit$args)) {
115+
if (!is.null(x$method$fit$args)) {
116116
cat("Model fit template:\n")
117117
print(show_call(x))
118118
}
119119

120120
invisible(x)
121121
}
122122

123+
#' @export
124+
translate.multinom_reg <- translate.linear_reg
125+
123126
# ------------------------------------------------------------------------------
124127

125128
#' @inheritParams update.boost_tree
@@ -188,7 +191,7 @@ organize_multnet_prob <- function(x, object) {
188191
}
189192

190193
# ------------------------------------------------------------------------------
191-
# glmnet call stack for linear regression using `predict` when object has
194+
# glmnet call stack for multinomial regression using `predict` when object has
192195
# classes "_multnet" and "model_fit" (for class predictions):
193196
#
194197
# predict()
@@ -199,7 +202,7 @@ organize_multnet_prob <- function(x, object) {
199202
# predict.multnet()
200203

201204

202-
# glmnet call stack for linear regression using `multi_predict` when object has
205+
# glmnet call stack for multinomial regression using `multi_predict` when object has
203206
# classes "_multnet" and "model_fit" (for class predictions):
204207
#
205208
# multi_predict()
@@ -217,6 +220,11 @@ organize_multnet_prob <- function(x, object) {
217220
predict._multnet <-
218221
function(object, new_data, type = NULL, opts = list(), penalty = NULL, multi = FALSE, ...) {
219222

223+
# See discussion in https://github.com/tidymodels/parsnip/issues/195
224+
if (is.null(penalty) & !is.null(object$spec$args$penalty)) {
225+
penalty <- object$spec$args$penalty
226+
}
227+
220228
object$spec$args$penalty <- check_penalty(penalty, object, multi)
221229

222230
object$spec <- eval_args(object$spec)
@@ -242,8 +250,14 @@ multi_predict._multnet <-
242250
penalty <- eval_tidy(penalty)
243251

244252
dots <- list(...)
245-
if (is.null(penalty))
246-
penalty <- eval_tidy(object$fit$lambda)
253+
if (is.null(penalty)) {
254+
# See discussion in https://github.com/tidymodels/parsnip/issues/195
255+
if (!is.null(object$spec$args$penalty)) {
256+
penalty <- object$spec$args$penalty
257+
} else {
258+
penalty <- object$fit$lambda
259+
}
260+
}
247261
dots$s <- penalty
248262

249263
if (is.null(type))

R/multinom_reg_data.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ set_pred(
4444
mode = "classification",
4545
type = "class",
4646
value = list(
47-
pre = check_glmnet_lambda,
47+
pre = NULL,
4848
post = organize_multnet_class,
4949
func = c(fun = "predict"),
5050
args =

man/linear_reg.Rd

Lines changed: 3 additions & 3 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/logistic_reg.Rd

Lines changed: 3 additions & 3 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/multinom_reg.Rd

Lines changed: 3 additions & 3 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/testthat/test_linear_reg.R

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,6 @@ test_that('primary arguments', {
7676
x = expr(missing_arg()),
7777
y = expr(missing_arg()),
7878
weights = expr(missing_arg()),
79-
lambda = new_empty_quosure(1),
8079
family = "gaussian"
8180
)
8281
)

tests/testthat/test_logistic_reg.R

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,6 @@ test_that('primary arguments', {
7979
x = expr(missing_arg()),
8080
y = expr(missing_arg()),
8181
weights = expr(missing_arg()),
82-
lambda = new_empty_quosure(1),
8382
family = "binomial"
8483
)
8584
)

0 commit comments

Comments
 (0)