Skip to content

Commit bbce70f

Browse files
committed
linting and use of call_modify()
1 parent b9ac7e6 commit bbce70f

File tree

3 files changed

+89
-75
lines changed

3 files changed

+89
-75
lines changed

R/boost_tree.R

Lines changed: 33 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ print.boost_tree <- function(x, ...) {
137137
cat("Boosted Tree Model Specification (", x$mode, ")\n\n", sep = "")
138138
model_printer(x, ...)
139139

140-
if(!is.null(x$method$fit$args)) {
140+
if (!is.null(x$method$fit$args)) {
141141
cat("Model fit template:\n")
142142
print(show_call(x))
143143
}
@@ -211,14 +211,15 @@ translate.boost_tree <- function(x, engine = x$engine, ...) {
211211
x <- translate.default(x, engine, ...)
212212

213213
if (engine == "spark") {
214-
if (x$mode == "unknown")
214+
if (x$mode == "unknown") {
215215
stop(
216216
"For spark boosted trees models, the mode cannot be 'unknown' ",
217217
"if the specification is to be translated.",
218218
call. = FALSE
219219
)
220-
else
220+
} else {
221221
x$method$fit$args$type <- x$mode
222+
}
222223
}
223224
x
224225
}
@@ -282,27 +283,33 @@ xgb_train <- function(
282283
}
283284
}
284285

285-
if (is.data.frame(x))
286+
if (is.data.frame(x)) {
286287
x <- as.matrix(x) # maybe use model.matrix here?
288+
}
287289

288290
n <- nrow(x)
289291
p <- ncol(x)
290292

291-
if (!inherits(x, "xgb.DMatrix"))
293+
if (!inherits(x, "xgb.DMatrix")) {
292294
x <- xgboost::xgb.DMatrix(x, label = y, missing = NA)
293-
else
295+
} else {
294296
xgboost::setinfo(x, "label", y)
297+
}
295298

296299
# translate `subsample` and `colsample_bytree` to be on (0, 1] if not
297-
if(subsample > 1)
300+
if (subsample > 1) {
298301
subsample <- subsample/n
299-
if(subsample > 1)
302+
}
303+
if (subsample > 1) {
300304
subsample <- 1
305+
}
301306

302-
if(colsample_bytree > 1)
307+
if (colsample_bytree > 1) {
303308
colsample_bytree <- colsample_bytree/p
304-
if(colsample_bytree > 1)
309+
}
310+
if (colsample_bytree > 1) {
305311
colsample_bytree <- 1
312+
}
306313

307314
arg_list <- list(
308315
eta = eta,
@@ -321,18 +328,19 @@ xgb_train <- function(
321328
nrounds = nrounds,
322329
objective = loss
323330
)
324-
if (!is.null(num_class))
331+
if (!is.null(num_class)) {
325332
main_args$num_class <- num_class
333+
}
326334

327335
call <- make_call(fun = "xgb.train", ns = "xgboost", main_args)
328336

329337
# override or add some other args
330338
others <- list(...)
331339
others <-
332340
others[!(names(others) %in% c("data", "weights", "nrounds", "num_class", names(arg_list)))]
333-
if (length(others) > 0)
334-
for (i in names(others))
335-
call[[i]] <- others[[i]]
341+
if (length(others) > 0) {
342+
call <- rlang::call_modify(call, !!!others)
343+
}
336344

337345
eval_tidy(call, env = current_env())
338346
}
@@ -348,7 +356,7 @@ xgb_pred <- function(object, newdata, ...) {
348356

349357
x = switch(
350358
object$params$objective,
351-
"reg:linear" =, "reg:logistic" =, "binary:logistic" = res,
359+
"reg:linear" = , "reg:logistic" = , "binary:logistic" = res,
352360
"binary:logitraw" = stats::binomial()$linkinv(res),
353361
"multi:softprob" = matrix(res, ncol = object$params$num_class, byrow = TRUE),
354362
res
@@ -360,11 +368,13 @@ xgb_pred <- function(object, newdata, ...) {
360368
#' @export
361369
multi_predict._xgb.Booster <-
362370
function(object, new_data, type = NULL, trees = NULL, ...) {
363-
if (any(names(enquos(...)) == "newdata"))
371+
if (any(names(enquos(...)) == "newdata")) {
364372
stop("Did you mean to use `new_data` instead of `newdata`?", call. = FALSE)
373+
}
365374

366-
if (is.null(trees))
375+
if (is.null(trees)) {
367376
trees <- object$fit$nIter
377+
}
368378
trees <- sort(trees)
369379

370380
if (is.null(type)) {
@@ -374,9 +384,8 @@ multi_predict._xgb.Booster <-
374384
type <- "numeric"
375385
}
376386

377-
res <-
378-
map_df(trees, xgb_by_tree, object = object,
379-
new_data = new_data, type = type, ...)
387+
res <- map_df(trees, xgb_by_tree, object = object, new_data = new_data,
388+
type = type, ...)
380389
res <- arrange(res, .row, trees)
381390
res <- split(res[, -1], res$.row)
382391
names(res) <- NULL
@@ -449,19 +458,18 @@ C5.0_train <-
449458
ctrl <- call2("C5.0Control", .ns = "C50")
450459
ctrl$minCases <- minCases
451460
ctrl$sample <- sample
452-
for(i in names(ctrl_args))
453-
ctrl[[i]] <- ctrl_args[[i]]
461+
ctrl <- rlang::call_modify(ctrl, !!!ctrl_args)
454462

455463
fit_call <- call2("C5.0", .ns = "C50")
456464
fit_call$x <- expr(x)
457465
fit_call$y <- expr(y)
458466
fit_call$trials <- trials
459467
fit_call$control <- ctrl
460-
if(!is.null(weights))
468+
if (!is.null(weights)) {
461469
fit_call$weights <- quote(weights)
470+
}
471+
fit_call <- rlang::call_modify(fit_call, !!!fit_args)
462472

463-
for(i in names(fit_args))
464-
fit_call[[i]] <- fit_args[[i]]
465473
eval_tidy(fit_call)
466474
}
467475

R/decision_tree.R

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -273,18 +273,16 @@ rpart_train <-
273273
ctrl$minsplit <- minsplit
274274
ctrl$maxdepth <- maxdepth
275275
ctrl$cp <- cp
276-
for(i in names(ctrl_args))
277-
ctrl[[i]] <- ctrl_args[[i]]
276+
ctrl <- rlang::call_modify(ctrl, !!!ctrl_args)
278277

279278
fit_call <- call2("rpart", .ns = "rpart")
280279
fit_call$formula <- expr(formula)
281280
fit_call$data <- expr(data)
282281
fit_call$control <- ctrl
283-
if(!is.null(weights))
282+
if (!is.null(weights)) {
284283
fit_call$weights <- quote(weights)
285-
286-
for(i in names(fit_args))
287-
fit_call[[i]] <- fit_args[[i]]
284+
}
285+
fit_call <- rlang::call_modify(fit_call, !!!fit_args)
288286

289287
eval_tidy(fit_call)
290288
}

R/mlp.R

Lines changed: 52 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -275,27 +275,31 @@ keras_mlp <-
275275
seeds = sample.int(10^5, size = 3),
276276
...) {
277277

278-
if(decay > 0 & dropout > 0)
278+
if (decay > 0 & dropout > 0) {
279279
stop("Please use either dropoput or weight decay.", call. = FALSE)
280-
281-
if (!is.matrix(x))
280+
}
281+
if (!is.matrix(x)) {
282282
x <- as.matrix(x)
283+
}
283284

284-
if(is.character(y))
285+
if (is.character(y)) {
285286
y <- as.factor(y)
287+
}
286288
factor_y <- is.factor(y)
287289

288-
if (factor_y)
290+
if (factor_y) {
289291
y <- class2ind(y)
290-
else {
291-
if (isTRUE(ncol(y) > 1))
292+
} else {
293+
if (isTRUE(ncol(y) > 1)) {
292294
y <- as.matrix(y)
293-
else
295+
} else {
294296
y <- matrix(y, ncol = 1)
297+
}
295298
}
296299

297300
model <- keras::keras_model_sequential()
298-
if(decay > 0) {
301+
302+
if (decay > 0) {
299303
model %>%
300304
keras::layer_dense(
301305
units = hidden_units,
@@ -313,53 +317,57 @@ keras_mlp <-
313317
kernel_initializer = keras::initializer_glorot_uniform(seed = seeds[1])
314318
)
315319
}
316-
if(dropout > 0)
320+
321+
if (dropout > 0) {
317322
model %>%
318-
keras::layer_dense(
319-
units = hidden_units,
320-
activation = act,
321-
input_shape = ncol(x),
322-
kernel_initializer = keras::initializer_glorot_uniform(seed = seeds[1])
323-
) %>%
324-
keras::layer_dropout(rate = dropout, seed = seeds[2])
325-
326-
if (factor_y)
323+
keras::layer_dense(
324+
units = hidden_units,
325+
activation = act,
326+
input_shape = ncol(x),
327+
kernel_initializer = keras::initializer_glorot_uniform(seed = seeds[1])
328+
) %>%
329+
keras::layer_dropout(rate = dropout, seed = seeds[2])
330+
}
331+
332+
if (factor_y) {
327333
model <- model %>%
328-
keras::layer_dense(
329-
units = ncol(y),
330-
activation = 'softmax',
331-
kernel_initializer = keras::initializer_glorot_uniform(seed = seeds[3])
332-
)
333-
else
334+
keras::layer_dense(
335+
units = ncol(y),
336+
activation = 'softmax',
337+
kernel_initializer = keras::initializer_glorot_uniform(seed = seeds[3])
338+
)
339+
} else {
334340
model <- model %>%
335-
keras::layer_dense(
336-
units = ncol(y),
337-
activation = 'linear',
338-
kernel_initializer = keras::initializer_glorot_uniform(seed = seeds[3])
339-
)
341+
keras::layer_dense(
342+
units = ncol(y),
343+
activation = 'linear',
344+
kernel_initializer = keras::initializer_glorot_uniform(seed = seeds[3])
345+
)
346+
}
340347

341348
arg_values <- parse_keras_args(...)
342-
compile_call <- expr(
343-
keras::compile(object = model)
344-
)
345-
if(!any(names(arg_values$compile) == "loss"))
346-
compile_call$loss <-
347-
if(factor_y) "binary_crossentropy" else "mse"
348-
if(!any(names(arg_values$compile) == "optimizer"))
349+
compile_call <- expr(keras::compile(object = model))
350+
if (!any(names(arg_values$compile) == "loss")) {
351+
if (factor_y) {
352+
compile_call$loss <- "binary_crossentropy"
353+
} else {
354+
compile_call$loss <- "mse"
355+
}
356+
}
357+
358+
if (!any(names(arg_values$compile) == "optimizer")) {
349359
compile_call$optimizer <- "adam"
350-
for(arg in names(arg_values$compile))
351-
compile_call[[arg]] <- arg_values$compile[[arg]]
360+
}
361+
362+
compile_call <- rlang::call_modify(compile_call, !!!arg_values$compile)
352363

353364
model <- eval_tidy(compile_call)
354365

355-
fit_call <- expr(
356-
keras::fit(object = model)
357-
)
366+
fit_call <- expr(keras::fit(object = model))
358367
fit_call$x <- quote(x)
359368
fit_call$y <- quote(y)
360369
fit_call$epochs <- epochs
361-
for(arg in names(arg_values$fit))
362-
fit_call[[arg]] <- arg_values$fit[[arg]]
370+
fit_call <- rlang::call_modify(fit_call, !!!arg_values$fit)
363371

364372
history <- eval_tidy(fit_call)
365373
model

0 commit comments

Comments
 (0)