Skip to content

Commit b8f02a3

Browse files
committed
Updates tests for new glmnet error
1 parent 83c5900 commit b8f02a3

File tree

3 files changed

+33
-73
lines changed

3 files changed

+33
-73
lines changed

tests/testthat/test_linear_reg.R

Lines changed: 13 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,10 @@ hpc <- hpc_data[1:150, c(2:5, 8)]
1515
test_that('primary arguments', {
1616
basic <- linear_reg()
1717
basic_lm <- translate(basic %>% set_engine("lm"))
18-
basic_glmnet <- translate(basic %>% set_engine("glmnet"))
18+
expect_error(
19+
basic_glmnet <- translate(basic %>% set_engine("glmnet")),
20+
"For the glmnet engine, `penalty` must be a single"
21+
)
1922
basic_stan <- translate(basic %>% set_engine("stan"))
2023
basic_spark <- translate(basic %>% set_engine("spark"))
2124
expect_equal(basic_lm$method$fit$args,
@@ -25,14 +28,6 @@ test_that('primary arguments', {
2528
weights = expr(missing_arg())
2629
)
2730
)
28-
expect_equal(basic_glmnet$method$fit$args,
29-
list(
30-
x = expr(missing_arg()),
31-
y = expr(missing_arg()),
32-
weights = expr(missing_arg()),
33-
family = "gaussian"
34-
)
35-
)
3631
expect_equal(basic_stan$method$fit$args,
3732
list(
3833
formula = expr(missing_arg()),
@@ -51,17 +46,11 @@ test_that('primary arguments', {
5146
)
5247

5348
mixture <- linear_reg(mixture = 0.128)
54-
mixture_glmnet <- translate(mixture %>% set_engine("glmnet"))
55-
mixture_spark <- translate(mixture %>% set_engine("spark"))
56-
expect_equal(mixture_glmnet$method$fit$args,
57-
list(
58-
x = expr(missing_arg()),
59-
y = expr(missing_arg()),
60-
weights = expr(missing_arg()),
61-
alpha = new_empty_quosure(0.128),
62-
family = "gaussian"
63-
)
49+
expect_error(
50+
mixture_glmnet <- translate(mixture %>% set_engine("glmnet")),
51+
"For the glmnet engine, `penalty` must be a single"
6452
)
53+
mixture_spark <- translate(mixture %>% set_engine("spark"))
6554
expect_equal(mixture_spark$method$fit$args,
6655
list(
6756
x = expr(missing_arg()),
@@ -92,17 +81,11 @@ test_that('primary arguments', {
9281
)
9382

9483
mixture_v <- linear_reg(mixture = varying())
95-
mixture_v_glmnet <- translate(mixture_v %>% set_engine("glmnet"))
96-
mixture_v_spark <- translate(mixture_v %>% set_engine("spark"))
97-
expect_equal(mixture_v_glmnet$method$fit$args,
98-
list(
99-
x = expr(missing_arg()),
100-
y = expr(missing_arg()),
101-
weights = expr(missing_arg()),
102-
alpha = new_empty_quosure(varying()),
103-
family = "gaussian"
104-
)
84+
expect_error(
85+
mixture_v_glmnet <- translate(mixture_v %>% set_engine("glmnet")),
86+
"For the glmnet engine, `penalty` must be a single"
10587
)
88+
mixture_v_spark <- translate(mixture_v %>% set_engine("spark"))
10689
expect_equal(mixture_v_spark$method$fit$args,
10790
list(
10891
x = expr(missing_arg()),
@@ -125,7 +108,7 @@ test_that('engine arguments', {
125108
)
126109
)
127110

128-
glmnet_nlam <- linear_reg() %>% set_engine("glmnet", nlambda = 10)
111+
glmnet_nlam <- linear_reg(penalty = 0.1) %>% set_engine("glmnet", nlambda = 10)
129112
expect_equal(translate(glmnet_nlam)$method$fit$args,
130113
list(
131114
x = expr(missing_arg()),

tests/testthat/test_logistic_reg.R

Lines changed: 13 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,10 @@ hpc <- hpc_data[1:150, c(2:5, 8)]
1616
test_that('primary arguments', {
1717
basic <- logistic_reg()
1818
basic_glm <- translate(basic %>% set_engine("glm"))
19-
basic_glmnet <- translate(basic %>% set_engine("glmnet"))
19+
expect_error(
20+
basic_glmnet <- translate(basic %>% set_engine("glmnet")),
21+
"For the glmnet engine, `penalty` must be a single"
22+
)
2023
basic_liblinear <- translate(basic %>% set_engine("LiblineaR"))
2124
basic_stan <- translate(basic %>% set_engine("stan"))
2225
basic_spark <- translate(basic %>% set_engine("spark"))
@@ -28,14 +31,6 @@ test_that('primary arguments', {
2831
family = expr(stats::binomial)
2932
)
3033
)
31-
expect_equal(basic_glmnet$method$fit$args,
32-
list(
33-
x = expr(missing_arg()),
34-
y = expr(missing_arg()),
35-
weights = expr(missing_arg()),
36-
family = "binomial"
37-
)
38-
)
3934
expect_equal(basic_liblinear$method$fit$args,
4035
list(
4136
x = expr(missing_arg()),
@@ -63,17 +58,11 @@ test_that('primary arguments', {
6358
)
6459

6560
mixture <- logistic_reg(mixture = 0.128)
66-
mixture_glmnet <- translate(mixture %>% set_engine("glmnet"))
67-
mixture_spark <- translate(mixture %>% set_engine("spark"))
68-
expect_equal(mixture_glmnet$method$fit$args,
69-
list(
70-
x = expr(missing_arg()),
71-
y = expr(missing_arg()),
72-
weights = expr(missing_arg()),
73-
alpha = new_empty_quosure(0.128),
74-
family = "binomial"
75-
)
61+
expect_error(
62+
mixture_glmnet <- translate(mixture %>% set_engine("glmnet")),
63+
"For the glmnet engine, `penalty` must be a single"
7664
)
65+
mixture_spark <- translate(mixture %>% set_engine("spark"))
7766
expect_equal(mixture_spark$method$fit$args,
7867
list(
7968
x = expr(missing_arg()),
@@ -116,18 +105,12 @@ test_that('primary arguments', {
116105
)
117106

118107
mixture_v <- logistic_reg(mixture = varying())
119-
mixture_v_glmnet <- translate(mixture_v %>% set_engine("glmnet"))
108+
expect_error(
109+
mixture_v_glmnet <- translate(mixture_v %>% set_engine("glmnet")),
110+
"For the glmnet engine, `penalty` must be a single"
111+
)
120112
mixture_v_liblinear <- translate(mixture_v %>% set_engine("LiblineaR"))
121113
mixture_v_spark <- translate(mixture_v %>% set_engine("spark"))
122-
expect_equal(mixture_v_glmnet$method$fit$args,
123-
list(
124-
x = expr(missing_arg()),
125-
y = expr(missing_arg()),
126-
weights = expr(missing_arg()),
127-
alpha = new_empty_quosure(varying()),
128-
family = "binomial"
129-
)
130-
)
131114
expect_equal(mixture_v_liblinear$method$fit$args,
132115
list(
133116
x = expr(missing_arg()),
@@ -194,7 +177,7 @@ test_that('engine arguments', {
194177
)
195178
)
196179

197-
glmnet_nlam <- logistic_reg()
180+
glmnet_nlam <- logistic_reg(penalty = 0.1)
198181
expect_equal(
199182
translate(glmnet_nlam %>% set_engine("glmnet", nlambda = 10))$method$fit$args,
200183
list(

tests/testthat/test_multinom_reg.R

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -13,17 +13,11 @@ hpc <- hpc_data[1:150, c(2:5, 8)]
1313

1414
test_that('primary arguments', {
1515
basic <- multinom_reg()
16-
basic_glmnet <- translate(basic %>% set_engine("glmnet"))
17-
expect_equal(basic_glmnet$method$fit$args,
18-
list(
19-
x = expr(missing_arg()),
20-
y = expr(missing_arg()),
21-
weights = expr(missing_arg()),
22-
family = "multinomial"
23-
)
16+
expect_error(
17+
basic_glmnet <- translate(basic %>% set_engine("glmnet")),
18+
"For the glmnet engine, `penalty` must be a single"
2419
)
25-
26-
mixture <- multinom_reg(mixture = 0.128)
20+
mixture <- multinom_reg(penalty = 0.1, mixture = 0.128)
2721
mixture_glmnet <- translate(mixture %>% set_engine("glmnet"))
2822
expect_equal(mixture_glmnet$method$fit$args,
2923
list(
@@ -46,7 +40,7 @@ test_that('primary arguments', {
4640
)
4741
)
4842

49-
mixture_v <- multinom_reg(mixture = varying())
43+
mixture_v <- multinom_reg(penalty = 0.01, mixture = varying())
5044
mixture_v_glmnet <- translate(mixture_v %>% set_engine("glmnet"))
5145
expect_equal(mixture_v_glmnet$method$fit$args,
5246
list(
@@ -61,7 +55,7 @@ test_that('primary arguments', {
6155
})
6256

6357
test_that('engine arguments', {
64-
glmnet_nlam <- multinom_reg()
58+
glmnet_nlam <- multinom_reg(penalty = 0.01)
6559
expect_equal(
6660
translate(glmnet_nlam %>% set_engine("glmnet", nlambda = 10))$method$fit$args,
6761
list(
@@ -117,5 +111,5 @@ test_that('bad input', {
117111
expect_error(multinom_reg(mode = "regression"))
118112
expect_error(translate(multinom_reg() %>% set_engine("wat?")))
119113
expect_error(translate(multinom_reg() %>% set_engine()))
120-
expect_warning(translate(multinom_reg() %>% set_engine("glmnet", x = hpc[,1:3], y = hpc$class)))
114+
expect_warning(translate(multinom_reg(penalty = 0.01) %>% set_engine("glmnet", x = hpc[,1:3], y = hpc$class)))
121115
})

0 commit comments

Comments
 (0)