Skip to content

Commit 6c2952a

Browse files
committed
Updating for surv_reg() engine specific args
1 parent ede3d8a commit 6c2952a

File tree

2 files changed

+7
-3
lines changed

2 files changed

+7
-3
lines changed

R/surv_reg.R

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,8 @@ print.surv_reg <- function(x, ...) {
9696
#' @rdname surv_reg
9797
#' @export
9898
update.surv_reg <- function(object, parameters = NULL, dist = NULL, fresh = FALSE, ...) {
99-
update_dot_check(...)
99+
100+
eng_args <- update_engine_parameters(object$eng_args, ...)
100101

101102
if (!is.null(parameters)) {
102103
parameters <- check_final_param(parameters)
@@ -110,12 +111,15 @@ update.surv_reg <- function(object, parameters = NULL, dist = NULL, fresh = FALS
110111

111112
if (fresh) {
112113
object$args <- args
114+
object$eng_args <- eng_args
113115
} else {
114116
null_args <- map_lgl(args, null_value)
115117
if (any(null_args))
116118
args <- args[!null_args]
117119
if (length(args) > 0)
118120
object$args[names(args)] <- args
121+
if (length(eng_args) > 0)
122+
object$eng_args[names(eng_args)] <- eng_args
119123
}
120124

121125
new_model_spec(

tests/testthat/test_surv_reg.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,9 +60,9 @@ test_that('engine arguments', {
6060

6161

6262
test_that('updating', {
63-
expr1 <- surv_reg() %>% set_engine("flexsurv", cl = .99)
63+
expr1 <- surv_reg() %>% set_engine("flexsurv", cl = varying())
6464
expr1_exp <- surv_reg(dist = "lnorm") %>% set_engine("flexsurv", cl = .99)
65-
expect_equal(update(expr1, dist = "lnorm"), expr1_exp)
65+
expect_equal(update(expr1, dist = "lnorm", cl = 0.99), expr1_exp)
6666

6767
param_tibb <- tibble::tibble(dist = "weibull")
6868
param_list <- as.list(param_tibb)

0 commit comments

Comments
 (0)