Skip to content

Commit e6d93c7

Browse files
authored
Merge branch 'master' into knn-multi-predict
2 parents f120919 + 7829acb commit e6d93c7

File tree

13 files changed

+253
-94
lines changed

13 files changed

+253
-94
lines changed

NAMESPACE

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22

33
S3method(fit,model_spec)
44
S3method(fit_xy,model_spec)
5+
S3method(has_multi_pred,default)
6+
S3method(has_multi_pred,model_fit)
7+
S3method(has_multi_pred,workflow)
58
S3method(multi_predict,"_C5.0")
69
S3method(multi_predict,"_earth")
710
S3method(multi_predict,"_elnet")
@@ -92,6 +95,7 @@ export(get_fit)
9295
export(get_from_env)
9396
export(get_model_env)
9497
export(get_pred_type)
98+
export(has_multi_pred)
9599
export(keras_mlp)
96100
export(linear_reg)
97101
export(logistic_reg)
@@ -211,4 +215,5 @@ importFrom(utils,capture.output)
211215
importFrom(utils,getFromNamespace)
212216
importFrom(utils,globalVariables)
213217
importFrom(utils,head)
218+
importFrom(utils,methods)
214219
importFrom(vctrs,vec_unique)

R/boost_tree.R

Lines changed: 33 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ print.boost_tree <- function(x, ...) {
137137
cat("Boosted Tree Model Specification (", x$mode, ")\n\n", sep = "")
138138
model_printer(x, ...)
139139

140-
if(!is.null(x$method$fit$args)) {
140+
if (!is.null(x$method$fit$args)) {
141141
cat("Model fit template:\n")
142142
print(show_call(x))
143143
}
@@ -211,14 +211,15 @@ translate.boost_tree <- function(x, engine = x$engine, ...) {
211211
x <- translate.default(x, engine, ...)
212212

213213
if (engine == "spark") {
214-
if (x$mode == "unknown")
214+
if (x$mode == "unknown") {
215215
stop(
216216
"For spark boosted trees models, the mode cannot be 'unknown' ",
217217
"if the specification is to be translated.",
218218
call. = FALSE
219219
)
220-
else
220+
} else {
221221
x$method$fit$args$type <- x$mode
222+
}
222223
}
223224
x
224225
}
@@ -282,27 +283,33 @@ xgb_train <- function(
282283
}
283284
}
284285

285-
if (is.data.frame(x))
286+
if (is.data.frame(x)) {
286287
x <- as.matrix(x) # maybe use model.matrix here?
288+
}
287289

288290
n <- nrow(x)
289291
p <- ncol(x)
290292

291-
if (!inherits(x, "xgb.DMatrix"))
293+
if (!inherits(x, "xgb.DMatrix")) {
292294
x <- xgboost::xgb.DMatrix(x, label = y, missing = NA)
293-
else
295+
} else {
294296
xgboost::setinfo(x, "label", y)
297+
}
295298

296299
# translate `subsample` and `colsample_bytree` to be on (0, 1] if not
297-
if(subsample > 1)
300+
if (subsample > 1) {
298301
subsample <- subsample/n
299-
if(subsample > 1)
302+
}
303+
if (subsample > 1) {
300304
subsample <- 1
305+
}
301306

302-
if(colsample_bytree > 1)
307+
if (colsample_bytree > 1) {
303308
colsample_bytree <- colsample_bytree/p
304-
if(colsample_bytree > 1)
309+
}
310+
if (colsample_bytree > 1) {
305311
colsample_bytree <- 1
312+
}
306313

307314
arg_list <- list(
308315
eta = eta,
@@ -321,18 +328,19 @@ xgb_train <- function(
321328
nrounds = nrounds,
322329
objective = loss
323330
)
324-
if (!is.null(num_class))
331+
if (!is.null(num_class)) {
325332
main_args$num_class <- num_class
333+
}
326334

327335
call <- make_call(fun = "xgb.train", ns = "xgboost", main_args)
328336

329337
# override or add some other args
330338
others <- list(...)
331339
others <-
332340
others[!(names(others) %in% c("data", "weights", "nrounds", "num_class", names(arg_list)))]
333-
if (length(others) > 0)
334-
for (i in names(others))
335-
call[[i]] <- others[[i]]
341+
if (length(others) > 0) {
342+
call <- rlang::call_modify(call, !!!others)
343+
}
336344

337345
eval_tidy(call, env = current_env())
338346
}
@@ -348,7 +356,7 @@ xgb_pred <- function(object, newdata, ...) {
348356

349357
x = switch(
350358
object$params$objective,
351-
"reg:linear" =, "reg:logistic" =, "binary:logistic" = res,
359+
"reg:linear" = , "reg:logistic" = , "binary:logistic" = res,
352360
"binary:logitraw" = stats::binomial()$linkinv(res),
353361
"multi:softprob" = matrix(res, ncol = object$params$num_class, byrow = TRUE),
354362
res
@@ -362,11 +370,13 @@ xgb_pred <- function(object, newdata, ...) {
362370
#' @param trees An integer vector for the number of trees in the ensemble.
363371
multi_predict._xgb.Booster <-
364372
function(object, new_data, type = NULL, trees = NULL, ...) {
365-
if (any(names(enquos(...)) == "newdata"))
373+
if (any(names(enquos(...)) == "newdata")) {
366374
stop("Did you mean to use `new_data` instead of `newdata`?", call. = FALSE)
375+
}
367376

368-
if (is.null(trees))
377+
if (is.null(trees)) {
369378
trees <- object$fit$nIter
379+
}
370380
trees <- sort(trees)
371381

372382
if (is.null(type)) {
@@ -376,9 +386,8 @@ multi_predict._xgb.Booster <-
376386
type <- "numeric"
377387
}
378388

379-
res <-
380-
map_df(trees, xgb_by_tree, object = object,
381-
new_data = new_data, type = type, ...)
389+
res <- map_df(trees, xgb_by_tree, object = object, new_data = new_data,
390+
type = type, ...)
382391
res <- arrange(res, .row, trees)
383392
res <- split(res[, -1], res$.row)
384393
names(res) <- NULL
@@ -451,20 +460,18 @@ C5.0_train <-
451460
ctrl <- call2("C5.0Control", .ns = "C50")
452461
ctrl$minCases <- minCases
453462
ctrl$sample <- sample
454-
for(i in names(ctrl_args))
455-
ctrl[[i]] <- ctrl_args[[i]]
463+
ctrl <- rlang::call_modify(ctrl, !!!ctrl_args)
456464

457465
fit_call <- call2("C5.0", .ns = "C50")
458466
fit_call$x <- expr(x)
459467
fit_call$y <- expr(y)
460468
fit_call$trials <- trials
461469
fit_call$control <- ctrl
462-
if(!is.null(weights))
470+
if (!is.null(weights)) {
463471
fit_call$weights <- quote(weights)
464-
465-
for (i in names(fit_args)) {
466-
fit_call[[i]] <- fit_args[[i]]
467472
}
473+
fit_call <- rlang::call_modify(fit_call, !!!fit_args)
474+
468475
eval_tidy(fit_call)
469476
}
470477

R/decision_tree.R

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -187,8 +187,6 @@ translate.decision_tree <- function(x, engine = x$engine, ...) {
187187
"if the specification is to be translated.",
188188
call. = FALSE
189189
)
190-
} else {
191-
arg_vals$type <- x$mode
192190
}
193191

194192
# See "Details" in ?ml_random_forest_classifier. `feature_subset_strategy`
@@ -273,18 +271,16 @@ rpart_train <-
273271
ctrl$minsplit <- minsplit
274272
ctrl$maxdepth <- maxdepth
275273
ctrl$cp <- cp
276-
for(i in names(ctrl_args))
277-
ctrl[[i]] <- ctrl_args[[i]]
274+
ctrl <- rlang::call_modify(ctrl, !!!ctrl_args)
278275

279276
fit_call <- call2("rpart", .ns = "rpart")
280277
fit_call$formula <- expr(formula)
281278
fit_call$data <- expr(data)
282279
fit_call$control <- ctrl
283-
if(!is.null(weights))
280+
if (!is.null(weights)) {
284281
fit_call$weights <- quote(weights)
285-
286-
for(i in names(fit_args))
287-
fit_call[[i]] <- fit_args[[i]]
282+
}
283+
fit_call <- rlang::call_modify(fit_call, !!!fit_args)
288284

289285
eval_tidy(fit_call)
290286
}

R/mlp.R

Lines changed: 52 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -275,27 +275,31 @@ keras_mlp <-
275275
seeds = sample.int(10^5, size = 3),
276276
...) {
277277

278-
if(decay > 0 & dropout > 0)
278+
if (decay > 0 & dropout > 0) {
279279
stop("Please use either dropoput or weight decay.", call. = FALSE)
280-
281-
if (!is.matrix(x))
280+
}
281+
if (!is.matrix(x)) {
282282
x <- as.matrix(x)
283+
}
283284

284-
if(is.character(y))
285+
if (is.character(y)) {
285286
y <- as.factor(y)
287+
}
286288
factor_y <- is.factor(y)
287289

288-
if (factor_y)
290+
if (factor_y) {
289291
y <- class2ind(y)
290-
else {
291-
if (isTRUE(ncol(y) > 1))
292+
} else {
293+
if (isTRUE(ncol(y) > 1)) {
292294
y <- as.matrix(y)
293-
else
295+
} else {
294296
y <- matrix(y, ncol = 1)
297+
}
295298
}
296299

297300
model <- keras::keras_model_sequential()
298-
if(decay > 0) {
301+
302+
if (decay > 0) {
299303
model %>%
300304
keras::layer_dense(
301305
units = hidden_units,
@@ -313,53 +317,57 @@ keras_mlp <-
313317
kernel_initializer = keras::initializer_glorot_uniform(seed = seeds[1])
314318
)
315319
}
316-
if(dropout > 0)
320+
321+
if (dropout > 0) {
317322
model %>%
318-
keras::layer_dense(
319-
units = hidden_units,
320-
activation = act,
321-
input_shape = ncol(x),
322-
kernel_initializer = keras::initializer_glorot_uniform(seed = seeds[1])
323-
) %>%
324-
keras::layer_dropout(rate = dropout, seed = seeds[2])
325-
326-
if (factor_y)
323+
keras::layer_dense(
324+
units = hidden_units,
325+
activation = act,
326+
input_shape = ncol(x),
327+
kernel_initializer = keras::initializer_glorot_uniform(seed = seeds[1])
328+
) %>%
329+
keras::layer_dropout(rate = dropout, seed = seeds[2])
330+
}
331+
332+
if (factor_y) {
327333
model <- model %>%
328-
keras::layer_dense(
329-
units = ncol(y),
330-
activation = 'softmax',
331-
kernel_initializer = keras::initializer_glorot_uniform(seed = seeds[3])
332-
)
333-
else
334+
keras::layer_dense(
335+
units = ncol(y),
336+
activation = 'softmax',
337+
kernel_initializer = keras::initializer_glorot_uniform(seed = seeds[3])
338+
)
339+
} else {
334340
model <- model %>%
335-
keras::layer_dense(
336-
units = ncol(y),
337-
activation = 'linear',
338-
kernel_initializer = keras::initializer_glorot_uniform(seed = seeds[3])
339-
)
341+
keras::layer_dense(
342+
units = ncol(y),
343+
activation = 'linear',
344+
kernel_initializer = keras::initializer_glorot_uniform(seed = seeds[3])
345+
)
346+
}
340347

341348
arg_values <- parse_keras_args(...)
342-
compile_call <- expr(
343-
keras::compile(object = model)
344-
)
345-
if(!any(names(arg_values$compile) == "loss"))
346-
compile_call$loss <-
347-
if(factor_y) "binary_crossentropy" else "mse"
348-
if(!any(names(arg_values$compile) == "optimizer"))
349+
compile_call <- expr(keras::compile(object = model))
350+
if (!any(names(arg_values$compile) == "loss")) {
351+
if (factor_y) {
352+
compile_call$loss <- "binary_crossentropy"
353+
} else {
354+
compile_call$loss <- "mse"
355+
}
356+
}
357+
358+
if (!any(names(arg_values$compile) == "optimizer")) {
349359
compile_call$optimizer <- "adam"
350-
for(arg in names(arg_values$compile))
351-
compile_call[[arg]] <- arg_values$compile[[arg]]
360+
}
361+
362+
compile_call <- rlang::call_modify(compile_call, !!!arg_values$compile)
352363

353364
model <- eval_tidy(compile_call)
354365

355-
fit_call <- expr(
356-
keras::fit(object = model)
357-
)
366+
fit_call <- expr(keras::fit(object = model))
358367
fit_call$x <- quote(x)
359368
fit_call$y <- quote(y)
360369
fit_call$epochs <- epochs
361-
for(arg in names(arg_values$fit))
362-
fit_call[[arg]] <- arg_values$fit[[arg]]
370+
fit_call <- rlang::call_modify(fit_call, !!!arg_values$fit)
363371

364372
history <- eval_tidy(fit_call)
365373
model

0 commit comments

Comments
 (0)