Skip to content

Commit a5e7d35

Browse files
committed
removed link as a parameter
1 parent 7eaf0c9 commit a5e7d35

File tree

2 files changed

+33
-32
lines changed

2 files changed

+33
-32
lines changed

R/logistic_reg.R

Lines changed: 32 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
#' different packages in R, Stan, or via Spark. The main arguments for the
66
#' model are:
77
#' \itemize{
8-
#' \item \code{link}: The link function.
98
#' \item \code{regularization}: The total amount of regularization
109
#' in the model. Note that this must be zero for some engines.
1110
#' \item \code{mixture}: The proportion of L2 regularization in
@@ -51,8 +50,6 @@ logistic_reg <- function (mode, ...)
5150
#' `rstanarm::stan_glm`, etc.). These are not evaluated
5251
#' until the model is fit and will be substituted into the model
5352
#' fit expression.
54-
#' @param link A character string for the link function. Possible
55-
#' values are "logit", "probit", "cauchit", "log" and "cloglog".
5653
#' @param regularization An non-negative number representing the
5754
#' total amount of regularization.
5855
#' @param mixture A number between zero and one (inclusive) that
@@ -65,7 +62,6 @@ logistic_reg <- function (mode, ...)
6562

6663
logistic_reg.default <-
6764
function(mode = "classification",
68-
link = NULL,
6965
regularization = NULL,
7066
mixture = NULL,
7167
engine_args = list(),
@@ -79,7 +75,6 @@ logistic_reg.default <-
7975
)
8076

8177
args <- list(
82-
link = rlang::enquo(link),
8378
regularization = rlang::enquo(regularization),
8479
mixture = rlang::enquo(mixture)
8580
)
@@ -107,17 +102,18 @@ print.logistic_reg <- function(x, ...) {
107102

108103
###################################################################
109104

105+
#' @importFrom rlang missing_arg
110106
logistic_reg_glm_classification <- function () {
111107
libs <- "stats"
112108
interface <- "formula"
113109
protect = c("glm", "formula", "data", "weights")
114110
fit <-
115111
quote(
116112
glm(
117-
formula = missing_arg(),
113+
formula = formula,
118114
family = binomial(),
119-
data = missing_arg(),
120-
weights = missing_arg(),
115+
data = data,
116+
weights = NULL,
121117
subset = missing_arg(),
122118
na.action = missing_arg(),
123119
start = NULL,
@@ -143,7 +139,7 @@ logistic_reg_glmnet_classification <- function () {
143139
fit <-
144140
quote(
145141
glmnet(
146-
x = x,
142+
x = as.matrix(x),
147143
y = y,
148144
family = "binomial",
149145
weights = missing_arg(),
@@ -213,10 +209,10 @@ logistic_reg_stan_glm_classification <- function () {
213209
fit <-
214210
quote(
215211
stan_glm(
216-
formula = missing_arg(),
212+
formula = formula,
217213
family = binomial(),
218-
data = missing_arg(),
219-
weights = missing_arg(),
214+
data = data,
215+
weights = NULL,
220216
subset = missing_arg(),
221217
na.action = NULL,
222218
offset = NULL,
@@ -247,23 +243,24 @@ finalize.logistic_reg <- function(x, engine = NULL, ...) {
247243
x <- check_engine(x)
248244

249245
# exceptions and error trapping here
250-
if(engine %in% c("glm", "stan_glm") & !is.null(x$args$regularization)) {
246+
if(engine %in% c("glm", "stan_glm") & !null_value(x$args$regularization)) {
251247
warning("The argument `regularization` cannot be used with this engine. ",
252248
"The value will be set to NULL")
253249
x$args$regularization <- quos(NULL)
254250
}
255-
if(engine %in% c("glm", "stan_glm") & !is.null(x$args$mixture)) {
251+
if(engine %in% c("glm", "stan_glm") & !null_value(x$args$mixture)) {
256252
warning("The argument `mixture` cannot be used with this engine. ",
257253
"The value will be set to NULL")
258254
x$args$mixture <- quos(NULL)
259255
}
260256

261257
x$method <- get_model_objects(x, x$engine)()
262-
real_args <- deharmonize(x$args, logistic_reg_arg_key, x$engine)
263-
264-
# replace default args with user-specified
265-
x$method$fit <-
266-
sub_arg_values(x$method$fit, real_args, ignore = x$method$protect)
258+
if(!(engine %in% c("glm", "stan_glm"))) {
259+
real_args <- deharmonize(x$args, logistic_reg_arg_key, x$engine)
260+
# replace default args with user-specified
261+
x$method$fit <-
262+
sub_arg_values(x$method$fit, real_args, ignore = x$method$protect)
263+
}
267264

268265
if (length(x$others) > 0) {
269266
protected <- names(x$others) %in% x$method$protect
@@ -281,7 +278,16 @@ finalize.logistic_reg <- function(x, engine = NULL, ...) {
281278
x$method$fit <- sub_arg_values(x$method$fit, x$others, ignore = x$method$protect)
282279

283280
# remove NULL and unmodified argument values
284-
modifed_args <- names(real_args)[!vapply(real_args, null_value, lgl(1))]
281+
modifed_args <- if (!(engine %in% c("glm", "stan_glm")))
282+
names(real_args)[!vapply(real_args, null_value, lgl(1))]
283+
else
284+
NULL
285+
modifed_args <- unique(c("family", modifed_args))
286+
287+
# glmnet can't handle NULL weights
288+
if (engine == "glmnet" & identical(x$method$fit$weights, quote(missing_arg())))
289+
x$method$protect <- x$method$protect[x$method$protect != "weights"]
290+
285291
x$method$fit <- prune_expr(x$method$fit, x$method$protect, c(modifed_args, names(x$others)))
286292
x
287293
}
@@ -291,14 +297,13 @@ finalize.logistic_reg <- function(x, engine = NULL, ...) {
291297
#' @export
292298
update.logistic_reg <-
293299
function(object,
294-
link = NULL, regularization = NULL, mixture = NULL,
300+
regularization = NULL, mixture = NULL,
295301
engine_args = list(),
296302
fresh = FALSE,
297303
...) {
298304
check_empty_ellipse(...)
299305

300306
args <- list(
301-
link = rlang::enquo(link),
302307
regularization = rlang::enquo(regularization),
303308
mixture = rlang::enquo(mixture)
304309
)
@@ -327,12 +332,12 @@ update.logistic_reg <-
327332
###################################################################
328333

329334
logistic_reg_arg_key <- data.frame(
330-
glm = c("link", NA, NA),
331-
glmnet = c( NA, "lambda", "alpha"),
332-
spark = c( NA, "reg_param", "elastic_net_param"),
333-
stan_glm = c("link", NA, NA),
335+
glm = c( NA, NA),
336+
glmnet = c( "lambda", "alpha"),
337+
spark = c("reg_param", "elastic_net_param"),
338+
stan_glm = c( NA, NA),
334339
stringsAsFactors = FALSE,
335-
row.names = c("link", "regularization", "mixture")
340+
row.names = c("regularization", "mixture")
336341
)
337342

338343
logistic_reg_modes <- "classification"

man/logistic_reg.Rd

Lines changed: 1 addition & 5 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)