Skip to content

Commit 69ee298

Browse files
committed
logistic regression draft
1 parent 46197c8 commit 69ee298

File tree

3 files changed

+428
-0
lines changed

3 files changed

+428
-0
lines changed

NAMESPACE

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,17 @@
11
# Generated by roxygen2: do not edit by hand
22

3+
S3method(finalize,logistic_reg)
34
S3method(finalize,rand_forest)
45
S3method(fit,model_spec)
6+
S3method(logistic_reg,default)
7+
S3method(print,logistic_reg)
58
S3method(print,rand_forest)
69
S3method(rand_forest,default)
10+
S3method(update,logistic_reg)
711
S3method(update,rand_forest)
812
export(finalize)
913
export(fit)
14+
export(logistic_reg)
1015
export(rand_forest)
1116
export(varying)
1217
importFrom(purrr,map_lgl)

R/logistic_reg.R

Lines changed: 347 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,347 @@
1+
#' General Interface for Logistic Regression Models
2+
#'
3+
#' `logistic_reg` is a way to generate a _specification_ of a model
4+
#' before fitting and allows the model to be created using
5+
#' different packages in R, Stan, or via Spark. The main arguments for the
6+
#' model are:
7+
#' \itemize{
8+
#' \item \code{link}: The link function.
9+
#' \item \code{regularization}: The total amount of regularization
10+
#' in the model. Note that this must be zero for some engines.
11+
#' \item \code{mixture}: The proportion of L2 regularization in
12+
#' the model. Note that this will be ignored for some engines.
13+
#' }
14+
#' These arguments are converted to their specific names at the
15+
#' time that the model is fit. Other options and argument can be
16+
#' set using the `engine_args` argument. If left to their defaults
17+
#' here (`NULL`), the values are taken from the underlying model
18+
#' functions.
19+
#'
20+
#' The data given to the function are not saved and are only used
21+
#' to determine the _mode_ of the model. For `logistic_reg`,the
22+
#' mode will always be "classification".
23+
#'
24+
#' The model can be created using the [fit()] function using the
25+
#' following _engines_:
26+
#' \itemize{
27+
#' \item \pkg{R}: `"glm"` or `"glmnet"`
28+
#' \item \pkg{Stan}: `"rstanarm"`
29+
#' \item \pkg{Spark}: `"spark"`
30+
#' }
31+
#' @export
32+
#' @rdname logistic_reg
33+
#' @importFrom rlang expr enquo missing_arg
34+
#' @importFrom purrr map_lgl
35+
#' @seealso [varying()], [fit()]
36+
#' @examples
37+
#' logistic_reg()
38+
#'
39+
#' # Parameters can be represented by a placeholder:
40+
#' logistic_reg(link = "probit", regularization = varying())
41+
42+
logistic_reg <- function (mode, ...)
43+
UseMethod("logistic_reg")
44+
45+
#' @rdname logistic_reg
46+
#' @export
47+
#' @param mode A single character string for the type of model.
48+
#' The only possible value for this model is "classification".
49+
#' @param engine_args A named list of arguments to be used by the
50+
#' underlying models (e.g., `stats::glm`,
51+
#' `rstanarm::stan_glm`, etc.). These are not evaluated
52+
#' until the model is fit and will be substituted into the model
53+
#' fit expression.
54+
#' @param link A character string for the link function. Possible
55+
#' values are "logit", "probit", "cauchit", "log" and "cloglog".
56+
#' @param regularization An non-negative number representing the
57+
#' total amount of regularization.
58+
#' @param mixture A number between zero and one (inclusive) that
59+
#' represents the proportion of regularization that is used for the
60+
#' L2 penalty (i.e. weight decay, or ridge regression) versus L1
61+
#' (the lasso).
62+
#' @param ... Used for S3 method consistency. Any arguments passed to
63+
#' the ellipses will result in an error. Use `engine_args` instead.
64+
65+
66+
logistic_reg.default <-
67+
function(mode = "classification",
68+
link = NULL,
69+
regularization = NULL,
70+
mixture = NULL,
71+
engine_args = list(),
72+
...) {
73+
check_empty_ellipse(...)
74+
if (!(mode %in% logistic_reg_modes))
75+
stop(
76+
"`mode` should be one of: ",
77+
paste0("'", logistic_reg_modes, "'", collapse = ", "),
78+
call. = FALSE
79+
)
80+
81+
args <- list(
82+
link = rlang::enquo(link),
83+
regularization = rlang::enquo(regularization),
84+
mixture = rlang::enquo(mixture)
85+
)
86+
87+
others <- parse_engine_options(rlang::enquo(engine_args))
88+
89+
# write a constructor function
90+
out <- list(
91+
args = args,
92+
others = others,
93+
mode = mode,
94+
method = NULL,
95+
engine = NULL
96+
)
97+
class(out) <- make_classes("logistic_reg", mode)
98+
out
99+
}
100+
101+
#' @export
102+
print.logistic_reg <- function(x, ...) {
103+
cat("Logistic Regression Model Specification (", x$mode, ")\n\n", sep = "")
104+
model_printer(x, ...)
105+
invisible(x)
106+
}
107+
108+
###################################################################
109+
110+
logistic_reg_glm_classification <- function () {
111+
libs <- "stats"
112+
interface <- "formula"
113+
protect = c("glm", "formula", "data", "weights")
114+
fit <-
115+
quote(
116+
glm(
117+
formula = missing_arg(),
118+
family = binomial(),
119+
data = missing_arg(),
120+
weights = missing_arg(),
121+
subset = missing_arg(),
122+
na.action = missing_arg(),
123+
start = NULL,
124+
etastart = missing_arg(),
125+
mustart = missing_arg(),
126+
offset = missing_arg(),
127+
control = list(...),
128+
model = TRUE,
129+
method = "glm.fit",
130+
x = FALSE,
131+
y = TRUE,
132+
contrasts = NULL,
133+
... = missing_arg()
134+
)
135+
)
136+
list(library = libs, interface = interface, fit = fit, protect = protect)
137+
}
138+
139+
logistic_reg_glmnet_classification <- function () {
140+
libs <- "glmnet"
141+
interface <- "data.frame"
142+
protect = c("glmnet", "x", "y", "weights")
143+
fit <-
144+
quote(
145+
glmnet(
146+
x = x,
147+
y = y,
148+
family = "binomial",
149+
weights = missing_arg(),
150+
offset = NULL,
151+
alpha = 1,
152+
nlambda = 100,
153+
lambda.min.ratio = ifelse(nobs < nvars, 0.01, 1e-04),
154+
lambda = NULL,
155+
standardize = TRUE,
156+
intercept = TRUE,
157+
thresh = 1e-07,
158+
dfmax = nvars + 1,
159+
pmax = min(dfmax * 2 + 20, nvars),
160+
exclude = missing_arg(),
161+
penalty.factor = rep(1, nvars),
162+
lower.limits = -Inf,
163+
upper.limits = Inf,
164+
maxit = 1e+05,
165+
type.gaussian = ifelse(nvars < 500, "covariance", "naive"),
166+
type.logistic = c("Newton", "modified.Newton"),
167+
standardize.response = FALSE,
168+
type.multinomial = c("ungrouped", "grouped")
169+
)
170+
)
171+
list(library = libs, interface = interface, fit = fit, protect = protect)
172+
}
173+
174+
logistic_reg_spark_classification <- function () {
175+
libs <- "sparklyr"
176+
interface <- "data.frame"
177+
protect = c("ml_logistic_regression", "x", "formula", "label_col", "features_col")
178+
fit <-
179+
quote(
180+
ml_logistic_regression(
181+
x = x,
182+
formula = NULL,
183+
fit_intercept = TRUE,
184+
elastic_net_param = 0,
185+
reg_param = 0,
186+
max_iter = 100L,
187+
threshold = 0.5,
188+
thresholds = NULL,
189+
tol = 1e-06,
190+
weight_col = NULL,
191+
aggregation_depth = 2L,
192+
lower_bounds_on_coefficients = NULL,
193+
lower_bounds_on_intercepts = NULL,
194+
upper_bounds_on_coefficients = NULL,
195+
upper_bounds_on_intercepts = NULL,
196+
features_col = "features",
197+
label_col = "label",
198+
family = "auto",
199+
prediction_col = "prediction",
200+
probability_col = "probability",
201+
raw_prediction_col = "rawPrediction",
202+
uid = random_string("logistic_regression_"),
203+
... = missing_arg()
204+
)
205+
)
206+
list(library = libs, interface = interface, fit = fit, protect = protect)
207+
}
208+
209+
logistic_reg_stan_glm_classification <- function () {
210+
libs <- "rstanarm"
211+
interface <- "formula"
212+
protect = c("stan_glm", "formula", "data", "weights")
213+
fit <-
214+
quote(
215+
stan_glm(
216+
formula = missing_arg(),
217+
family = binomial(),
218+
data = missing_arg(),
219+
weights = missing_arg(),
220+
subset = missing_arg(),
221+
na.action = NULL,
222+
offset = NULL,
223+
model = TRUE,
224+
x = FALSE,
225+
y = TRUE,
226+
contrasts = NULL,
227+
... = missing_arg(),
228+
prior = normal(),
229+
prior_intercept = normal(),
230+
prior_aux = exponential(),
231+
prior_PD = FALSE,
232+
algorithm = c("sampling", "optimizing", "meanfield", "fullrank"),
233+
adapt_delta = NULL,
234+
QR = FALSE,
235+
sparse = FALSE
236+
)
237+
)
238+
list(library = libs, interface = interface, fit = fit, protect = protect)
239+
}
240+
241+
#' @importFrom rlang quos
242+
#' @export
243+
finalize.logistic_reg <- function(x, engine = NULL, ...) {
244+
check_empty_ellipse(...)
245+
246+
x$engine <- engine
247+
x <- check_engine(x)
248+
249+
# exceptions and error trapping here
250+
if(engine %in% c("glm", "stan_glm") & !is.null(x$args$regularization)) {
251+
warning("The argument `regularization` cannot be used with this engine. ",
252+
"The value will be set to NULL")
253+
x$args$regularization <- quos(NULL)
254+
}
255+
if(engine %in% c("glm", "stan_glm") & !is.null(x$args$mixture)) {
256+
warning("The argument `mixture` cannot be used with this engine. ",
257+
"The value will be set to NULL")
258+
x$args$mixture <- quos(NULL)
259+
}
260+
261+
x$method <- get_model_objects(x, x$engine)()
262+
real_args <- deharmonize(x$args, logistic_reg_arg_key, x$engine)
263+
264+
# replace default args with user-specified
265+
x$method$fit <-
266+
sub_arg_values(x$method$fit, real_args, ignore = x$method$protect)
267+
268+
if (length(x$others) > 0) {
269+
protected <- names(x$others) %in% x$method$protect
270+
if (any(protected)) {
271+
warning(
272+
"The following options cannot be changed at this time ",
273+
"and were removed: ",
274+
paste0("`", names(x$others)[protected], "`", collapse = ", "),
275+
call. = FALSE
276+
)
277+
x$others <- x$others[-which(protected)]
278+
}
279+
}
280+
if (length(x$others) > 0)
281+
x$method$fit <- sub_arg_values(x$method$fit, x$others, ignore = x$method$protect)
282+
283+
# remove NULL and unmodified argument values
284+
modifed_args <- names(real_args)[!vapply(real_args, null_value, lgl(1))]
285+
x$method$fit <- prune_expr(x$method$fit, x$method$protect, c(modifed_args, names(x$others)))
286+
x
287+
}
288+
289+
###################################################################
290+
291+
#' @export
292+
update.logistic_reg <-
293+
function(object,
294+
link = NULL, regularization = NULL, mixture = NULL,
295+
engine_args = list(),
296+
fresh = FALSE,
297+
...) {
298+
check_empty_ellipse(...)
299+
300+
args <- list(
301+
link = rlang::enquo(link),
302+
regularization = rlang::enquo(regularization),
303+
mixture = rlang::enquo(mixture)
304+
)
305+
if (fresh) {
306+
object$args <- args
307+
} else {
308+
null_args <- map_lgl(args, null_value)
309+
if (any(null_args))
310+
args <- args[!null_args]
311+
if (length(args) > 0)
312+
object$args[names(args)] <- args
313+
}
314+
315+
others <- parse_engine_options(rlang::enquo(engine_args))
316+
if (length(others) > 0) {
317+
if (fresh)
318+
object$others <- others
319+
else
320+
object$others[names(others)] <- others
321+
}
322+
323+
object
324+
}
325+
326+
327+
###################################################################
328+
329+
logistic_reg_arg_key <- data.frame(
330+
glm = c("link", NA, NA),
331+
glmnet = c( NA, "lambda", "alpha"),
332+
spark = c( NA, "reg_param", "elastic_net_param"),
333+
stan_glm = c("link", NA, NA),
334+
stringsAsFactors = FALSE,
335+
row.names = c("link", "regularization", "mixture")
336+
)
337+
338+
logistic_reg_modes <- "classification"
339+
340+
logistic_reg_engines <- data.frame(
341+
glm = TRUE,
342+
glmnet = TRUE,
343+
spark = TRUE,
344+
stan_glm = TRUE,
345+
row.names = c("classification")
346+
)
347+

0 commit comments

Comments
 (0)