Skip to content

Commit 34c0323

Browse files
committed
working versions of fit for formulas and recipes
1 parent c36e7de commit 34c0323

File tree

6 files changed

+88
-29
lines changed

6 files changed

+88
-29
lines changed

NAMESPACE

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,5 +22,7 @@ importFrom(rlang,ll)
2222
importFrom(rlang,missing_arg)
2323
importFrom(rlang,na_lgl)
2424
importFrom(rlang,quos)
25+
importFrom(stats,as.formula)
26+
importFrom(stats,model.frame)
2527
importFrom(utils,capture.output)
2628
importFrom(utils,installed.packages)

R/fit.R

Lines changed: 78 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,9 @@
1+
# General TODOs
2+
# - think about case weights in each instance below
3+
# - try/catch all model fit evaluations
4+
# - option to capture output/verboseness
5+
6+
17

28
#' Fit a Model Specification to a Dataset
39
#'
@@ -11,7 +17,7 @@
1117
fit <- function (object, ...)
1218
UseMethod("fit")
1319

14-
# The S3 part here is awful for now
20+
# The S3 part here is awful (for now I hope)
1521

1622
#' @importFrom utils capture.output
1723
# fit_formula <- function(object, formula, data, verboseness = 0, engine = "ranger") {
@@ -44,10 +50,10 @@ fit.model_spec <- function(object, x, engine = object$engine, ...) {
4450
res <- fit_formula(object, formula = x, ...)
4551
} else {
4652
if (inherits(x, c("matrix", "data.frame"))) {
47-
res <- fit_xy(object, formula = x, ...)
53+
res <- fit_xy(object, x = x, ...)
4854
} else {
4955
if (inherits(x, "recipe")) {
50-
res <- fit_recipe(object, formula = x, ...)
56+
res <- fit_recipe(object, recipe = x, ...)
5157
} else {
5258
stop("`x` should be a formula, data frame, matrix, or recipe")
5359
}
@@ -58,57 +64,115 @@ fit.model_spec <- function(object, x, engine = object$engine, ...) {
5864

5965
###################################################################
6066

61-
fit_formula <- function(object, formula = x, engine = engine, ...) {
67+
#' @importFrom rlang eval_tidy quos
68+
#' @importFrom stats as.formula
69+
fit_formula <- function(object, formula, engine = engine, ...) {
6270
opts <- quos(...)
71+
6372
if(!any(names(opts) == "data"))
6473
stop("Please pass a data frame with the `data` argument.",
6574
call. = FALSE)
75+
76+
# TODO Should probably just load the namespace
77+
for(pkg in object$method$library)
78+
suppressPackageStartupMessages(library(pkg, character.only = TRUE))
6679

6780
# Look up the model's interface (e.g. formula, recipes, etc)
6881
# and delagate to the connector functions (`formula_to_recipe` etc)
69-
82+
if(object$method$interface == "formula") {
83+
fit_expr <- sub_arg_values(object$method$fit, opts["data"])
84+
fit_expr$formula <- as.formula(eval(formula))
85+
res <- rlang:::eval_tidy(fit_expr)
86+
} else {
87+
if(object$method$interface %in% c("data.frame", "matrix")) {
88+
res <- formula_to_xy(object = object, formula = formula, data = opts["data"])
89+
} else {
90+
stop("I don't know about the ",
91+
object$method$interface, " interface.",
92+
call. = FALSE)
93+
}
94+
}
95+
res
7096
}
7197

72-
fit_xy <- function(object, formula = x, ...) {
98+
fit_xy <- function(object, x, ...) {
99+
# Look up the model's interface (e.g. formula, recipes, etc)
100+
# and delagate to the connector functions (`xy_to_formula` etc)
101+
opts <- quos(...)
102+
103+
if(!any(names(opts) == "y"))
104+
stop("Please pass a data frame with the `y` argument.",
105+
call. = FALSE)
106+
107+
# TODO Should probably just load the namespace
108+
for(pkg in object$method$library)
109+
suppressPackageStartupMessages(library(pkg, character.only = TRUE))
110+
73111
# Look up the model's interface (e.g. formula, recipes, etc)
74112
# and delagate to the connector functions (`xy_to_formula` etc)
113+
if(object$method$interface == "formula") {
114+
res <- xy_to_formula(object = object, x = x, y = opts["y"])
115+
} else {
116+
if(object$method$interface %in% c("data.frame", "matrix")) {
117+
fit_expr <- sub_arg_values(object$method$fit, opts["y"])
118+
res <- rlang:::eval_tidy(fit_expr)
119+
} else {
120+
stop("I don't know about the ",
121+
object$method$interface, " interface.",
122+
call. = FALSE)
123+
}
124+
}
125+
res
75126
}
76127

77-
fit_recipe <- function(object, formula = x, ...) {
128+
fit_recipe <- function(object, recipe, ...) {
78129
# Look up the model's interface (e.g. formula, recipes, etc)
79130
# and delagate to the connector functions (`recipe_to_formula` etc)
80131
}
81132

82133
###################################################################
83134

84-
formula_to_recipe <- function(formula, data) {
135+
formula_to_recipe <- function(object, formula, data) {
85136
# execute the formula
86137
# extract terms _and roles_
87138
# put into recipe
88139

89140
}
90141

91-
formula_to_xy <- function(formula, data) {
142+
#' @importFrom stats model.frame
143+
formula_to_xy <- function(object, formula, data) {
144+
# TODO how do we fill in the other standard things here (subset, contrasts etc)?
145+
# TODO add a "matrix" option here and invoke model.matrix
92146

147+
# Q: avoid eval using ?rlang:::get_expr(data[["data"]])
148+
x <- stats::model.frame(formula, eval_tidy(data[["data"]]))
149+
y <- model.response(x, "numeric")
150+
eval_tidy(object$method$fit)
93151
}
94152

95153
###################################################################
96154

97-
recipe_to_formula <- function(recipe, data) {
155+
recipe_to_formula <- function(object, recipe, data) {
98156

99157
}
100158

101-
recipe_to_xy <- function(recipe, data) {
159+
recipe_to_xy <- function(object, recipe, data) {
102160

103161
}
104162

105163
###################################################################
106164

107-
xy_to_formula <- function(x, y) {
108-
165+
xy_to_formula <- function(object, x, y) {
166+
if(!is.data.frame(x))
167+
x <- as.data.frame(x)
168+
x$.y <- eval_tidy(y[["y"]])
169+
fit_expr <- object$method$fit
170+
fit_expr$formula <- as.formula(.y ~ .)
171+
fit_expr$data <- quote(x)
172+
rlang:::eval_tidy(fit_expr)
109173
}
110174

111-
xy_to_recipe <- function(x, y) {
175+
xy_to_recipe <- function(object, x, y) {
112176

113177
}
114178

R/rand_forest.R

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,7 @@ get_randomForest_regression <- function () {
205205
list(library = libs, interface = interface, fit = fit, protect = protect)
206206
}
207207

208-
get_sparklyr_regression <- function () {
208+
get_spark_regression <- function () {
209209
libs <- "sparklyr"
210210
interface <- "data.frame" # adjust this to something else
211211
protect = c("x", "formula", "label_col", "features_col")
@@ -236,7 +236,6 @@ get_sparklyr_regression <- function () {
236236
list(library = libs, interface = interface, fit = fit, protect = protect)
237237
}
238238

239-
240239
get_ranger_classification <- function () {
241240
libs <- "ranger"
242241
interface <- "formula"
@@ -312,7 +311,7 @@ get_randomForest_classification <- function () {
312311
list(library = libs, interface = interface, fit = fit, protect = protect)
313312
}
314313

315-
get_sparklyr_regression <- function () {
314+
get_spark_regression <- function () {
316315
libs <- "sparklyr"
317316
interface <- "data.frame" # adjust this to something else
318317
protect = c("x", "formula", "label_col", "features_col")
@@ -346,7 +345,6 @@ get_sparklyr_regression <- function () {
346345
list(library = libs, interface = interface, fit = fit, protect = protect)
347346
}
348347

349-
350348
###################################################################
351349

352350
# finalizing the model consists of:
@@ -447,7 +445,7 @@ update.rand_forest <-
447445
rand_forest_arg_key <- data.frame(
448446
randomForest = c("mtry", "ntree", "nodesize"),
449447
ranger = c("mtry", "num.trees", "min.node.size"),
450-
sparklyr =
448+
spark =
451449
c("feature_subset_strategy", "num_trees", "min_instances_per_node"),
452450
stringsAsFactors = FALSE,
453451
row.names = c("mtry", "trees", "min_n")
@@ -458,7 +456,7 @@ rand_forest_modes <- c("classification", "regression", "unknown")
458456
rand_forest_engines <- data.frame(
459457
ranger = c(TRUE, TRUE, FALSE),
460458
randomForest = c(TRUE, TRUE, FALSE),
461-
sparklyr = c(TRUE, TRUE, FALSE),
459+
spark = c(TRUE, TRUE, FALSE),
462460
row.names = c("classification", "regression", "unknown")
463461
)
464462

man/finalize.Rd

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/rand_forest.Rd

Lines changed: 2 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/varying.Rd

Lines changed: 1 addition & 6 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)