Skip to content

Commit 2c704e9

Browse files
committed
Merge commit 'a517c87205639ee55b72dd3d843ee91caeada70c'
2 parents 1229ad1 + a517c87 commit 2c704e9

File tree

2 files changed

+8
-5
lines changed

2 files changed

+8
-5
lines changed

R/logistic_reg.R

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -271,8 +271,6 @@ multi_predict._lognet <-
271271
}
272272
}
273273

274-
dots$s <- penalty
275-
276274
if (is.null(type))
277275
type <- "class"
278276
if (!(type %in% c("class", "prob", "link", "raw"))) {
@@ -284,7 +282,9 @@ multi_predict._lognet <-
284282
dots$type <- type
285283

286284
object$spec <- eval_args(object$spec)
287-
pred <- predict.model_fit(object, new_data = new_data, type = "raw", opts = dots)
285+
pred <- predict._lognet(object, new_data = new_data, type = "raw",
286+
opts = dots, penalty = penalty, multi = TRUE)
287+
288288
param_key <- tibble(group = colnames(pred), penalty = penalty)
289289
pred <- as_tibble(pred)
290290
pred$.row <- 1:nrow(pred)
@@ -340,6 +340,7 @@ predict_raw._lognet <- function(object, new_data, opts = list(), ...) {
340340
rlang::abort("Did you mean to use `new_data` instead of `newdata`?")
341341

342342
object$spec <- eval_args(object$spec)
343+
opts$s <- object$spec$args$penalty
343344
predict_raw.model_fit(object, new_data = new_data, opts = opts, ...)
344345
}
345346

R/multinom_reg.R

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,7 @@ multi_predict._multnet <-
200200
penalty <- eval_tidy(penalty)
201201

202202
dots <- list(...)
203+
203204
if (is.null(penalty)) {
204205
# See discussion in https://github.com/tidymodels/parsnip/issues/195
205206
if (!is.null(object$spec$args$penalty)) {
@@ -208,7 +209,6 @@ multi_predict._multnet <-
208209
penalty <- object$fit$lambda
209210
}
210211
}
211-
dots$s <- penalty
212212

213213
if (is.null(type))
214214
type <- "class"
@@ -221,7 +221,8 @@ multi_predict._multnet <-
221221
dots$type <- type
222222

223223
object$spec <- eval_args(object$spec)
224-
pred <- predict.model_fit(object, new_data = new_data, type = "raw", opts = dots)
224+
pred <- predict._multnet(object, new_data = new_data, type = "raw",
225+
opts = dots, penalty = penalty, multi = TRUE)
225226

226227
format_probs <- function(x) {
227228
x <- as_tibble(x)
@@ -269,5 +270,6 @@ predict_classprob._multnet <- function(object, new_data, ...) {
269270
#' @export
270271
predict_raw._multnet <- function(object, new_data, opts = list(), ...) {
271272
object$spec <- eval_args(object$spec)
273+
opts$s <- object$spec$args$penalty
272274
predict_raw.model_fit(object, new_data = new_data, opts = opts, ...)
273275
}

0 commit comments

Comments
 (0)