Skip to content

Commit af3e0ae

Browse files
authored
Merge pull request #508 from tidymodels/export-convert-functions
Export `.convert_*()` functions
2 parents 2a8da60 + cd7f3d4 commit af3e0ae

File tree

10 files changed

+294
-131
lines changed

10 files changed

+294
-131
lines changed

NAMESPACE

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,10 @@ S3method(varying_args,recipe)
109109
S3method(varying_args,step)
110110
export("%>%")
111111
export(.cols)
112+
export(.convert_form_to_xy_fit)
113+
export(.convert_form_to_xy_new)
114+
export(.convert_xy_to_form_fit)
115+
export(.convert_xy_to_form_new)
112116
export(.dat)
113117
export(.facts)
114118
export(.lvls)

NEWS.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
# parsnip (development version)
22

3+
* The helper functions `.convert_form_to_xy_fit()`, `.convert_form_to_xy_new()`, `.convert_xy_to_form_fit()`, and `.convert_xy_to_form_new()` for converting between formula and matrix interface are now exported for developer use (#508).
4+
35
# parsnip 0.1.6
46

57
## Model Specification Changes

R/convert_data.R

Lines changed: 135 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -1,51 +1,75 @@
11
# ------------------------------------------------------------------------------
22

3-
# Functions to take a formula interface and get the resulting
4-
# objects (y, x, weights, etc) back. For the most part, this
5-
# emulates the internals of `lm` (and also see the notes at
6-
# https://developer.r-project.org/model-fitting-functions.html).
7-
8-
# `convert_form_to_xy_fit` is for when the data are created for modeling.
9-
# It saves both the data objects as well as the objects needed
10-
# when new data are predicted (e.g. `terms`, etc.).
11-
12-
# `convert_form_to_xy_new` is used when new samples are being predicted
13-
# and only requires the predictors to be available.
14-
3+
#' Helper functions to convert between formula and matrix interface
4+
#'
5+
#' @description
6+
#' Functions to take a formula interface and get the resulting
7+
#' objects (y, x, weights, etc) back or the other way around. The functions are
8+
#' intended for developer use. For the most part, this emulates the internals
9+
#' of `lm()` (and also see the notes at
10+
#' https://developer.r-project.org/model-fitting-functions.html).
11+
#'
12+
#' `.convert_form_to_xy_fit()` and `.convert_xy_to_form_fit()` are for when the
13+
#' data are created for modeling.
14+
#' `.convert_form_to_xy_fit()` saves both the data objects as well as the objects
15+
#' needed when new data are predicted (e.g. `terms`, etc.).
16+
#'
17+
#' `.convert_form_to_xy_new()` and `.convert_xy_to_form_new()` are used when new
18+
#' samples are being predicted and only require the predictors to be available.
19+
#'
20+
#' @param data A data frame containing all relevant variables (e.g. outcome(s),
21+
#' predictors, case weights, etc).
22+
#' @param ... Additional arguments passed to [stats::model.frame()] and
23+
#' specification of `offset` and `contrasts`.
24+
#' @param na.action A function which indicates what should happen when the data
25+
#' contain NAs.
26+
#' @param indicators A string describing whether and how to create
27+
#' indicator/dummy variables from factor predictors. Possible options are
28+
#' `"none"`, `"traditional"`, and `"one_hot"`.
29+
#' @param composition A string describing whether the resulting `x` and `y`
30+
#' should be returned as a `"matrix"` or a `"data.frame"`.
31+
#' @param remove_intercept A logical indicating whether to remove the intercept
32+
#' column after `model.matrix()` is finished.
33+
#' @inheritParams fit.model_spec
34+
#' @rdname convert_helpers
35+
#' @keywords internal
36+
#' @export
37+
#'
1538
#' @importFrom stats .checkMFClasses .getXlevels delete.response
1639
#' @importFrom stats model.offset model.weights na.omit na.pass
17-
18-
convert_form_to_xy_fit <- function(
19-
formula,
20-
data,
21-
...,
22-
na.action = na.omit,
23-
indicators = "traditional",
24-
composition = "data.frame",
25-
remove_intercept = TRUE
26-
) {
27-
if (!(composition %in% c("data.frame", "matrix")))
40+
.convert_form_to_xy_fit <- function(formula,
41+
data,
42+
...,
43+
na.action = na.omit,
44+
indicators = "traditional",
45+
composition = "data.frame",
46+
remove_intercept = TRUE) {
47+
if (!(composition %in% c("data.frame", "matrix"))) {
2848
rlang::abort("`composition` should be either 'data.frame' or 'matrix'.")
49+
}
2950

3051
## Assemble model.frame call from call arguments
3152
mf_call <- quote(model.frame(formula, data))
3253
mf_call$na.action <- match.call()$na.action # TODO this should work better
3354
dots <- quos(...)
3455
check_form_dots(dots)
35-
for(i in seq_along(dots))
36-
mf_call[[ names(dots)[i] ]] <- get_expr(dots[[i]])
56+
for (i in seq_along(dots)) {
57+
mf_call[[names(dots)[i]]] <- get_expr(dots[[i]])
58+
}
3759

3860
# setup contrasts
39-
if (any(names(dots) == "contrasts"))
61+
if (any(names(dots) == "contrasts")) {
4062
contrasts <- eval_tidy(dots[["contrasts"]])
41-
else
63+
} else {
4264
contrasts <- NULL
65+
}
4366

4467
# For new data, save the expression to create offsets (if any)
45-
if (any(names(dots) == "offset"))
68+
if (any(names(dots) == "offset")) {
4669
offset_expr <- get_expr(dots[["offset"]])
47-
else
70+
} else {
4871
offset_expr <- NULL
72+
}
4973

5074
mod_frame <- eval_tidy(mf_call)
5175
mod_terms <- attr(mod_frame, "terms")
@@ -57,20 +81,22 @@ convert_form_to_xy_fit <- function(
5781
y <- model.response(mod_frame, type = "any")
5882

5983
# if y is a numeric vector, model.response() added names
60-
if(is.atomic(y)) {
84+
if (is.atomic(y)) {
6185
names(y) <- NULL
6286
}
6387

6488
w <- as.vector(model.weights(mod_frame))
65-
if (!is.null(w) && !is.numeric(w))
89+
if (!is.null(w) && !is.numeric(w)) {
6690
rlang::abort("`weights` must be a numeric vector")
91+
}
6792

6893
offset <- as.vector(model.offset(mod_frame))
6994
if (!is.null(offset)) {
70-
if (length(offset) != nrow(mod_frame))
95+
if (length(offset) != nrow(mod_frame)) {
7196
rlang::abort(
7297
glue::glue("The offset data should have {nrow(mod_frame)} elements.")
73-
)
98+
)
99+
}
74100
}
75101

76102
if (indicators != "none") {
@@ -82,13 +108,13 @@ convert_form_to_xy_fit <- function(
82108
options(contrasts = new_contr)
83109
}
84110
x <- model.matrix(mod_terms, mod_frame, contrasts)
85-
86111
} else {
87112
# this still ignores -vars in formula
88113
x <- model.frame(mod_terms, data)
89114
y_cols <- attr(mod_terms, "response")
90-
if (length(y_cols) > 0)
91-
x <- x[,-y_cols, drop = FALSE]
115+
if (length(y_cols) > 0) {
116+
x <- x[, -y_cols, drop = FALSE]
117+
}
92118
}
93119

94120
if (remove_intercept) {
@@ -103,8 +129,9 @@ convert_form_to_xy_fit <- function(
103129
)
104130

105131
if (composition == "data.frame") {
106-
if (is.matrix(y))
132+
if (is.matrix(y)) {
107133
y <- as.data.frame(y)
134+
}
108135
res <-
109136
list(
110137
x = as.data.frame(x),
@@ -119,8 +146,9 @@ convert_form_to_xy_fit <- function(
119146
} else {
120147
# Since a matrix is requested, try to convert y but check
121148
# to see if it is possible
122-
if (will_make_matrix(y))
149+
if (will_make_matrix(y)) {
123150
y <- as.matrix(y)
151+
}
124152
res <-
125153
list(
126154
x = x,
@@ -136,10 +164,19 @@ convert_form_to_xy_fit <- function(
136164
res
137165
}
138166

139-
convert_form_to_xy_new <- function(object, new_data, na.action = na.pass,
140-
composition = "data.frame") {
141-
if (!(composition %in% c("data.frame", "matrix")))
167+
168+
#' @param object An object of class `model_fit`.
169+
#' @inheritParams predict.model_fit
170+
#' @rdname convert_helpers
171+
#' @keywords internal
172+
#' @export
173+
.convert_form_to_xy_new <- function(object,
174+
new_data,
175+
na.action = na.pass,
176+
composition = "data.frame") {
177+
if (!(composition %in% c("data.frame", "matrix"))) {
142178
rlang::abort("`composition` should be either 'data.frame' or 'matrix'.")
179+
}
143180

144181
mod_terms <- object$terms
145182
mod_terms <- delete.response(mod_terms)
@@ -153,29 +190,38 @@ convert_form_to_xy_new <- function(object, new_data, na.action = na.pass,
153190
# If offset was done at least once in-line
154191
if (!is.null(offset_cols)) {
155192
offset <- rep(0, nrow(new_data))
156-
for (i in offset_cols)
193+
for (i in offset_cols) {
157194
offset <- offset +
158-
eval_tidy(attr(mod_terms, "variables")[[i + 1]],
159-
new_data) # use na.action here and below?
160-
} else offset <- NULL
195+
eval_tidy(
196+
attr(mod_terms, "variables")[[i + 1]],
197+
new_data
198+
) # use na.action here and below?
199+
}
200+
} else {
201+
offset <- NULL
202+
}
161203

162204
if (!is.null(object$offset_expr)) {
163-
if (is.null(offset))
205+
if (is.null(offset)) {
164206
offset <- rep(0, nrow(new_data))
207+
}
165208
offset <- offset + eval_tidy(object$offset_expr, new_data)
166209
}
167210

168211
new_data <-
169-
model.frame(mod_terms,
170-
new_data,
171-
na.action = na.action,
172-
xlev = object$xlevels)
212+
model.frame(
213+
mod_terms,
214+
new_data,
215+
na.action = na.action,
216+
xlev = object$xlevels
217+
)
173218

174219
cl <- attr(mod_terms, "dataClasses")
175-
if (!is.null(cl))
220+
if (!is.null(cl)) {
176221
.checkMFClasses(cl, new_data)
222+
}
177223

178-
if(object$options$indicators != "none") {
224+
if (object$options$indicators != "none") {
179225
if (object$options$indicators == "one_hot") {
180226
old_contr <- options("contrasts")$contrasts
181227
on.exit(options(contrasts = old_contr))
@@ -187,15 +233,16 @@ convert_form_to_xy_new <- function(object, new_data, na.action = na.pass,
187233
model.matrix(mod_terms, new_data, contrasts.arg = object$contrasts)
188234
}
189235

190-
if(object$options$remove_intercept) {
236+
if (object$options$remove_intercept) {
191237
new_data <- new_data[, colnames(new_data) != "(Intercept)", drop = FALSE]
192238
}
193239

194-
if (composition == "data.frame")
240+
if (composition == "data.frame") {
195241
new_data <- as.data.frame(new_data)
196-
else {
197-
if (will_make_matrix(new_data))
242+
} else {
243+
if (will_make_matrix(new_data)) {
198244
new_data <- as.matrix(new_data)
245+
}
199246
}
200247
list(x = new_data, offset = offset)
201248
}
@@ -205,21 +252,35 @@ convert_form_to_xy_new <- function(object, new_data, na.action = na.pass,
205252
# The other direction where we make a formula from the data
206253
# objects
207254

208-
#' @importFrom dplyr bind_cols
209255
# TODO slots for other roles
210-
convert_xy_to_form_fit <- function(x, y, weights = NULL, y_name = "..y",
211-
remove_intercept = TRUE) {
212-
if (is.vector(x))
256+
#' @param weights A numeric vector containing the weights.
257+
#' @param y_name A string specifying the name of the outcome.
258+
#' @inheritParams fit.model_spec
259+
#' @inheritParams .convert_form_to_xy_fit
260+
#'
261+
#' @rdname convert_helpers
262+
#' @keywords internal
263+
#' @export
264+
#'
265+
#' @importFrom dplyr bind_cols
266+
.convert_xy_to_form_fit <- function(x,
267+
y,
268+
weights = NULL,
269+
y_name = "..y",
270+
remove_intercept = TRUE) {
271+
if (is.vector(x)) {
213272
rlang::abort("`x` cannot be a vector.")
273+
}
214274

215-
if(remove_intercept) {
275+
if (remove_intercept) {
216276
x <- x[, colnames(x) != "(Intercept)", drop = FALSE]
217277
}
218278

219279
rn <- rownames(x)
220280

221-
if (!is.data.frame(x))
281+
if (!is.data.frame(x)) {
222282
x <- as.data.frame(x)
283+
}
223284

224285
if (is.matrix(y)) {
225286
y <- as.data.frame(y)
@@ -235,14 +296,17 @@ convert_xy_to_form_fit <- function(x, y, weights = NULL, y_name = "..y",
235296
form <- make_formula(names(x), names(y))
236297

237298
x <- bind_cols(x, y)
238-
if(!is.null(rn) & !inherits(x, "tbl_df"))
299+
if (!is.null(rn) & !inherits(x, "tbl_df")) {
239300
rownames(x) <- rn
301+
}
240302

241303
if (!is.null(weights)) {
242-
if (!is.numeric(weights))
304+
if (!is.numeric(weights)) {
243305
rlang::abort("`weights` must be a numeric vector")
244-
if (length(weights) != nrow(x))
306+
}
307+
if (length(weights) != nrow(x)) {
245308
rlang::abort(glue::glue("`weights` should have {nrow(x)} elements"))
309+
}
246310
}
247311

248312
res <- list(
@@ -254,10 +318,14 @@ convert_xy_to_form_fit <- function(x, y, weights = NULL, y_name = "..y",
254318
res
255319
}
256320

257-
convert_xy_to_form_new <- function(object, new_data) {
321+
#' @rdname convert_helpers
322+
#' @keywords internal
323+
#' @export
324+
.convert_xy_to_form_new <- function(object, new_data) {
258325
new_data <- new_data[, object$x_var, drop = FALSE]
259-
if (!is.data.frame(new_data))
326+
if (!is.data.frame(new_data)) {
260327
new_data <- as.data.frame(new_data)
328+
}
261329
new_data
262330
}
263331

@@ -350,4 +418,3 @@ maybe_data_frame <- function(x) {
350418
}
351419
x
352420
}
353-

0 commit comments

Comments
 (0)