55# ' different packages in R, Stan, or via Spark. The main arguments for the
66# ' model are:
77# ' \itemize{
8- # ' \item \code{link}: The link function.
98# ' \item \code{regularization}: The total amount of regularization
109# ' in the model. Note that this must be zero for some engines.
1110# ' \item \code{mixture}: The proportion of L2 regularization in
@@ -51,8 +50,6 @@ logistic_reg <- function (mode, ...)
5150# ' `rstanarm::stan_glm`, etc.). These are not evaluated
5251# ' until the model is fit and will be substituted into the model
5352# ' fit expression.
54- # ' @param link A character string for the link function. Possible
55- # ' values are "logit", "probit", "cauchit", "log" and "cloglog".
5653# ' @param regularization An non-negative number representing the
5754# ' total amount of regularization.
5855# ' @param mixture A number between zero and one (inclusive) that
@@ -65,7 +62,6 @@ logistic_reg <- function (mode, ...)
6562
6663logistic_reg.default <-
6764 function (mode = " classification" ,
68- link = NULL ,
6965 regularization = NULL ,
7066 mixture = NULL ,
7167 engine_args = list (),
@@ -79,7 +75,6 @@ logistic_reg.default <-
7975 )
8076
8177 args <- list (
82- link = rlang :: enquo(link ),
8378 regularization = rlang :: enquo(regularization ),
8479 mixture = rlang :: enquo(mixture )
8580 )
@@ -107,17 +102,18 @@ print.logistic_reg <- function(x, ...) {
107102
108103# ##################################################################
109104
105+ # ' @importFrom rlang missing_arg
110106logistic_reg_glm_classification <- function () {
111107 libs <- " stats"
112108 interface <- " formula"
113109 protect = c(" glm" , " formula" , " data" , " weights" )
114110 fit <-
115111 quote(
116112 glm(
117- formula = missing_arg() ,
113+ formula = formula ,
118114 family = binomial(),
119- data = missing_arg() ,
120- weights = missing_arg() ,
115+ data = data ,
116+ weights = NULL ,
121117 subset = missing_arg(),
122118 na.action = missing_arg(),
123119 start = NULL ,
@@ -143,7 +139,7 @@ logistic_reg_glmnet_classification <- function () {
143139 fit <-
144140 quote(
145141 glmnet(
146- x = x ,
142+ x = as.matrix( x ) ,
147143 y = y ,
148144 family = " binomial" ,
149145 weights = missing_arg(),
@@ -213,10 +209,10 @@ logistic_reg_stan_glm_classification <- function () {
213209 fit <-
214210 quote(
215211 stan_glm(
216- formula = missing_arg() ,
212+ formula = formula ,
217213 family = binomial(),
218- data = missing_arg() ,
219- weights = missing_arg() ,
214+ data = data ,
215+ weights = NULL ,
220216 subset = missing_arg(),
221217 na.action = NULL ,
222218 offset = NULL ,
@@ -247,23 +243,24 @@ finalize.logistic_reg <- function(x, engine = NULL, ...) {
247243 x <- check_engine(x )
248244
249245 # exceptions and error trapping here
250- if (engine %in% c(" glm" , " stan_glm" ) & ! is.null (x $ args $ regularization )) {
246+ if (engine %in% c(" glm" , " stan_glm" ) & ! null_value (x $ args $ regularization )) {
251247 warning(" The argument `regularization` cannot be used with this engine. " ,
252248 " The value will be set to NULL" )
253249 x $ args $ regularization <- quos(NULL )
254250 }
255- if (engine %in% c(" glm" , " stan_glm" ) & ! is.null (x $ args $ mixture )) {
251+ if (engine %in% c(" glm" , " stan_glm" ) & ! null_value (x $ args $ mixture )) {
256252 warning(" The argument `mixture` cannot be used with this engine. " ,
257253 " The value will be set to NULL" )
258254 x $ args $ mixture <- quos(NULL )
259255 }
260256
261257 x $ method <- get_model_objects(x , x $ engine )()
262- real_args <- deharmonize(x $ args , logistic_reg_arg_key , x $ engine )
263-
264- # replace default args with user-specified
265- x $ method $ fit <-
266- sub_arg_values(x $ method $ fit , real_args , ignore = x $ method $ protect )
258+ if (! (engine %in% c(" glm" , " stan_glm" ))) {
259+ real_args <- deharmonize(x $ args , logistic_reg_arg_key , x $ engine )
260+ # replace default args with user-specified
261+ x $ method $ fit <-
262+ sub_arg_values(x $ method $ fit , real_args , ignore = x $ method $ protect )
263+ }
267264
268265 if (length(x $ others ) > 0 ) {
269266 protected <- names(x $ others ) %in% x $ method $ protect
@@ -281,7 +278,16 @@ finalize.logistic_reg <- function(x, engine = NULL, ...) {
281278 x $ method $ fit <- sub_arg_values(x $ method $ fit , x $ others , ignore = x $ method $ protect )
282279
283280 # remove NULL and unmodified argument values
284- modifed_args <- names(real_args )[! vapply(real_args , null_value , lgl(1 ))]
281+ modifed_args <- if (! (engine %in% c(" glm" , " stan_glm" )))
282+ names(real_args )[! vapply(real_args , null_value , lgl(1 ))]
283+ else
284+ NULL
285+ modifed_args <- unique(c(" family" , modifed_args ))
286+
287+ # glmnet can't handle NULL weights
288+ if (engine == " glmnet" & identical(x $ method $ fit $ weights , quote(missing_arg())))
289+ x $ method $ protect <- x $ method $ protect [x $ method $ protect != " weights" ]
290+
285291 x $ method $ fit <- prune_expr(x $ method $ fit , x $ method $ protect , c(modifed_args , names(x $ others )))
286292 x
287293}
@@ -291,14 +297,13 @@ finalize.logistic_reg <- function(x, engine = NULL, ...) {
291297# ' @export
292298update.logistic_reg <-
293299 function (object ,
294- link = NULL , regularization = NULL , mixture = NULL ,
300+ regularization = NULL , mixture = NULL ,
295301 engine_args = list (),
296302 fresh = FALSE ,
297303 ... ) {
298304 check_empty_ellipse(... )
299305
300306 args <- list (
301- link = rlang :: enquo(link ),
302307 regularization = rlang :: enquo(regularization ),
303308 mixture = rlang :: enquo(mixture )
304309 )
@@ -327,12 +332,12 @@ update.logistic_reg <-
327332# ##################################################################
328333
329334logistic_reg_arg_key <- data.frame (
330- glm = c(" link " , NA , NA ),
331- glmnet = c( NA , " lambda" , " alpha" ),
332- spark = c( NA , " reg_param" , " elastic_net_param" ),
333- stan_glm = c(" link " , NA , NA ),
335+ glm = c( NA , NA ),
336+ glmnet = c( " lambda" , " alpha" ),
337+ spark = c(" reg_param" , " elastic_net_param" ),
338+ stan_glm = c( NA , NA ),
334339 stringsAsFactors = FALSE ,
335- row.names = c(" link " , " regularization" , " mixture" )
340+ row.names = c(" regularization" , " mixture" )
336341)
337342
338343logistic_reg_modes <- " classification"
0 commit comments