Skip to content

Commit 452be81

Browse files
committed
added pipe-able functions for setting elements
1 parent be44293 commit 452be81

File tree

4 files changed

+124
-0
lines changed

4 files changed

+124
-0
lines changed

NAMESPACE

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,8 @@ export(predict_predint.model_fit)
7979
export(predict_raw)
8080
export(predict_raw.model_fit)
8181
export(rand_forest)
82+
export(set_args)
83+
export(set_mode)
8284
export(show_call)
8385
export(surv_reg)
8486
export(translate)

R/arguments.R

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,3 +64,56 @@ check_others <- function(args, obj, core_args) {
6464
}
6565
args
6666
}
67+
68+
#' Change elements of a model specification
69+
#'
70+
#' `set_args` can be used to modify the arguments of a model specification while
71+
#' `set_mode` is used to change the model's mode.
72+
#'
73+
#' @param object A model specification.
74+
#' @param ... One or more named model arguments.
75+
#' @param mode A character string for the model type (e.g. "classification" or
76+
#' "regression")
77+
#' @return An updated model object.
78+
#' @details `set_args` will replace existing values of the arguments.
79+
#'
80+
#' @examples
81+
#' rand_forest()
82+
#'
83+
#' rand_forest() %>%
84+
#' set_args(mtry = 3, importance = TRUE) %>%
85+
#' set_mode("regression")
86+
#'
87+
#' @export
88+
set_args <- function(object, ...) {
89+
the_dots <- list(...)
90+
if (length(the_dots) == 0)
91+
stop("Please pass at least one named argument.", call. = FALSE)
92+
main_args <- names(object$args)
93+
new_args <- names(the_dots)
94+
for (i in new_args) {
95+
if (any(main_args == i)) {
96+
object$args[[i]] <- the_dots[[i]]
97+
} else {
98+
object$others[[i]] <- the_dots[[i]]
99+
}
100+
}
101+
object
102+
}
103+
104+
#' @rdname set_args
105+
#' @export
106+
set_mode <- function(object, mode) {
107+
if (is.null(mode))
108+
return(object)
109+
mode <- mode[1]
110+
if (!(any(all_modes == mode))) {
111+
stop("`mode` should be one of ",
112+
paste0("'", all_modes, "'", collapse = ", "),
113+
call. = FALSE)
114+
}
115+
object$mode <- mode
116+
object
117+
}
118+
119+

man/set_args.Rd

Lines changed: 37 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
library(testthat)
2+
library(parsnip)
3+
library(dplyr)
4+
5+
context("changing arguments and engine")
6+
7+
test_that('pipe arguments', {
8+
mod_1 <- rand_forest() %>%
9+
set_args(mtry = 1, something = "blah")
10+
expect_equal(mod_1$args$mtry, 1)
11+
expect_equal(mod_1$others$something, "blah")
12+
13+
mod_2 <- rand_forest(mtry = 2, others = list(var = "x")) %>%
14+
set_args(mtry = 1, something = "blah")
15+
expect_equal(mod_2$args$mtry, 1)
16+
expect_equal(mod_2$others$something, "blah")
17+
expect_equal(mod_2$others$var, "x")
18+
19+
expect_error(rand_forest() %>% set_args())
20+
21+
})
22+
23+
24+
test_that('pipe engine', {
25+
mod_1 <- rand_forest() %>%
26+
set_mode("regression")
27+
expect_equal(mod_1$mode, "regression")
28+
29+
expect_error(rand_forest() %>% set_mode())
30+
expect_error(rand_forest() %>% set_mode(2))
31+
expect_error(rand_forest() %>% set_mode("haberdashery"))
32+
})

0 commit comments

Comments
 (0)