Skip to content

Commit 30e3c68

Browse files
authored
Merge pull request #103 from tidymodels/predict-and-new-models
Predict and new models
2 parents fcace3f + 90e1514 commit 30e3c68

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

68 files changed

+3018
-262
lines changed

DESCRIPTION

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ Depends:
1717
R (>= 2.10)
1818
Imports:
1919
dplyr,
20-
rlang (>= 0.2.0.9001),
20+
rlang (>= 0.3.0.1),
2121
purrr,
2222
utils,
2323
tibble,
@@ -38,6 +38,4 @@ Suggests:
3838
C50,
3939
xgboost,
4040
covr
41-
Remotes:
42-
tidyverse/rlang,
43-
r-lib/generics
41+

NAMESPACE

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@ S3method(predict_classprob,"_lognet")
1919
S3method(predict_classprob,"_multnet")
2020
S3method(predict_classprob,model_fit)
2121
S3method(predict_confint,model_fit)
22-
S3method(predict_num,"_elnet")
23-
S3method(predict_num,model_fit)
22+
S3method(predict_numeric,"_elnet")
23+
S3method(predict_numeric,model_fit)
2424
S3method(predict_predint,model_fit)
2525
S3method(predict_quantile,model_fit)
2626
S3method(predict_raw,"_elnet")
@@ -38,12 +38,16 @@ S3method(print,multinom_reg)
3838
S3method(print,nearest_neighbor)
3939
S3method(print,rand_forest)
4040
S3method(print,surv_reg)
41+
S3method(print,svm_poly)
42+
S3method(print,svm_rbf)
4143
S3method(translate,boost_tree)
4244
S3method(translate,default)
4345
S3method(translate,mars)
4446
S3method(translate,mlp)
4547
S3method(translate,rand_forest)
4648
S3method(translate,surv_reg)
49+
S3method(translate,svm_poly)
50+
S3method(translate,svm_rbf)
4751
S3method(type_sum,model_fit)
4852
S3method(type_sum,model_spec)
4953
S3method(update,boost_tree)
@@ -55,6 +59,8 @@ S3method(update,multinom_reg)
5559
S3method(update,nearest_neighbor)
5660
S3method(update,rand_forest)
5761
S3method(update,surv_reg)
62+
S3method(update,svm_poly)
63+
S3method(update,svm_rbf)
5864
S3method(varying_args,model_spec)
5965
S3method(varying_args,recipe)
6066
S3method(varying_args,step)
@@ -92,8 +98,8 @@ export(predict_classprob)
9298
export(predict_classprob.model_fit)
9399
export(predict_confint)
94100
export(predict_confint.model_fit)
95-
export(predict_num)
96-
export(predict_num.model_fit)
101+
export(predict_numeric)
102+
export(predict_numeric.model_fit)
97103
export(predict_predint)
98104
export(predict_predint.model_fit)
99105
export(predict_quantile)
@@ -106,6 +112,8 @@ export(set_engine)
106112
export(set_mode)
107113
export(show_call)
108114
export(surv_reg)
115+
export(svm_poly)
116+
export(svm_rbf)
109117
export(translate)
110118
export(varying)
111119
export(varying_args)

R/boost_tree.R

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -359,6 +359,9 @@ xgb_pred <- function(object, newdata, ...) {
359359
#' @export
360360
multi_predict._xgb.Booster <-
361361
function(object, new_data, type = NULL, trees = NULL, ...) {
362+
if (any(names(enquos(...)) == "newdata"))
363+
stop("Did you mean to use `new_data` instead of `newdata`?", call. = FALSE)
364+
362365
if (is.null(trees))
363366
trees <- object$fit$nIter
364367
trees <- sort(trees)
@@ -388,10 +391,10 @@ xgb_by_tree <- function(tree, object, new_data, type, ...) {
388391
nms <- names(pred)
389392
} else {
390393
if (type == "class") {
391-
pred <- boost_tree_xgboost_data$classes$post(pred, object)
394+
pred <- boost_tree_xgboost_data$class$post(pred, object)
392395
pred <- tibble(.pred = factor(pred, levels = object$lvl))
393396
} else {
394-
pred <- boost_tree_xgboost_data$prob$post(pred, object)
397+
pred <- boost_tree_xgboost_data$classprob$post(pred, object)
395398
pred <- as_tibble(pred)
396399
names(pred) <- paste0(".pred_", names(pred))
397400
}
@@ -458,6 +461,9 @@ C5.0_train <-
458461
#' @export
459462
multi_predict._C5.0 <-
460463
function(object, new_data, type = NULL, trees = NULL, ...) {
464+
if (any(names(enquos(...)) == "newdata"))
465+
stop("Did you mean to use `new_data` instead of `newdata`?", call. = FALSE)
466+
461467
if (is.null(trees))
462468
trees <- min(object$fit$trials)
463469
trees <- sort(trees)

R/boost_tree_data.R

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ boost_tree_xgboost_data <-
3131
verbose = 0
3232
)
3333
),
34-
pred = list(
34+
numeric = list(
3535
pre = NULL,
3636
post = NULL,
3737
func = c(fun = "xgb_pred"),
@@ -41,7 +41,7 @@ boost_tree_xgboost_data <-
4141
newdata = quote(new_data)
4242
)
4343
),
44-
classes = list(
44+
class = list(
4545
pre = NULL,
4646
post = function(x, object) {
4747
if (is.vector(x)) {
@@ -58,7 +58,7 @@ boost_tree_xgboost_data <-
5858
newdata = quote(new_data)
5959
)
6060
),
61-
prob = list(
61+
classprob = list(
6262
pre = NULL,
6363
post = function(x, object) {
6464
if (is.vector(x)) {
@@ -97,7 +97,7 @@ boost_tree_C5.0_data <-
9797
func = c(pkg = "parsnip", fun = "C5.0_train"),
9898
defaults = list()
9999
),
100-
classes = list(
100+
class = list(
101101
pre = NULL,
102102
post = NULL,
103103
func = c(fun = "predict"),
@@ -106,7 +106,7 @@ boost_tree_C5.0_data <-
106106
newdata = quote(new_data)
107107
)
108108
),
109-
prob = list(
109+
classprob = list(
110110
pre = NULL,
111111
post = function(x, object) {
112112
as_tibble(x)
@@ -142,7 +142,7 @@ boost_tree_spark_data <-
142142
seed = expr(sample.int(10^5, 1))
143143
)
144144
),
145-
pred = list(
145+
numeric = list(
146146
pre = NULL,
147147
post = format_spark_num,
148148
func = c(pkg = "sparklyr", fun = "ml_predict"),
@@ -152,7 +152,7 @@ boost_tree_spark_data <-
152152
dataset = quote(new_data)
153153
)
154154
),
155-
classes = list(
155+
class = list(
156156
pre = NULL,
157157
post = format_spark_class,
158158
func = c(pkg = "sparklyr", fun = "ml_predict"),
@@ -162,7 +162,7 @@ boost_tree_spark_data <-
162162
dataset = quote(new_data)
163163
)
164164
),
165-
prob = list(
165+
classprob = list(
166166
pre = NULL,
167167
post = format_spark_probs,
168168
func = c(pkg = "sparklyr", fun = "ml_predict"),

R/linear_reg.R

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,11 @@
22
#'
33
#' `linear_reg` is a way to generate a _specification_ of a model
44
#' before fitting and allows the model to be created using
5-
#' different packages in R, Stan, or via Spark. The main arguments for the
6-
#' model are:
5+
#' different packages in R, Stan, keras, or via Spark. The main
6+
#' arguments for the model are:
77
#' \itemize{
88
#' \item \code{penalty}: The total amount of regularization
9-
#' in the model. Note that this must be zero for some engines .
9+
#' in the model. Note that this must be zero for some engines.
1010
#' \item \code{mixture}: The proportion of L1 regularization in
1111
#' the model. Note that this will be ignored for some engines.
1212
#' }
@@ -19,8 +19,11 @@
1919
#' @inheritParams boost_tree
2020
#' @param mode A single character string for the type of model.
2121
#' The only possible value for this model is "regression".
22-
#' @param penalty An non-negative number representing the
23-
#' total amount of regularization (`glmnet` and `spark` only).
22+
#' @param penalty An non-negative number representing the total
23+
#' amount of regularization (`glmnet`, `keras`, and `spark` only).
24+
#' For `keras` models, this corresponds to purely L2 regularization
25+
#' (aka weight decay) while the other models can be a combination
26+
#' of L1 and L2 (depending on the value of `mixture`).
2427
#' @param mixture A number between zero and one (inclusive) that
2528
#' represents the proportion of regularization that is used for the
2629
#' L2 penalty (i.e. weight decay, or ridge regression) versus L1
@@ -36,6 +39,7 @@
3639
#' \item \pkg{R}: `"lm"` or `"glmnet"`
3740
#' \item \pkg{Stan}: `"stan"`
3841
#' \item \pkg{Spark}: `"spark"`
42+
#' \item \pkg{keras}: `"keras"`
3943
#' }
4044
#'
4145
#' @section Engine Details:
@@ -59,6 +63,10 @@
5963
#' \pkg{spark}
6064
#'
6165
#' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::linear_reg(), "spark")}
66+
#'
67+
#' \pkg{keras}
68+
#'
69+
#' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::linear_reg(), "keras")}
6270
#'
6371
#' When using `glmnet` models, there is the option to pass
6472
#' multiple values (or no values) to the `penalty` argument.
@@ -211,18 +219,27 @@ organize_glmnet_pred <- function(x, object) {
211219
#' @export
212220
predict._elnet <-
213221
function(object, new_data, type = NULL, opts = list(), ...) {
222+
if (any(names(enquos(...)) == "newdata"))
223+
stop("Did you mean to use `new_data` instead of `newdata`?", call. = FALSE)
224+
214225
object$spec <- eval_args(object$spec)
215226
predict.model_fit(object, new_data = new_data, type = type, opts = opts, ...)
216227
}
217228

218229
#' @export
219-
predict_num._elnet <- function(object, new_data, ...) {
230+
predict_numeric._elnet <- function(object, new_data, ...) {
231+
if (any(names(enquos(...)) == "newdata"))
232+
stop("Did you mean to use `new_data` instead of `newdata`?", call. = FALSE)
233+
220234
object$spec <- eval_args(object$spec)
221-
predict_num.model_fit(object, new_data = new_data, ...)
235+
predict_numeric.model_fit(object, new_data = new_data, ...)
222236
}
223237

224238
#' @export
225239
predict_raw._elnet <- function(object, new_data, opts = list(), ...) {
240+
if (any(names(enquos(...)) == "newdata"))
241+
stop("Did you mean to use `new_data` instead of `newdata`?", call. = FALSE)
242+
226243
object$spec <- eval_args(object$spec)
227244
predict_raw.model_fit(object, new_data = new_data, opts = opts, ...)
228245
}
@@ -232,6 +249,9 @@ predict_raw._elnet <- function(object, new_data, opts = list(), ...) {
232249
#' @export
233250
multi_predict._elnet <-
234251
function(object, new_data, type = NULL, penalty = NULL, ...) {
252+
if (any(names(enquos(...)) == "newdata"))
253+
stop("Did you mean to use `new_data` instead of `newdata`?", call. = FALSE)
254+
235255
dots <- list(...)
236256
if (is.null(penalty))
237257
penalty <- object$fit$lambda

R/linear_reg_data.R

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,19 @@ linear_reg_arg_key <- data.frame(
44
glmnet = c( "lambda", "alpha"),
55
spark = c("reg_param", "elastic_net_param"),
66
stan = c( NA, NA),
7+
keras = c( "decay", NA),
78
stringsAsFactors = FALSE,
89
row.names = c("penalty", "mixture")
910
)
1011

1112
linear_reg_modes <- "regression"
1213

1314
linear_reg_engines <- data.frame(
14-
lm = TRUE,
15+
lm = TRUE,
1516
glmnet = TRUE,
1617
spark = TRUE,
1718
stan = TRUE,
19+
keras = TRUE,
1820
row.names = c("regression")
1921
)
2022

@@ -30,7 +32,7 @@ linear_reg_lm_data <-
3032
func = c(pkg = "stats", fun = "lm"),
3133
defaults = list()
3234
),
33-
pred = list(
35+
numeric = list(
3436
pre = NULL,
3537
post = NULL,
3638
func = c(fun = "predict"),
@@ -100,7 +102,7 @@ linear_reg_glmnet_data <-
100102
family = "gaussian"
101103
)
102104
),
103-
pred = list(
105+
numeric = list(
104106
pre = NULL,
105107
post = organize_glmnet_pred,
106108
func = c(fun = "predict"),
@@ -135,7 +137,7 @@ linear_reg_stan_data <-
135137
family = expr(stats::gaussian)
136138
)
137139
),
138-
pred = list(
140+
numeric = list(
139141
pre = NULL,
140142
post = NULL,
141143
func = c(fun = "predict"),
@@ -224,7 +226,7 @@ linear_reg_spark_data <-
224226
protect = c("x", "formula", "weight_col"),
225227
func = c(pkg = "sparklyr", fun = "ml_linear_regression")
226228
),
227-
pred = list(
229+
numeric = list(
228230
pre = NULL,
229231
post = function(results, object) {
230232
results <- dplyr::rename(results, pred = prediction)
@@ -240,5 +242,24 @@ linear_reg_spark_data <-
240242
)
241243
)
242244

243-
245+
linear_reg_keras_data <-
246+
list(
247+
libs = c("keras", "magrittr"),
248+
fit = list(
249+
interface = "matrix",
250+
protect = c("x", "y"),
251+
func = c(pkg = "parsnip", fun = "keras_mlp"),
252+
defaults = list(hidden_units = 1, act = "linear")
253+
),
254+
numeric = list(
255+
pre = NULL,
256+
post = maybe_multivariate,
257+
func = c(fun = "predict"),
258+
args =
259+
list(
260+
object = quote(object$fit),
261+
x = quote(as.matrix(new_data))
262+
)
263+
)
264+
)
244265

0 commit comments

Comments
 (0)