Skip to content

Commit fe24a99

Browse files
committed
added logistic and multinomial regression via keras
1 parent b84df46 commit fe24a99

15 files changed

+598
-37
lines changed

R/logistic_reg.R

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22
#'
33
#' `logistic_reg` is a way to generate a _specification_ of a model
44
#' before fitting and allows the model to be created using
5-
#' different packages in R, Stan, or via Spark. The main arguments for the
6-
#' model are:
5+
#' different packages in R, Stan, keras, or via Spark. The main
6+
#' arguments for the model are:
77
#' \itemize{
88
#' \item \code{penalty}: The total amount of regularization
99
#' in the model. Note that this must be zero for some engines.
@@ -19,8 +19,11 @@
1919
#' @inheritParams boost_tree
2020
#' @param mode A single character string for the type of model.
2121
#' The only possible value for this model is "classification".
22-
#' @param penalty An non-negative number representing the
23-
#' total amount of regularization (`glmnet` and `spark` only).
22+
#' @param penalty An non-negative number representing the total
23+
#' amount of regularization (`glmnet`, `keras`, and `spark` only).
24+
#' For `keras` models, this corresponds to purely L2 regularization
25+
#' (aka weight decay) while the other models can be a combination
26+
#' of L1 and L2 (depending on the value of `mixture`).
2427
#' @param mixture A number between zero and one (inclusive) that
2528
#' represents the proportion of regularization that is used for the
2629
#' L2 penalty (i.e. weight decay, or ridge regression) versus L1
@@ -34,6 +37,7 @@
3437
#' \item \pkg{R}: `"glm"` or `"glmnet"`
3538
#' \item \pkg{Stan}: `"stan"`
3639
#' \item \pkg{Spark}: `"spark"`
40+
#' \item \pkg{keras}: `"keras"`
3741
#' }
3842
#'
3943
#' @section Engine Details:
@@ -58,6 +62,10 @@
5862
#'
5963
#' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::logistic_reg(), "spark")}
6064
#'
65+
#' \pkg{keras}
66+
#'
67+
#' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::logistic_reg(), "keras")}
68+
#'
6169
#' When using `glmnet` models, there is the option to pass
6270
#' multiple values (or no values) to the `penalty` argument.
6371
#' This can have an effect on the model object results. When using

R/logistic_reg_data.R

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ logistic_reg_arg_key <- data.frame(
44
glmnet = c( "lambda", "alpha"),
55
spark = c("reg_param", "elastic_net_param"),
66
stan = c( NA, NA),
7+
keras = c( "decay", NA),
78
stringsAsFactors = FALSE,
89
row.names = c("penalty", "mixture")
910
)
@@ -15,6 +16,7 @@ logistic_reg_engines <- data.frame(
1516
glmnet = TRUE,
1617
spark = TRUE,
1718
stan = TRUE,
19+
keras = TRUE,
1820
row.names = c("classification")
1921
)
2022

@@ -290,3 +292,39 @@ logistic_reg_spark_data <-
290292
)
291293
)
292294

295+
logistic_reg_keras_data <-
296+
list(
297+
libs = c("keras", "magrittr"),
298+
fit = list(
299+
interface = "matrix",
300+
protect = c("x", "y"),
301+
func = c(pkg = "parsnip", fun = "keras_mlp"),
302+
defaults = list(hidden_units = 1, act = "linear")
303+
),
304+
class = list(
305+
pre = NULL,
306+
post = function(x, object) {
307+
object$lvl[x + 1]
308+
},
309+
func = c(pkg = "keras", fun = "predict_classes"),
310+
args =
311+
list(
312+
object = quote(object$fit),
313+
x = quote(as.matrix(new_data))
314+
)
315+
),
316+
classprob = list(
317+
pre = NULL,
318+
post = function(x, object) {
319+
x <- as_tibble(x)
320+
colnames(x) <- object$lvl
321+
x
322+
},
323+
func = c(pkg = "keras", fun = "predict_proba"),
324+
args =
325+
list(
326+
object = quote(object$fit),
327+
x = quote(as.matrix(new_data))
328+
)
329+
)
330+
)

R/multinom_reg.R

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
#'
33
#' `multinom_reg` is a way to generate a _specification_ of a model
44
#' before fitting and allows the model to be created using
5-
#' different packages in R or Spark. The main arguments for the
5+
#' different packages in R, keras, or Spark. The main arguments for the
66
#' model are:
77
#' \itemize{
88
#' \item \code{penalty}: The total amount of regularization
@@ -19,8 +19,11 @@
1919
#' @inheritParams boost_tree
2020
#' @param mode A single character string for the type of model.
2121
#' The only possible value for this model is "classification".
22-
#' @param penalty An non-negative number representing the
23-
#' total amount of regularization.
22+
#' @param penalty An non-negative number representing the total
23+
#' amount of regularization (`glmnet`, `keras`, and `spark` only).
24+
#' For `keras` models, this corresponds to purely L2 regularization
25+
#' (aka weight decay) while the other models can be a combination
26+
#' of L1 and L2 (depending on the value of `mixture`).
2427
#' @param mixture A number between zero and one (inclusive) that
2528
#' represents the proportion of regularization that is used for the
2629
#' L2 penalty (i.e. weight decay, or ridge regression) versus L1
@@ -33,6 +36,7 @@
3336
#' \itemize{
3437
#' \item \pkg{R}: `"glmnet"`
3538
#' \item \pkg{Stan}: `"stan"`
39+
#' \item \pkg{keras}: `"keras"`
3640
#' }
3741
#'
3842
#' @section Engine Details:
@@ -49,6 +53,10 @@
4953
#'
5054
#' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::multinom_reg(), "spark")}
5155
#'
56+
#' \pkg{keras}
57+
#'
58+
#' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::multinom_reg(), "keras")}
59+
#'
5260
#' When using `glmnet` models, there is the option to pass
5361
#' multiple values (or no values) to the `penalty` argument.
5462
#' This can have an effect on the model object results. When using

R/multinom_reg_data.R

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
multinom_reg_arg_key <- data.frame(
33
glmnet = c( "lambda", "alpha"),
44
spark = c("reg_param", "elastic_net_param"),
5+
keras = c( "decay", NA),
56
stringsAsFactors = FALSE,
67
row.names = c("penalty", "mixture")
78
)
@@ -11,6 +12,7 @@ multinom_reg_modes <- "classification"
1112
multinom_reg_engines <- data.frame(
1213
glmnet = TRUE,
1314
spark = TRUE,
15+
keras = TRUE,
1416
row.names = c("classification")
1517
)
1618

@@ -98,3 +100,39 @@ multinom_reg_spark_data <-
98100
)
99101

100102

103+
multinom_reg_keras_data <-
104+
list(
105+
libs = c("keras", "magrittr"),
106+
fit = list(
107+
interface = "matrix",
108+
protect = c("x", "y"),
109+
func = c(pkg = "parsnip", fun = "keras_mlp"),
110+
defaults = list(hidden_units = 1, act = "linear")
111+
),
112+
class = list(
113+
pre = NULL,
114+
post = function(x, object) {
115+
object$lvl[x + 1]
116+
},
117+
func = c(pkg = "keras", fun = "predict_classes"),
118+
args =
119+
list(
120+
object = quote(object$fit),
121+
x = quote(as.matrix(new_data))
122+
)
123+
),
124+
classprob = list(
125+
pre = NULL,
126+
post = function(x, object) {
127+
x <- as_tibble(x)
128+
colnames(x) <- object$lvl
129+
x
130+
},
131+
func = c(pkg = "keras", fun = "predict_proba"),
132+
args =
133+
list(
134+
object = quote(object$fit),
135+
x = quote(as.matrix(new_data))
136+
)
137+
)
138+
)

docs/articles/articles/Classification.html

Lines changed: 4 additions & 4 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

docs/articles/articles/Models.html

Lines changed: 54 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

docs/articles/articles/Scratch.html

Lines changed: 3 additions & 3 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

docs/reference/index.html

Lines changed: 12 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

docs/reference/linear_reg.html

Lines changed: 4 additions & 4 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)