Skip to content

Commit fdf7abc

Browse files
committed
keras linear regression
1 parent 4964061 commit fdf7abc

File tree

7 files changed

+218
-20
lines changed

7 files changed

+218
-20
lines changed

R/linear_reg.R

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,11 @@
22
#'
33
#' `linear_reg` is a way to generate a _specification_ of a model
44
#' before fitting and allows the model to be created using
5-
#' different packages in R, Stan, or via Spark. The main arguments for the
6-
#' model are:
5+
#' different packages in R, Stan, keras, or via Spark. The main
6+
#' arguments for the model are:
77
#' \itemize{
88
#' \item \code{penalty}: The total amount of regularization
9-
#' in the model. Note that this must be zero for some engines .
9+
#' in the model. Note that this must be zero for some engines.
1010
#' \item \code{mixture}: The proportion of L1 regularization in
1111
#' the model. Note that this will be ignored for some engines.
1212
#' }
@@ -19,8 +19,11 @@
1919
#' @inheritParams boost_tree
2020
#' @param mode A single character string for the type of model.
2121
#' The only possible value for this model is "regression".
22-
#' @param penalty An non-negative number representing the
23-
#' total amount of regularization (`glmnet` and `spark` only).
22+
#' @param penalty An non-negative number representing the total
23+
#' amount of regularization (`glmnet`, `keras`, and `spark` only).
24+
#' For `keras` models, this corresponds to purely L2 regularization
25+
#' (aka weight decay) while the other models can be a combination
26+
#' of L1 and L2 (depending on the value of `mixture`).
2427
#' @param mixture A number between zero and one (inclusive) that
2528
#' represents the proportion of regularization that is used for the
2629
#' L2 penalty (i.e. weight decay, or ridge regression) versus L1
@@ -36,6 +39,7 @@
3639
#' \item \pkg{R}: `"lm"` or `"glmnet"`
3740
#' \item \pkg{Stan}: `"stan"`
3841
#' \item \pkg{Spark}: `"spark"`
42+
#' \item \pkg{keras}: `"keras"`
3943
#' }
4044
#'
4145
#' @section Engine Details:
@@ -59,6 +63,10 @@
5963
#' \pkg{spark}
6064
#'
6165
#' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::linear_reg(), "spark")}
66+
#'
67+
#' \pkg{keras}
68+
#'
69+
#' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::linear_reg(), "keras")}
6270
#'
6371
#' When using `glmnet` models, there is the option to pass
6472
#' multiple values (or no values) to the `penalty` argument.

R/linear_reg_data.R

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,19 @@ linear_reg_arg_key <- data.frame(
44
glmnet = c( "lambda", "alpha"),
55
spark = c("reg_param", "elastic_net_param"),
66
stan = c( NA, NA),
7+
keras = c( "decay", NA),
78
stringsAsFactors = FALSE,
89
row.names = c("penalty", "mixture")
910
)
1011

1112
linear_reg_modes <- "regression"
1213

1314
linear_reg_engines <- data.frame(
14-
lm = TRUE,
15+
lm = TRUE,
1516
glmnet = TRUE,
1617
spark = TRUE,
1718
stan = TRUE,
19+
keras = TRUE,
1820
row.names = c("regression")
1921
)
2022

@@ -240,5 +242,24 @@ linear_reg_spark_data <-
240242
)
241243
)
242244

243-
245+
linear_reg_keras_data <-
246+
list(
247+
libs = c("keras", "magrittr"),
248+
fit = list(
249+
interface = "matrix",
250+
protect = c("x", "y"),
251+
func = c(pkg = "parsnip", fun = "keras_mlp"),
252+
defaults = list(hidden_units = 1, act = "linear")
253+
),
254+
numeric = list(
255+
pre = NULL,
256+
post = maybe_multivariate,
257+
func = c(fun = "predict"),
258+
args =
259+
list(
260+
object = quote(object$fit),
261+
x = quote(as.matrix(new_data))
262+
)
263+
)
264+
)
244265

docs/articles/articles/Classification.html

Lines changed: 3 additions & 3 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

docs/articles/articles/Models.html

Lines changed: 27 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

docs/reference/linear_reg.html

Lines changed: 13 additions & 4 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/linear_reg.Rd

Lines changed: 13 additions & 5 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.
Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
library(testthat)
2+
library(parsnip)
3+
library(rlang)
4+
library(tibble)
5+
6+
# ------------------------------------------------------------------------------
7+
8+
context("keras linear regression")
9+
source("helpers.R")
10+
11+
# ------------------------------------------------------------------------------
12+
13+
basic_mod <-
14+
linear_reg() %>%
15+
set_engine("keras", epochs = 50, verbose = 0)
16+
17+
ridge_mod <-
18+
linear_reg(penalty = 0.1) %>%
19+
set_engine("keras", epochs = 50, verbose = 0)
20+
21+
ctrl <- fit_control(verbosity = 0, catch = FALSE)
22+
23+
# ------------------------------------------------------------------------------
24+
25+
test_that('model fitting', {
26+
27+
skip_if_not_installed("keras")
28+
29+
set.seed(257)
30+
expect_error(
31+
fit1 <-
32+
fit_xy(
33+
basic_mod,
34+
control = ctrl,
35+
x = iris[,2:4],
36+
y = iris$Sepal.Length
37+
),
38+
regexp = NA
39+
)
40+
41+
set.seed(257)
42+
expect_error(
43+
fit2 <-
44+
fit_xy(
45+
basic_mod,
46+
control = ctrl,
47+
x = iris[,2:4],
48+
y = iris$Sepal.Length
49+
),
50+
regexp = NA
51+
)
52+
expect_equal(fit1, fit2)
53+
54+
expect_error(
55+
fit(
56+
basic_mod,
57+
Sepal.Length ~ .,
58+
data = iris[, -5],
59+
control = ctrl
60+
),
61+
regexp = NA
62+
)
63+
64+
expect_error(
65+
fit1 <-
66+
fit_xy(
67+
ridge_mod,
68+
control = ctrl,
69+
x = iris[,2:4],
70+
y = iris$Sepal.Length
71+
),
72+
regexp = NA
73+
)
74+
75+
expect_error(
76+
fit(
77+
ridge_mod,
78+
Sepal.Length ~ .,
79+
data = iris[, -5],
80+
control = ctrl
81+
),
82+
regexp = NA
83+
)
84+
85+
})
86+
87+
88+
test_that('regression prediction', {
89+
90+
skip_if_not_installed("keras")
91+
92+
library(keras)
93+
94+
set.seed(257)
95+
lm_fit <-
96+
fit_xy(
97+
basic_mod,
98+
control = ctrl,
99+
x = iris[,2:4],
100+
y = iris$Sepal.Length
101+
)
102+
103+
keras_pred <-
104+
predict(lm_fit$fit, as.matrix(iris[1:3,2:4])) %>%
105+
as_tibble() %>%
106+
setNames(".pred")
107+
parsnip_pred <- predict(lm_fit, iris[1:3,2:4])
108+
expect_equal(as.data.frame(keras_pred), as.data.frame(parsnip_pred))
109+
110+
set.seed(257)
111+
rr_fit <-
112+
fit_xy(
113+
ridge_mod,
114+
control = ctrl,
115+
x = iris[,2:4],
116+
y = iris$Sepal.Length
117+
)
118+
119+
keras_pred <-
120+
predict(rr_fit$fit, as.matrix(iris[1:3,2:4])) %>%
121+
as_tibble() %>%
122+
setNames(".pred")
123+
parsnip_pred <- predict(rr_fit, iris[1:3,2:4])
124+
expect_equal(as.data.frame(keras_pred), as.data.frame(parsnip_pred))
125+
126+
})

0 commit comments

Comments
 (0)