Skip to content

Commit 620adaf

Browse files
committed
classes -> class for #65
1 parent 8005421 commit 620adaf

File tree

9 files changed

+23
-23
lines changed

9 files changed

+23
-23
lines changed

R/boost_tree.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -391,7 +391,7 @@ xgb_by_tree <- function(tree, object, new_data, type, ...) {
391391
nms <- names(pred)
392392
} else {
393393
if (type == "class") {
394-
pred <- boost_tree_xgboost_data$classes$post(pred, object)
394+
pred <- boost_tree_xgboost_data$class$post(pred, object)
395395
pred <- tibble(.pred = factor(pred, levels = object$lvl))
396396
} else {
397397
pred <- boost_tree_xgboost_data$classprob$post(pred, object)

R/boost_tree_data.R

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ boost_tree_xgboost_data <-
4141
newdata = quote(new_data)
4242
)
4343
),
44-
classes = list(
44+
class = list(
4545
pre = NULL,
4646
post = function(x, object) {
4747
if (is.vector(x)) {
@@ -97,7 +97,7 @@ boost_tree_C5.0_data <-
9797
func = c(pkg = "parsnip", fun = "C5.0_train"),
9898
defaults = list()
9999
),
100-
classes = list(
100+
class = list(
101101
pre = NULL,
102102
post = NULL,
103103
func = c(fun = "predict"),
@@ -152,7 +152,7 @@ boost_tree_spark_data <-
152152
dataset = quote(new_data)
153153
)
154154
),
155-
classes = list(
155+
class = list(
156156
pre = NULL,
157157
post = format_spark_class,
158158
func = c(pkg = "sparklyr", fun = "ml_predict"),

R/logistic_reg_data.R

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ logistic_reg_glm_data <-
3333
family = expr(stats::binomial)
3434
)
3535
),
36-
classes = list(
36+
class = list(
3737
pre = NULL,
3838
post = prob_to_class_2,
3939
func = c(fun = "predict"),
@@ -109,7 +109,7 @@ logistic_reg_glmnet_data <-
109109
family = "binomial"
110110
)
111111
),
112-
classes = list(
112+
class = list(
113113
pre = NULL,
114114
post = organize_glmnet_class,
115115
func = c(fun = "predict"),
@@ -156,7 +156,7 @@ logistic_reg_stan_data <-
156156
family = expr(stats::binomial)
157157
)
158158
),
159-
classes = list(
159+
class = list(
160160
pre = NULL,
161161
post = function(x, object) {
162162
x <- object$fit$family$linkinv(x)
@@ -268,7 +268,7 @@ logistic_reg_spark_data <-
268268
family = "binomial"
269269
)
270270
),
271-
classes = list(
271+
class = list(
272272
pre = NULL,
273273
post = format_spark_class,
274274
func = c(pkg = "sparklyr", fun = "ml_predict"),

R/mars_data.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ mars_earth_data <-
3434
type = "response"
3535
)
3636
),
37-
classes = list(
37+
class = list(
3838
pre = NULL,
3939
post = function(x, object) {
4040
x <- ifelse(x[,1] >= 0.5, object$lvl[2], object$lvl[1])

R/mlp_data.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ mlp_keras_data <-
3535
x = quote(as.matrix(new_data))
3636
)
3737
),
38-
classes = list(
38+
class = list(
3939
pre = NULL,
4040
post = function(x, object) {
4141
object$lvl[x + 1]
@@ -92,7 +92,7 @@ mlp_nnet_data <-
9292
type = "raw"
9393
)
9494
),
95-
classes = list(
95+
class = list(
9696
pre = NULL,
9797
post = NULL,
9898
func = c(fun = "predict"),

R/multinom_reg_data.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ multinom_reg_glmnet_data <-
2828
family = "multinomial"
2929
)
3030
),
31-
classes = list(
31+
class = list(
3232
pre = check_glmnet_lambda,
3333
post = organize_multnet_class,
3434
func = c(fun = "predict"),
@@ -75,7 +75,7 @@ multinom_reg_spark_data <-
7575
family = "multinomial"
7676
)
7777
),
78-
classes = list(
78+
class = list(
7979
pre = NULL,
8080
post = format_spark_class,
8181
func = c(pkg = "sparklyr", fun = "ml_predict"),

R/nearest_neighbor_data.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ nearest_neighbor_kknn_data <-
4242
type = "raw"
4343
)
4444
),
45-
classes = list(
45+
class = list(
4646
pre = function(x, object) {
4747
if (!(object$fit$response %in% c("ordinal", "nominal"))) {
4848
stop("`kknn` model does not appear to use class predictions. Was ",

R/predict_class.R

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,23 +13,23 @@ predict_class.model_fit <- function (object, new_data, ...) {
1313
stop("`predict.model_fit` is for predicting factor outcomes.",
1414
call. = FALSE)
1515

16-
if (!any(names(object$spec$method) == "classes"))
16+
if (!any(names(object$spec$method) == "class"))
1717
stop("No class prediction module defined for this model.", call. = FALSE)
1818

1919
new_data <- prepare_data(object, new_data)
2020

2121
# preprocess data
22-
if (!is.null(object$spec$method$classes$pre))
23-
new_data <- object$spec$method$classes$pre(new_data, object)
22+
if (!is.null(object$spec$method$class$pre))
23+
new_data <- object$spec$method$class$pre(new_data, object)
2424

2525
# create prediction call
26-
pred_call <- make_pred_call(object$spec$method$classes)
26+
pred_call <- make_pred_call(object$spec$method$class)
2727

2828
res <- eval_tidy(pred_call)
2929

3030
# post-process the predictions
31-
if(!is.null(object$spec$method$classes$post)) {
32-
res <- object$spec$method$classes$post(res, object)
31+
if(!is.null(object$spec$method$class$post)) {
32+
res <- object$spec$method$class$post(res, object)
3333
}
3434

3535
# coerce levels to those in `object`

R/rand_forest_data.R

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ rand_forest_ranger_data <-
123123
verbose = FALSE
124124
)
125125
),
126-
classes = list(
126+
class = list(
127127
pre = NULL,
128128
post = ranger_class_pred,
129129
func = c(fun = "predict"),
@@ -200,7 +200,7 @@ rand_forest_randomForest_data <-
200200
newdata = quote(new_data)
201201
)
202202
),
203-
classes = list(
203+
class = list(
204204
pre = NULL,
205205
post = NULL,
206206
func = c(fun = "predict"),
@@ -257,7 +257,7 @@ rand_forest_spark_data <-
257257
dataset = quote(new_data)
258258
)
259259
),
260-
classes = list(
260+
class = list(
261261
pre = NULL,
262262
post = format_spark_class,
263263
func = c(pkg = "sparklyr", fun = "ml_predict"),

0 commit comments

Comments
 (0)