Skip to content

Commit ec6b08e

Browse files
committed
Updating for svm_poly() engine specific args
1 parent 6c2952a commit ec6b08e

File tree

2 files changed

+8
-4
lines changed

2 files changed

+8
-4
lines changed

R/svm_poly.R

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,8 @@ update.svm_poly <-
9898
cost = NULL, degree = NULL, scale_factor = NULL, margin = NULL,
9999
fresh = FALSE,
100100
...) {
101-
update_dot_check(...)
101+
102+
eng_args <- update_engine_parameters(object$eng_args, ...)
102103

103104
if (!is.null(parameters)) {
104105
parameters <- check_final_param(parameters)
@@ -115,12 +116,15 @@ update.svm_poly <-
115116

116117
if (fresh) {
117118
object$args <- args
119+
object$eng_args <- eng_args
118120
} else {
119121
null_args <- map_lgl(args, null_value)
120122
if (any(null_args))
121123
args <- args[!null_args]
122124
if (length(args) > 0)
123125
object$args[names(args)] <- args
126+
if (length(eng_args) > 0)
127+
object$eng_args[names(eng_args)] <- eng_args
124128
}
125129

126130
new_model_spec(

tests/testthat/test_svm_poly.R

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -80,14 +80,14 @@ test_that('updating', {
8080
expr1 <- svm_poly(mode = "regression") %>% set_engine("kernlab", cross = 10)
8181
expr1_exp <- svm_poly(mode = "regression", degree = 1) %>% set_engine("kernlab", cross = 10)
8282

83-
expr2 <- svm_poly(mode = "regression", degree = varying()) %>% set_engine("kernlab")
84-
expr2_exp <- svm_poly(mode = "regression", degree = varying(), scale_factor = 1) %>% set_engine("kernlab")
83+
expr2 <- svm_poly(mode = "regression", degree = varying()) %>% set_engine("kernlab", cross = varying())
84+
expr2_exp <- svm_poly(mode = "regression", degree = varying(), scale_factor = 1) %>% set_engine("kernlab", cross = 10)
8585

8686
expr3 <- svm_poly(mode = "regression", degree = 2, scale_factor = varying()) %>% set_engine("kernlab")
8787
expr3_exp <- svm_poly(mode = "regression", degree = 3) %>% set_engine("kernlab")
8888

8989
expect_equal(update(expr1, degree = 1), expr1_exp)
90-
expect_equal(update(expr2, scale_factor = 1), expr2_exp)
90+
expect_equal(update(expr2, scale_factor = 1, cross = 10), expr2_exp)
9191
expect_equal(update(expr3, degree = 3, fresh = TRUE), expr3_exp)
9292

9393
param_tibb <- tibble::tibble(degree = 3, cost = 10)

0 commit comments

Comments
 (0)