Skip to content

Commit 31d5149

Browse files
committed
changed model object naming convention
1 parent 22192e6 commit 31d5149

File tree

4 files changed

+21
-15
lines changed

4 files changed

+21
-15
lines changed

R/engines.R

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,12 @@ get_model_objects <- function(x, engine) {
33
if(x$mode == "unknown")
44
stop("Please specify a mode for the model (e.g. regression, classification, etc.) ",
55
"so that the model code can be finalized", call. = FALSE)
6-
nm <- paste("get", engine, x$mode, sep = "_")
6+
cls <- class(x)
7+
cls <- cls[cls != "model_spec"]
8+
# This is a short term hack to get most general class
9+
# Q: do we need mode-specific classes?
10+
cls <- cls[which.min(nchar(cls))]
11+
nm <- paste(cls, engine, x$mode, sep = "_")
712
res <- try(get(nm), silent = TRUE)
813
if(inherits(res, "try-error"))
914
stop("Can't find model object ", nm)

R/fit.R

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# - try/catch all model fit evaluations
44
# - option to capture output/verboseness
55
# - devise a unit test plan that does not add pkg deps for each model
6+
# - where/how to add data checks (e.g. factors for classification)
67

78

89
#' Fit a Model Specification to a Dataset

R/rand_forest.R

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
#' @examples
4141
#' rand_forest(mode = "classification", trees = 2000)
4242
#'
43-
#' # Parameters can be reprresented by a placeholder:
43+
#' # Parameters can be represented by a placeholder:
4444
#' rand_forest(mode = "regression", mtry = varying())
4545

4646
rand_forest <- function (mode, ...)
@@ -52,8 +52,8 @@ rand_forest <- function (mode, ...)
5252
#' Possible values for this model are "unknown", "regression", or
5353
#' "classification".
5454
#' @param engine_args A named list of arguments to be used by the
55-
#' underlying models (e.g., [ranger::ranger()],
56-
#' [randomForest::randomForest()], etc.). These are not evaluated
55+
#' underlying models (e.g., `ranger::ranger`,
56+
#' `randomForest::randomForest`, etc.). These are not evaluated
5757
#' until the model is fit and will be substituted into the model
5858
#' fit expression.
5959
#' @param mtry An integer for the number of predictors that will
@@ -130,7 +130,7 @@ print.rand_forest <- function(x, ...) {
130130
## in the ellipses), when should the ellipses be removed? Maybe right
131131
## before evaluation since `update` might be invoked to change those.
132132

133-
get_ranger_regression <- function () {
133+
rand_forest_ranger_regression <- function () {
134134
libs <- "ranger"
135135
interface <- "formula"
136136
protect = c("ranger", "formula", "data", "case.weights")
@@ -166,11 +166,11 @@ get_ranger_regression <- function () {
166166
status.variable.name = NULL,
167167
classification = NULL
168168
)
169-
)
169+
)
170170
list(library = libs, interface = interface, fit = fit, protect = protect)
171171
}
172172

173-
get_randomForest_regression <- function () {
173+
rand_forest_randomForest_regression <- function () {
174174
libs <- "randomForest"
175175
interface <- "data.frame"
176176
protect = c("randomForest", "x", "y")
@@ -205,7 +205,7 @@ get_randomForest_regression <- function () {
205205
list(library = libs, interface = interface, fit = fit, protect = protect)
206206
}
207207

208-
get_spark_regression <- function () {
208+
rand_forest_spark_regression <- function () {
209209
libs <- "sparklyr"
210210
interface <- "data.frame" # adjust this to something else
211211
protect = c("x", "formula", "label_col", "features_col")
@@ -236,7 +236,7 @@ get_spark_regression <- function () {
236236
list(library = libs, interface = interface, fit = fit, protect = protect)
237237
}
238238

239-
get_ranger_classification <- function () {
239+
rand_forest_ranger_classification <- function () {
240240
libs <- "ranger"
241241
interface <- "formula"
242242
protect = c("ranger", "formula", "data", "case.weights")
@@ -276,7 +276,7 @@ get_ranger_classification <- function () {
276276
list(library = libs, interface = interface, fit = fit, protect = protect)
277277
}
278278

279-
get_randomForest_classification <- function () {
279+
rand_forest_randomForest_classification <- function () {
280280
libs <- "randomForest"
281281
interface <- "data.frame"
282282
protect = c("randomForest", "x", "y")
@@ -311,7 +311,7 @@ get_randomForest_classification <- function () {
311311
list(library = libs, interface = interface, fit = fit, protect = protect)
312312
}
313313

314-
get_spark_regression <- function () {
314+
rand_forest_spark_regression <- function () {
315315
libs <- "sparklyr"
316316
interface <- "data.frame" # adjust this to something else
317317
protect = c("x", "formula", "label_col", "features_col")
@@ -396,7 +396,7 @@ finalize.rand_forest <- function(x, engine = NULL, ...) {
396396
if (length(x$others) > 0)
397397
x$method$fit <- sub_arg_values(x$method$fit, x$others, ignore = x$method$protect)
398398

399-
# remove NULL and unmodified argiment values
399+
# remove NULL and unmodified argument values
400400
modifed_args <- names(real_args)[!vapply(real_args, null_value, lgl(1))]
401401
x$method$fit <- prune_expr(x$method$fit, x$method$protect, c(modifed_args, names(x$others)))
402402
x

man/rand_forest.Rd

Lines changed: 3 additions & 3 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)