Skip to content

Commit fd898e0

Browse files
committed
more linear regression glmnet updates for #195
1 parent bdd1d86 commit fd898e0

File tree

5 files changed

+30
-19
lines changed

5 files changed

+30
-19
lines changed

DESCRIPTION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
Package: parsnip
2-
Version: 0.0.2.9000
2+
Version: 0.0.3
33
Title: A Common API to Modeling and Analysis Functions
44
Description: A common interface is provided to allow users to specify a model without having to remember the different argument names across different functions or computational engines (e.g. 'R', 'Spark', 'Stan', etc).
55
Authors@R: c(

NAMESPACE

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ S3method(print,svm_rbf)
5656
S3method(translate,boost_tree)
5757
S3method(translate,decision_tree)
5858
S3method(translate,default)
59+
S3method(translate,linear_reg)
5960
S3method(translate,mars)
6061
S3method(translate,mlp)
6162
S3method(translate,nearest_neighbor)

NEWS.md

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,26 @@
1-
# parsnip 0.0.2.9000
1+
# parsnip 0.0.3
2+
3+
Unplanned release based on CRAN requirements for Solaris.
24

35
## Breaking Changes
46

57
* The method that `parsnip` stores the model information has changed. Any custom models from previous versions will need to use the new method for registering models. The methods are detailed in `?get_model_env()` and the [package vignette for adding models](https://tidymodels.github.io/parsnip/articles/articles/Scratch.html).
6-
* The mode need to be declared for models that can be used for more than one mode prior to fitting and/or translation).
8+
9+
* The mode need to be declared for models that can be used for more than one mode prior to fitting and/or translation.
10+
711
* For `surv_reg()`, the engine that uses the `survival` package is now called `survival` instead of `survreg`.
812

13+
* For `glmnet` models, the full regularization path is always fit regardless of the value given to `penalty`. Previously, the model was fit with passing `penalty` to `glmnet`'s `lambda` argument and the model could only make predictions at those specific values. [(#195)](https://github.com/tidymodels/parsnip/issues/195)
14+
915
## New Features
1016

1117
* `add_rowindex()` can create a column called `.row` to a data frame.
1218

1319
* If a computational engine is not explicitly set, a default will be used. Each default is documented on the corresponding model page. A warning is issued at fit time unless verbosity is zero.
20+
1421
* `nearest_neighbor` gained a `multi_predict` method. The `multi_predict()` documentation is a little better organized.
22+
23+
* A suite of internal functions were added to help with upcoming model tuning features.
1524

1625

1726
# parsnip 0.0.2

R/linear_reg.R

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -336,7 +336,12 @@ multi_predict._elnet <-
336336
object$spec <- eval_args(object$spec)
337337

338338
if (is.null(penalty)) {
339-
penalty <- object$fit$lambda
339+
# See discussion in https://github.com/tidymodels/parsnip/issues/195
340+
if (!is.null(object$spec$args$penalty)) {
341+
penalty <- object$spec$args$penalty
342+
} else {
343+
penalty <- object$fit$lambda
344+
}
340345
}
341346

342347
pred <- predict._elnet(object, new_data = new_data, type = "raw",

tests/testthat/test_linear_reg_glmnet.R

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -67,10 +67,10 @@ test_that('glmnet prediction, single lambda', {
6767
y = iris$Sepal.Length
6868
)
6969

70-
uni_pred <- c(5.05124049139868, 4.87103404621362, 4.91028250633598, 4.9399094532023,
71-
5.08728178043569)
70+
uni_pred <- c(5.05125589060219, 4.86977761622526, 4.90912345599309, 4.93931874108359,
71+
5.08755154547758)
7272

73-
expect_equal(uni_pred, predict(res_xy, iris[1:5, num_pred])$.pred)
73+
expect_equal(uni_pred, predict(res_xy, iris[1:5, num_pred])$.pred, tolerance = 0.0001)
7474

7575
res_form <- fit(
7676
iris_basic,
@@ -79,10 +79,10 @@ test_that('glmnet prediction, single lambda', {
7979
control = ctrl
8080
)
8181

82-
form_pred <- c(5.24228948237804, 5.09448280355765, 5.15636527125752, 5.12592317615935,
83-
5.26930099973607)
82+
form_pred <- c(5.23960117346944, 5.08769210344022, 5.15129212608077, 5.12000510716518,
83+
5.26736239856889)
8484

85-
expect_equal(form_pred, predict(res_form, iris[1:5,])$.pred)
85+
expect_equal(form_pred, predict(res_form, iris[1:5,])$.pred, tolerance = 0.0001)
8686
})
8787

8888

@@ -132,7 +132,8 @@ test_that('glmnet prediction, multiple lambda', {
132132
as.data.frame(mult_pred),
133133
multi_predict(res_xy, new_data = iris[1:5, num_pred], lambda = lams) %>%
134134
unnest() %>%
135-
as.data.frame()
135+
as.data.frame(),
136+
tolerance = 0.0001
136137
)
137138

138139
res_form <- fit(
@@ -176,7 +177,8 @@ test_that('glmnet prediction, multiple lambda', {
176177
as.data.frame(form_pred),
177178
multi_predict(res_form, new_data = iris[1:5, ], lambda = lams) %>%
178179
unnest() %>%
179-
as.data.frame()
180+
as.data.frame(),
181+
tolerance = 0.0001
180182
)
181183
})
182184

@@ -249,7 +251,7 @@ test_that('submodel prediction', {
249251
)
250252

251253
reg_fit <-
252-
linear_reg(penalty = c(0, 0.01, 0.1)) %>%
254+
linear_reg() %>%
253255
set_engine("glmnet") %>%
254256
fit(mpg ~ ., data = mtcars[-(1:4), ])
255257

@@ -274,12 +276,6 @@ test_that('error traps', {
274276

275277
skip_if_not_installed("glmnet")
276278

277-
expect_error(
278-
linear_reg(penalty = .1) %>%
279-
set_engine("glmnet") %>%
280-
fit(mpg ~ ., data = mtcars[-(1:4), ]) %>%
281-
predict(mtcars[-(1:4), ], penalty = .2)
282-
)
283279
expect_error(
284280
linear_reg() %>%
285281
set_engine("glmnet") %>%

0 commit comments

Comments
 (0)