Skip to content

Commit f252377

Browse files
authored
Merge branch 'master' into modelspec_predict
2 parents 9b44807 + 73f001e commit f252377

File tree

14 files changed

+912
-235
lines changed

14 files changed

+912
-235
lines changed

NAMESPACE

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,13 @@ S3method(multi_predict,"_lognet")
99
S3method(multi_predict,"_multnet")
1010
S3method(multi_predict,"_xgb.Booster")
1111
S3method(multi_predict,default)
12+
S3method(nullmodel,default)
1213
S3method(predict,"_elnet")
1314
S3method(predict,"_lognet")
1415
S3method(predict,"_multnet")
1516
S3method(predict,model_fit)
1617
S3method(predict,model_spec)
18+
S3method(predict,nullmodel)
1719
S3method(predict_class,"_lognet")
1820
S3method(predict_class,model_fit)
1921
S3method(predict_classprob,"_lognet")
@@ -38,6 +40,7 @@ S3method(print,model_fit)
3840
S3method(print,model_spec)
3941
S3method(print,multinom_reg)
4042
S3method(print,nearest_neighbor)
43+
S3method(print,nullmodel)
4144
S3method(print,rand_forest)
4245
S3method(print,surv_reg)
4346
S3method(print,svm_poly)
@@ -96,6 +99,8 @@ export(model_printer)
9699
export(multi_predict)
97100
export(multinom_reg)
98101
export(nearest_neighbor)
102+
export(null_model)
103+
export(nullmodel)
99104
export(predict.model_fit)
100105
export(predict_class)
101106
export(predict_class.model_fit)
@@ -123,9 +128,6 @@ export(svm_rbf)
123128
export(translate)
124129
export(varying)
125130
export(varying_args)
126-
export(varying_args.model_spec)
127-
export(varying_args.recipe)
128-
export(varying_args.step)
129131
export(xgb_train)
130132
import(rlang)
131133
importFrom(dplyr,arrange)
@@ -147,6 +149,7 @@ importFrom(dplyr,tally)
147149
importFrom(dplyr,vars)
148150
importFrom(generics,fit)
149151
importFrom(generics,fit_xy)
152+
importFrom(generics,varying_args)
150153
importFrom(glue,glue_collapse)
151154
importFrom(magrittr,"%>%")
152155
importFrom(purrr,as_vector)

NEWS.md

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,24 @@
11
# parsnip 0.0.1.9000
22

3-
## Bug fixes
3+
## Other Changes
4+
5+
* `varying_args()` now has a `full` argument to control whether the full set
6+
of possible varying arguments is returned (as opposed to only the arguments
7+
that are actually varying).
8+
9+
## Bug Fixes
10+
11+
* `varying_args()` now uses the version from the `generics` package. This means
12+
that the first argument, `x`, has been renamed to `object` to align with
13+
generics.
14+
15+
* For the recipes step method of `varying_args()`, there is now error checking
16+
to catch if a user tries to specify an argument that _cannot_ be varying as
17+
varying (for example, the `id`) (#132).
18+
19+
* `find_varying()`, the internal function for detecting varying arguments,
20+
now returns correct results when a size 0 argument is provided. It can also now
21+
detect varying arguments nested deeply into a call (#131, #134).
422

523
* For multinomial regression, the `.pred_` prefix is now only added to prediction
624
column names once (#107).

R/nullmodel.R

Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,178 @@
1+
#' Fit a simple, non-informative model
2+
#'
3+
#' Fit a single mean or largest class model
4+
#'
5+
#' \code{nullmodel} emulates other model building functions, but returns the
6+
#' simplest model possible given a training set: a single mean for numeric
7+
#' outcomes and the most prevalent class for factor outcomes. When class
8+
#' probabilities are requested, the percentage of the training set samples with
9+
#' the most prevalent class is returned.
10+
#'
11+
#' @aliases nullmodel nullmodel.default predict.nullmodel
12+
#' @param x An optional matrix or data frame of predictors. These values are
13+
#' not used in the model fit
14+
#' @param y A numeric vector (for regression) or factor (for classification) of
15+
#' outcomes
16+
#' @param \dots Optional arguments (not yet used)
17+
#' @param object An object of class \code{nullmodel}
18+
#' @param new_data A matrix or data frame of predictors (only used to determine
19+
#' the number of predictions to return)
20+
#' @param type Either "raw" (for regression), "class" or "prob" (for
21+
#' classification)
22+
#' @return The output of \code{nullmodel} is a list of class \code{nullmodel}
23+
#' with elements \item{call }{the function call} \item{value }{the mean of
24+
#' \code{y} or the most prevalent class} \item{levels }{when \code{y} is a
25+
#' factor, a vector of levels. \code{NULL} otherwise} \item{pct }{when \code{y}
26+
#' is a factor, a data frame with a column for each class (\code{NULL}
27+
#' otherwise). The column for the most prevalent class has the proportion of
28+
#' the training samples with that class (the other columns are zero). } \item{n
29+
#' }{the number of elements in \code{y}}
30+
#'
31+
#' \code{predict.nullmodel} returns a either a factor or numeric vector
32+
#' depending on the class of \code{y}. All predictions are always the same.
33+
#' @keywords models
34+
#' @examples
35+
#'
36+
#' outcome <- factor(sample(letters[1:2],
37+
#' size = 100,
38+
#' prob = c(.1, .9),
39+
#' replace = TRUE))
40+
#' useless <- nullmodel(y = outcome)
41+
#' useless
42+
#' predict(useless, matrix(NA, nrow = 5))
43+
#'
44+
#' @export
45+
nullmodel <- function (x, ...) UseMethod("nullmodel")
46+
47+
#' @export
48+
#' @rdname nullmodel
49+
nullmodel.default <- function(x = NULL, y, ...) {
50+
51+
52+
if(is.factor(y)) {
53+
lvls <- levels(y)
54+
tab <- table(y)
55+
value <- names(tab)[which.max(tab)]
56+
pct <- tab/sum(tab)
57+
} else {
58+
lvls <- NULL
59+
pct <- NULL
60+
if(is.null(dim(y))) {
61+
value <- mean(y, na.rm = TRUE)
62+
} else {
63+
value <- colMeans(y, na.rm = TRUE)
64+
}
65+
}
66+
67+
structure(
68+
list(call = match.call(),
69+
value = value,
70+
levels = lvls,
71+
pct = pct,
72+
n = length(y[[1]])),
73+
class = "nullmodel")
74+
}
75+
76+
#' @export
77+
#' @rdname nullmodel
78+
print.nullmodel <- function(x, ...) {
79+
cat("Null",
80+
ifelse(is.null(x$levels), "Classification", "Regression"),
81+
"Model\n")
82+
x$call
83+
84+
if (length(x$value) == 1) {
85+
cat("Predicted Value:",
86+
ifelse(is.null(x$levels), format(x$value), x$value),
87+
"\n")
88+
} else {
89+
cat("Predicted Value:\n",
90+
names(x$value), "\n",
91+
x$value,
92+
"\n")
93+
}
94+
}
95+
96+
#' @export
97+
#' @rdname nullmodel
98+
predict.nullmodel <- function (object, new_data = NULL, type = NULL, ...) {
99+
if(is.null(type)) {
100+
type <- if(is.null(object$levels)) "raw" else "class"
101+
}
102+
103+
n <- if(is.null(new_data)) object$n else nrow(new_data)
104+
if(!is.null(object$levels)) {
105+
if(type == "prob") {
106+
out <- matrix(rep(object$pct, n), nrow = n, byrow = TRUE)
107+
colnames(out) <- object$levels
108+
out <- as.data.frame(out)
109+
} else {
110+
out <- factor(rep(object$value, n), levels = object$levels)
111+
}
112+
} else {
113+
if(type %in% c("prob", "class")) stop("ony raw predicitons are applicable to regression models")
114+
if(length(object$value) == 1) {
115+
out <- rep(object$value, n)
116+
} else {
117+
out <- as_tibble(matrix(rep(object$value, n),
118+
ncol = length(object$value), byrow = TRUE))
119+
120+
names(out) <- names(object$value)
121+
}
122+
}
123+
out
124+
}
125+
126+
#' General Interface for null models
127+
#'
128+
#' `null_model` is a way to generate a _specification_ of a model before
129+
#' fitting and allows the model to be created using R. It doens't have any
130+
#' main arguments.
131+
#'
132+
#' @param mode A single character string for the type of model.
133+
#' Possible values for this model are "unknown", "regression", or
134+
#' "classification".
135+
#' @details The model can be created using the `fit()` function using the
136+
#' following _engines_:
137+
#' \itemize{
138+
#' \item \pkg{R}: `"parsnip"`
139+
#' }
140+
#'
141+
#' @section Engine Details:
142+
#'
143+
#' Engines may have pre-set default arguments when executing the
144+
#' model fit call. For this type of
145+
#' model, the template of the fit calls are:
146+
#'
147+
#' \pkg{parsnip} classification
148+
#'
149+
#' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::null_model(mode = "classification"), "parsnip")}
150+
#'
151+
#' \pkg{parsnip} regression
152+
#'
153+
#' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::null_model(mode = "regression"), "parsnip")}
154+
#'
155+
#' @importFrom purrr map_lgl
156+
#' @seealso [varying()], [fit()]
157+
#' @examples
158+
#' null_model(mode = "regression")
159+
#' @export
160+
null_model <-
161+
function(mode = "classification") {
162+
# Check for correct mode
163+
if (!(mode %in% null_model_modes))
164+
stop("`mode` should be one of: ",
165+
paste0("'", null_model_modes, "'", collapse = ", "),
166+
call. = FALSE)
167+
168+
# Capture the arguments in quosures
169+
args <- list()
170+
171+
# Save some empty slots for future parts of the specification
172+
out <- list(args = args, eng_args = NULL,
173+
mode = mode, method = NULL, engine = NULL)
174+
175+
# set classes in the correct order
176+
class(out) <- make_classes("null_model")
177+
out
178+
}

R/nullmodel_data.R

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
null_model_arg_key <- data.frame(
2+
parsnip = NULL,
3+
row.names = NULL,
4+
stringsAsFactors = FALSE
5+
)
6+
7+
null_model_modes <- c("classification", "regression", "unknown")
8+
9+
null_model_engines <- data.frame(
10+
parsnip = c(TRUE, TRUE, FALSE),
11+
row.names = c("classification", "regression", "unknown")
12+
)
13+
14+
# ------------------------------------------------------------------------------
15+
16+
null_model_parsnip_data <-
17+
list(
18+
libs = "parsnip",
19+
fit = list(
20+
interface = "matrix",
21+
protect = c("x", "y"),
22+
func = c(fun = "nullmodel"),
23+
defaults = list()
24+
),
25+
class = list(
26+
pre = NULL,
27+
post = NULL,
28+
func = c(fun = "predict"),
29+
args =
30+
list(
31+
object = quote(object$fit),
32+
new_data = quote(new_data),
33+
type = "class"
34+
)
35+
),
36+
classprob = list(
37+
pre = NULL,
38+
post = function(x, object) {
39+
str(as_tibble(x))
40+
as_tibble(x)
41+
},
42+
func = c(fun = "predict"),
43+
args =
44+
list(
45+
object = quote(object$fit),
46+
new_data = quote(new_data),
47+
type = "prob"
48+
)
49+
),
50+
numeric = list(
51+
pre = NULL,
52+
post = NULL,
53+
func = c(fun = "predict"),
54+
args =
55+
list(
56+
object = quote(object$fit),
57+
new_data = quote(new_data),
58+
type = "numeric"
59+
)
60+
),
61+
raw = list(
62+
pre = NULL,
63+
post = NULL,
64+
func = c(fun = "predict"),
65+
args =
66+
list(
67+
object = quote(object$fit),
68+
new_data = quote(new_data),
69+
type = "raw"
70+
)
71+
)
72+
)

0 commit comments

Comments
 (0)