Skip to content

Commit ccc2fd5

Browse files
committed
prev failed cases now pass; had to readjust argument order for tests
1 parent 57d4cab commit ccc2fd5

File tree

2 files changed

+111
-126
lines changed

2 files changed

+111
-126
lines changed

tests/testthat/test_logistic_reg.R

Lines changed: 63 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -12,34 +12,33 @@ test_that('primary arguments', {
1212
expect_equal(basic_glm$method$fit_args,
1313
list(
1414
formula = quote(missing_arg()),
15-
family = quote(binomial),
1615
data = quote(missing_arg()),
17-
weights = quote(missing_arg())
16+
weights = quote(missing_arg()),
17+
family = quote(binomial)
1818
)
1919
)
2020
expect_equal(basic_glmnet$method$fit_args,
2121
list(
2222
x = quote(missing_arg()),
2323
y = quote(missing_arg()),
24-
family = "binomial",
25-
weights = quote(missing_arg())
24+
weights = quote(missing_arg()),
25+
family = "binomial"
2626
)
2727
)
2828
expect_equal(basic_stan$method$fit_args,
2929
list(
3030
formula = quote(missing_arg()),
31-
family = quote(binomial),
3231
data = quote(missing_arg()),
33-
weights = quote(missing_arg())
32+
weights = quote(missing_arg()),
33+
family = quote(binomial)
3434
)
3535
)
3636
expect_equal(basic_spark$method$fit_args,
3737
list(
3838
x = quote(missing_arg()),
39+
formula = quote(missing_arg()),
3940
weight_col = quote(missing_arg()),
40-
features_col = quote(missing_arg()),
41-
label_col = quote(missing_arg()),
42-
family = quote(binomial)
41+
family = "binomial"
4342
)
4443
)
4544

@@ -50,19 +49,18 @@ test_that('primary arguments', {
5049
list(
5150
x = quote(missing_arg()),
5251
y = quote(missing_arg()),
53-
family = "binomial",
5452
weights = quote(missing_arg()),
55-
alpha = 0.128
53+
alpha = 0.128,
54+
family = "binomial"
5655
)
5756
)
5857
expect_equal(mixture_spark$method$fit_args,
5958
list(
6059
x = quote(missing_arg()),
61-
elastic_net_param = 0.128,
60+
formula = quote(missing_arg()),
6261
weight_col = quote(missing_arg()),
63-
features_col = quote(missing_arg()),
64-
label_col = quote(missing_arg()),
65-
family = quote(binomial)
62+
elastic_net_param = 0.128,
63+
family = "binomial"
6664
)
6765
)
6866

@@ -73,19 +71,18 @@ test_that('primary arguments', {
7371
list(
7472
x = quote(missing_arg()),
7573
y = quote(missing_arg()),
76-
family = "binomial",
7774
weights = quote(missing_arg()),
78-
lambda = 1
75+
lambda = 1,
76+
family = "binomial"
7977
)
8078
)
8179
expect_equal(regularization_spark$method$fit_args,
8280
list(
8381
x = quote(missing_arg()),
84-
reg_param = 1,
82+
formula = quote(missing_arg()),
8583
weight_col = quote(missing_arg()),
86-
features_col = quote(missing_arg()),
87-
label_col = quote(missing_arg()),
88-
family = quote(binomial)
84+
reg_param = 1,
85+
family = "binomial"
8986
)
9087
)
9188

@@ -96,19 +93,18 @@ test_that('primary arguments', {
9693
list(
9794
x = quote(missing_arg()),
9895
y = quote(missing_arg()),
99-
family = "binomial",
10096
weights = quote(missing_arg()),
101-
alpha = varying()
97+
alpha = varying(),
98+
family = "binomial"
10299
)
103100
)
104101
expect_equal(mixture_v_spark$method$fit_args,
105102
list(
106103
x = quote(missing_arg()),
107-
elastic_net_param = varying(),
104+
formula = quote(missing_arg()),
108105
weight_col = quote(missing_arg()),
109-
features_col = quote(missing_arg()),
110-
label_col = quote(missing_arg()),
111-
family = quote(binomial)
106+
elastic_net_param = varying(),
107+
family = "binomial"
112108
)
113109
)
114110

@@ -119,9 +115,9 @@ test_that('engine arguments', {
119115
expect_equal(translate(glm_fam, engine = "glm")$method$fit_args,
120116
list(
121117
formula = quote(missing_arg()),
122-
family = quote(binomial(link = "probit")),
123118
data = quote(missing_arg()),
124-
weights = quote(missing_arg())
119+
weights = quote(missing_arg()),
120+
family = quote(binomial(link = "probit"))
125121
)
126122
)
127123

@@ -130,33 +126,32 @@ test_that('engine arguments', {
130126
list(
131127
x = quote(missing_arg()),
132128
y = quote(missing_arg()),
133-
family = "binomial",
134129
weights = quote(missing_arg()),
135-
nlambda = 10
130+
nlambda = 10,
131+
family = "binomial"
136132
)
137133
)
138134

139135
stan_samp <- logistic_reg(others = list(chains = 1, iter = 5))
140136
expect_equal(translate(stan_samp, engine = "stan")$method$fit_args,
141137
list(
142138
formula = quote(missing_arg()),
143-
family = quote(binomial),
144139
data = quote(missing_arg()),
145140
weights = quote(missing_arg()),
146141
chains = 1,
147-
iter = 5
142+
iter = 5,
143+
family = quote(binomial)
148144
)
149145
)
150146

151147
spark_iter <- logistic_reg(others = list(max_iter = 20))
152148
expect_equal(translate(spark_iter, engine = "spark")$method$fit_args,
153149
list(
154150
x = quote(missing_arg()),
155-
max_iter = 20,
151+
formula = quote(missing_arg()),
156152
weight_col = quote(missing_arg()),
157-
features_col = quote(missing_arg()),
158-
label_col = quote(missing_arg()),
159-
family = quote(binomial)
153+
max_iter = 20,
154+
family = "binomial"
160155
)
161156
)
162157

@@ -194,7 +189,6 @@ test_that('bad input', {
194189
expect_error(logistic_reg(mixture = -1))
195190
expect_error(translate(logistic_reg(), engine = "wat?"))
196191
expect_warning(translate(logistic_reg(), engine = NULL))
197-
expect_warning(translate(logistic_reg(others = list(ytest = 2)), engine = "glmnet"))
198192
expect_error(translate(logistic_reg(formula = y ~ x)))
199193
expect_warning(translate(logistic_reg(others = list(x = iris[,1:3], y = iris$Species)), engine = "glmnet"))
200194
expect_warning(translate(logistic_reg(others = list(formula = y ~ x)), engine = "glm"))
@@ -270,15 +264,14 @@ test_that('glm execution', {
270264
)
271265
expect_true(inherits(glm_form_catch, "try-error"))
272266

273-
# fails
274-
# glm_xy_catch <- fit(
275-
# lc_basic,
276-
# engine = "glm",
277-
# control = caught_ctrl,
278-
# x = lending_club[, num_pred],
279-
# y = lending_club$total_bal_il
280-
# )
281-
# expect_true(inherits(glm_xy_catch, "try-error"))
267+
glm_xy_catch <- fit(
268+
lc_basic,
269+
engine = "glm",
270+
control = caught_ctrl,
271+
x = lending_club[, num_pred],
272+
y = lending_club$total_bal_il
273+
)
274+
expect_true(inherits(glm_xy_catch, "try-error"))
282275

283276
glm_rec_catch <- fit(
284277
lc_basic,
@@ -293,17 +286,16 @@ test_that('glm execution', {
293286
test_that('glmnet execution', {
294287
skip_on_cran()
295288

296-
# fails because `glment` requires a matrix
297-
# expect_error(
298-
# fit(
299-
# lc_basic,
300-
# lc_form,
301-
# data = lending_club,
302-
# engine = "glmnet",
303-
# control = ctrl
304-
# ),
305-
# regexp = NA
306-
# )
289+
expect_error(
290+
fit(
291+
lc_basic,
292+
lc_form,
293+
data = lending_club,
294+
engine = "glmnet",
295+
control = ctrl
296+
),
297+
regexp = NA
298+
)
307299

308300
expect_error(
309301
fit(
@@ -316,7 +308,11 @@ test_that('glmnet execution', {
316308
regexp = NA
317309
)
318310

319-
# fails because `glment` requires a matrix
311+
# TODO: fails because the recipe tries to convert a data frame containing a
312+
# factor to a matrix (and trips an error checker). This is supposed to work
313+
# well with multivariate data when the model interface is a matrix but it
314+
# shouldn't automatically do that for a single column non-numeric data set.
315+
# One more coded exception
320316
# expect_error(
321317
# fit(
322318
# lc_basic,
@@ -433,15 +429,14 @@ test_that('stan_glm execution', {
433429
)
434430
expect_true(inherits(stan_form_catch, "try-error"))
435431

436-
# fails
437-
# stan_xy_catch <- fit(
438-
# lc_basic,
439-
# engine = "stan",
440-
# control = caught_ctrl,
441-
# x = lending_club[, num_pred],
442-
# y = lending_club$total_bal_il
443-
# )
444-
# expect_true(inherits(stan_xy_catch, "try-error"))
432+
stan_xy_catch <- fit(
433+
lc_basic,
434+
engine = "stan",
435+
control = caught_ctrl,
436+
x = lending_club[, num_pred],
437+
y = lending_club$total_bal_il
438+
)
439+
expect_true(inherits(stan_xy_catch, "try-error"))
445440

446441
stan_rec_catch <- fit(
447442
lc_basic,

0 commit comments

Comments
 (0)