Skip to content

Commit f8310e5

Browse files
topepoDavisVaughanhfrick
authored
Re-registration of model information (#664)
* changes for #653 * minor refactoring * unused variable * update news * unit tests * ugly solution for comparing model info within list columns * Apply suggestions from code review Co-authored-by: Davis Vaughan <davis@rstudio.com> Co-authored-by: Hannah Frick <hfrick@users.noreply.github.com> * simplified testing of model info Co-authored-by: Davis Vaughan <davis@rstudio.com> Co-authored-by: Hannah Frick <hfrick@users.noreply.github.com>
1 parent 859f60f commit f8310e5

File tree

4 files changed

+246
-72
lines changed

4 files changed

+246
-72
lines changed

NEWS.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,10 @@
4646

4747
* xgboost engines now use the new `iterationrange` parameter instead of the deprecated `ntreelimit` (#656).
4848

49+
## Developer
50+
51+
* Models information can be re-registered as long as the information being registered is the same. This is helpful for packages that add new engines and use `devtools::load_all()` (#653).
52+
4953

5054
# parsnip 0.1.7
5155

R/aaa_models.R

Lines changed: 84 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -540,7 +540,7 @@ set_new_model <- function(model) {
540540

541541
current <- get_model_env()
542542

543-
set_env_val("models", c(current$models, model))
543+
set_env_val("models", unique(c(current$models, model)))
544544
set_env_val(model, dplyr::tibble(engine = character(0), mode = character(0)))
545545
set_env_val(
546546
paste0(model, "_pkgs"),
@@ -674,12 +674,12 @@ set_dependency <- function(model, eng, pkg = "parsnip", mode = NULL) {
674674
check_eng_val(eng)
675675
check_pkg_val(pkg)
676676

677-
current <- get_model_env()
678677
model_info <- get_from_env(model)
679678
pkg_info <- get_from_env(paste0(model, "_pkgs"))
680679

681680
# ----------------------------------------------------------------------------
682681
# Check engine
682+
683683
has_engine <-
684684
model_info %>%
685685
dplyr::distinct(engine) %>%
@@ -750,37 +750,77 @@ get_dependency <- function(model) {
750750

751751
# ------------------------------------------------------------------------------
752752

753-
#' @rdname set_new_model
754-
#' @keywords internal
755-
#' @export
756-
set_fit <- function(model, mode, eng, value) {
757-
check_model_exists(model)
758-
check_eng_val(eng)
759-
check_spec_mode_engine_val(model, eng, mode)
760-
check_fit_info(value)
753+
# This will be used to see if the same information is being registered for the
754+
# same model/mode/engine (and prediction type). If it already exists and the
755+
# new information is different, fail with a message. See issue #653
756+
is_discordant_info <- function(model, mode, eng, candidate,
757+
pred_type = NULL, component = "fit") {
758+
current <- get_from_env(paste0(model, "_", component))
761759

762-
current <- get_model_env()
763-
model_info <- get_from_env(model)
764-
old_fits <- get_from_env(paste0(model, "_fit"))
760+
# For older versions of parsnip before set_encoding()
761+
new_encoding <- is.null(current) & component == "encoding"
762+
763+
if (new_encoding) {
764+
return(TRUE)
765+
} else {
766+
current <- dplyr::filter(current, engine == eng & mode == !!mode)
767+
}
768+
769+
if (component == "predict" & !is.null(pred_type)) {
770+
771+
current <- dplyr::filter(current, type == pred_type)
772+
p_type <- paste0("and prediction type '", pred_type, "'")
773+
} else {
774+
p_type <- ""
775+
}
765776

777+
if (nrow(current) == 0) {
778+
return(TRUE)
779+
}
780+
781+
same_info <- isTRUE(all.equal(current, candidate, check.environment = FALSE))
782+
783+
if (!same_info) {
784+
rlang::abort(
785+
glue::glue(
786+
"The combination of engine '{eng}' and mode '{mode}' {p_type} already has ",
787+
"{component} data for model '{model}' and the new information being ",
788+
"registered is different."
789+
)
790+
)
791+
}
792+
793+
FALSE
794+
}
795+
796+
# Also check for general registration
797+
798+
check_unregistered <- function(model, mode, eng) {
799+
model_info <- get_from_env(model)
766800
has_engine <-
767801
model_info %>%
768802
dplyr::filter(engine == eng & mode == !!mode) %>%
769803
nrow()
770804
if (has_engine != 1) {
771-
rlang::abort(glue::glue("The combination of '{eng}' and mode '{mode}' has not ",
772-
"been registered for model '{model}'."))
805+
rlang::abort(
806+
glue::glue("The combination of engine '{eng}' and mode '{mode}' has not ",
807+
"been registered for model '{model}'.")
808+
)
773809
}
810+
invisible(NULL)
811+
}
774812

775-
has_fit <-
776-
old_fits %>%
777-
dplyr::filter(engine == eng & mode == !!mode) %>%
778-
nrow()
779813

780-
if (has_fit > 0) {
781-
rlang::abort(glue::glue("The combination of '{eng}' and mode '{mode}' ",
782-
"already has a fit component for model '{model}'."))
783-
}
814+
815+
#' @rdname set_new_model
816+
#' @keywords internal
817+
#' @export
818+
set_fit <- function(model, mode, eng, value) {
819+
check_model_exists(model)
820+
check_eng_val(eng)
821+
check_spec_mode_engine_val(model, eng, mode)
822+
check_fit_info(value)
823+
check_unregistered(model, mode, eng)
784824

785825
new_fit <-
786826
dplyr::tibble(
@@ -789,6 +829,11 @@ set_fit <- function(model, mode, eng, value) {
789829
value = list(value)
790830
)
791831

832+
if (!is_discordant_info(model, mode, eng, new_fit)) {
833+
return(invisible(NULL))
834+
}
835+
836+
old_fits <- get_from_env(paste0(model, "_fit"))
792837
updated <- try(dplyr::bind_rows(old_fits, new_fit), silent = TRUE)
793838
if (inherits(updated, "try-error")) {
794839
rlang::abort("An error occured when adding the new fit module.")
@@ -824,39 +869,25 @@ set_pred <- function(model, mode, eng, type, value) {
824869
check_eng_val(eng)
825870
check_spec_mode_engine_val(model, eng, mode)
826871
check_pred_info(value, type)
872+
check_unregistered(model, mode, eng)
827873

828-
current <- get_model_env()
829874
model_info <- get_from_env(model)
830-
old_fits <- get_from_env(paste0(model, "_predict"))
831-
832-
has_engine <-
833-
model_info %>%
834-
dplyr::filter(engine == eng & mode == !!mode) %>%
835-
nrow()
836-
if (has_engine != 1) {
837-
rlang::abort(glue::glue("The combination of '{eng}' and mode '{mode}'",
838-
"has not been registered for model '{model}'."))
839-
}
840-
841-
has_pred <-
842-
old_fits %>%
843-
dplyr::filter(engine == eng & mode == !!mode & type == !!type) %>%
844-
nrow()
845-
if (has_pred > 0) {
846-
rlang::abort(glue::glue("The combination of '{eng}', mode '{mode}', ",
847-
"and type '{type}' already has a prediction component",
848-
"for model '{model}'."))
849-
}
850875

851-
new_fit <-
876+
new_pred <-
852877
dplyr::tibble(
853878
engine = eng,
854879
mode = mode,
855880
type = type,
856881
value = list(value)
857882
)
858883

859-
updated <- try(dplyr::bind_rows(old_fits, new_fit), silent = TRUE)
884+
pred_check <- is_discordant_info(model, mode, eng, new_pred, pred_type = type, component = "predict")
885+
if (!pred_check) {
886+
return(invisible(NULL))
887+
}
888+
889+
old_pred <- get_from_env(paste0(model, "_predict"))
890+
updated <- try(dplyr::bind_rows(old_pred, new_pred), silent = TRUE)
860891
if (inherits(updated, "try-error")) {
861892
rlang::abort("An error occured when adding the new fit module.")
862893
}
@@ -1032,25 +1063,15 @@ set_encoding <- function(model, mode, eng, options) {
10321063
options <- tibble::as_tibble(options)
10331064
new_values <- dplyr::bind_cols(keys, options)
10341065

1035-
1036-
current_db_list <- ls(envir = get_model_env())
1037-
nm <- paste(model, "encoding", sep = "_")
1038-
if (any(current_db_list == nm)) {
1039-
current <- get_from_env(nm)
1040-
dup_check <-
1041-
current %>%
1042-
dplyr::inner_join(
1043-
new_values,
1044-
by = c("model", "engine", "mode", "predictor_indicators")
1045-
)
1046-
if (nrow(dup_check)) {
1047-
rlang::abort(glue::glue("Engine '{eng}' and mode '{mode}' already have defined encodings for model '{model}'."))
1048-
}
1049-
1050-
} else {
1051-
current <- NULL
1066+
enc_check <- is_discordant_info(model, mode, eng, new_values, component = "encoding")
1067+
if (!enc_check) {
1068+
return(invisible(NULL))
10521069
}
10531070

1071+
# Allow for older versions before set_encoding() was created
1072+
nm <- paste0(model, "_encoding")
1073+
current <- get_from_env(nm)
1074+
10541075
db_values <- dplyr::bind_rows(current, new_values)
10551076
set_env_val(nm, db_values)
10561077

Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
# For issue #653 we want to be able to re-run the registration code as
2+
# long as the information being registered is the same.
3+
4+
5+
test_that('re-registration of mode', {
6+
old_val <- get_from_env("bart_modes")
7+
expect_error(set_model_mode("bart", "classification"), regexp = NA)
8+
new_val <- get_from_env("bart_modes")
9+
expect_equal(old_val, new_val)
10+
})
11+
12+
test_that('re-registration of engine', {
13+
old_val <- get_from_env("bart")
14+
expect_error(
15+
set_model_engine("bart", mode = "classification", eng = "dbarts"),
16+
regexp = NA
17+
)
18+
new_val <- get_from_env("bart")
19+
expect_equal(old_val, new_val)
20+
})
21+
22+
23+
test_that('re-registration of package dependencies', {
24+
old_val <- get_from_env("bart_pkgs")
25+
expect_error(
26+
set_dependency("bart", "dbarts", "dbarts"),
27+
regexp = NA
28+
)
29+
new_val <- get_from_env("bart_pkgs")
30+
expect_equal(old_val, new_val)
31+
})
32+
33+
test_that('re-registration of fit information', {
34+
old_val <- get_from_env("bart_fit")
35+
expect_error(
36+
set_fit(
37+
model = "bart",
38+
eng = "dbarts",
39+
mode = "regression",
40+
value = list(
41+
interface = "data.frame",
42+
data = c(x = "x.train", y = "y.train"),
43+
protect = c("x", "y"),
44+
func = c(pkg = "dbarts", fun = "bart"),
45+
defaults = list(verbose = FALSE, keeptrees = TRUE, keepcall = FALSE)
46+
)
47+
),
48+
regexp = NA
49+
)
50+
new_val <- get_from_env("bart_fit")
51+
expect_equal(old_val, new_val)
52+
53+
# Fail if newly registered data is different than existing
54+
# `verbose` option is different here
55+
expect_error(
56+
set_fit(
57+
model = "bart",
58+
eng = "dbarts",
59+
mode = "regression",
60+
value = list(
61+
interface = "data.frame",
62+
data = c(x = "x.train", y = "y.train"),
63+
protect = c("x", "y"),
64+
func = c(pkg = "dbarts", fun = "bart"),
65+
defaults = list(verbose = TRUE, keeptrees = TRUE, keepcall = FALSE)
66+
)
67+
),
68+
"new information being registered is different"
69+
)
70+
})
71+
72+
test_that('re-registration of encoding information', {
73+
old_val <- get_from_env("bart_encoding")
74+
expect_error(
75+
set_encoding(
76+
model = "bart",
77+
eng = "dbarts",
78+
mode = "regression",
79+
options = list(
80+
predictor_indicators = "none",
81+
compute_intercept = FALSE,
82+
remove_intercept = FALSE,
83+
allow_sparse_x = FALSE
84+
)
85+
),
86+
regexp = NA
87+
)
88+
new_val <- get_from_env("bart_encoding")
89+
expect_equal(old_val, new_val)
90+
91+
# Fail if newly registered data is different than existing
92+
# `compute_intercept` option is different here
93+
expect_error(
94+
set_encoding(
95+
model = "bart",
96+
eng = "dbarts",
97+
mode = "regression",
98+
options = list(
99+
predictor_indicators = "none",
100+
compute_intercept = TRUE,
101+
remove_intercept = FALSE,
102+
allow_sparse_x = FALSE
103+
)
104+
),
105+
"new information being registered is different"
106+
)
107+
})
108+
109+
110+
test_that('re-registration of prediction information', {
111+
old_val <- get_from_env("bart_predict")
112+
expect_error(
113+
set_pred(
114+
model = "bart",
115+
eng = "dbarts",
116+
mode = "regression",
117+
type = "numeric",
118+
value = list(
119+
pre = NULL,
120+
post = NULL,
121+
func = c(pkg = "parsnip", fun = "dbart_predict_calc"),
122+
args =
123+
list(
124+
obj = quote(object),
125+
new_data = quote(new_data),
126+
type = "numeric"
127+
)
128+
)
129+
),
130+
regexp = NA
131+
)
132+
new_val <- get_from_env("bart_predict")
133+
expect_equal(old_val, new_val)
134+
135+
# Fail if newly registered data is different than existing
136+
# `type` option is different here
137+
expect_error(
138+
set_pred(
139+
model = "bart",
140+
eng = "dbarts",
141+
mode = "regression",
142+
type = "numeric",
143+
value = list(
144+
pre = NULL,
145+
post = NULL,
146+
func = c(pkg = "parsnip", fun = "dbart_predict_calc"),
147+
args =
148+
list(
149+
obj = quote(object),
150+
new_data = quote(new_data),
151+
type = "tuba"
152+
)
153+
)
154+
),
155+
"new information being registered is different"
156+
)
157+
})
158+

0 commit comments

Comments
 (0)