Skip to content

Commit 3a52078

Browse files
committed
Updating for multinom() engine specific args
1 parent c71ad65 commit 3a52078

File tree

2 files changed

+8
-4
lines changed

2 files changed

+8
-4
lines changed

R/multinom_reg.R

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,8 @@ update.multinom_reg <-
114114
parameters = NULL,
115115
penalty = NULL, mixture = NULL,
116116
fresh = FALSE, ...) {
117-
update_dot_check(...)
117+
118+
eng_args <- update_engine_parameters(object$eng_args, ...)
118119

119120
if (!is.null(parameters)) {
120121
parameters <- check_final_param(parameters)
@@ -128,12 +129,15 @@ update.multinom_reg <-
128129

129130
if (fresh) {
130131
object$args <- args
132+
object$eng_args <- eng_args
131133
} else {
132134
null_args <- map_lgl(args, null_value)
133135
if (any(null_args))
134136
args <- args[!null_args]
135137
if (length(args) > 0)
136138
object$args[names(args)] <- args
139+
if (length(eng_args) > 0)
140+
object$eng_args[names(eng_args)] <- eng_args
137141
}
138142

139143
new_model_spec(

tests/testthat/test_multinom_reg.R

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ test_that('updating', {
8080
expr1 <- multinom_reg() %>% set_engine("glmnet", intercept = TRUE)
8181
expr1_exp <- multinom_reg(mixture = 0) %>% set_engine("glmnet", intercept = TRUE)
8282

83-
expr2 <- multinom_reg(mixture = varying()) %>% set_engine("glmnet")
83+
expr2 <- multinom_reg(mixture = varying()) %>% set_engine("glmnet", nlambda = varying())
8484
expr2_exp <- multinom_reg(mixture = varying()) %>% set_engine("glmnet", nlambda = 10)
8585

8686
expr3 <- multinom_reg(mixture = 0, penalty = varying()) %>% set_engine("glmnet")
@@ -92,8 +92,8 @@ test_that('updating', {
9292
expr5 <- multinom_reg(mixture = 1) %>% set_engine("glmnet", nlambda = 10)
9393
expr5_exp <- multinom_reg(mixture = 1) %>% set_engine("glmnet", nlambda = 10, pmax = 2)
9494

95-
# expect_equal(update(expr1 %>% set_engine("glmnet"), mixture = 0), expr1_exp)
96-
expect_equal(update(expr2) %>% set_engine("glmnet", nlambda = 10), expr2_exp)
95+
expect_equal(update(expr1, mixture = 0), expr1_exp)
96+
expect_equal(update(expr2, nlambda = 10), expr2_exp)
9797
expect_equal(update(expr3, mixture = 1, fresh = TRUE) %>% set_engine("glmnet"), expr3_exp)
9898
# expect_equal(update(expr4 %>% set_engine("glmnet", pmax = 2)), expr4_exp)
9999
expect_equal(update(expr5) %>% set_engine("glmnet", nlambda = 10, pmax = 2), expr5_exp)

0 commit comments

Comments
 (0)