Skip to content

Commit 0ec108d

Browse files
authored
changes for issues #120 and #115 (#159)
For classification problems, an error is thrown if the outcome is not a factor: ``` Error: For classification models, the outcome should be a factor. ``` There are a lot of travis related changes too: * To get travis to run tests, the modeling packages have to be formal dependencies so a bunch were added to Suggests. This _may_ be temporary; I may decide to remove these for the version sent to CRAN. `rstanarm` was excluded because compiling it (and its dependencies) exceeded the time allowed by travis. * A variety of changes were made to the tests related to r-devel. It looks like any function directly accessed in the tests now need to be formal dependencies (not the case before). There is still some weirdness about this though and I've attributed it to being _devel_. Hopefully this will get smoothed out.
1 parent 52ddb78 commit 0ec108d

19 files changed

+284
-173
lines changed

.Rbuildignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,6 @@
77
^\.Rproj\.user$
88
^.travis.yml$
99
^R/README\.md$
10+
derby.log
11+
^logs$
12+
^tests/testthat/logs$

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,5 @@
55
tests/testthat/derby.log
66
tests/testthat/logs/
77
*.history
8+
derby.log
9+
logs/*

.travis.yml

Lines changed: 45 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -8,67 +8,76 @@ sudo: true
88
warnings_are_errors: false
99

1010
r:
11-
- 3.1
12-
- 3.2
13-
- oldrel
14-
- release
15-
- devel
11+
- 3.1
12+
- 3.2
13+
- oldrel
14+
- release
15+
- devel
1616

17-
env:
18-
global:
19-
- KERAS_BACKEND="tensorflow"
20-
- MAKEFLAGS="-j 2"
2117

22-
# until we troubleshoot these issues
2318
matrix:
2419
allow_failures:
2520
- r: 3.1
2621
- r: 3.2
2722

2823
r_binary_packages:
29-
- rstan
30-
- rstanarm
31-
- RCurl
32-
- dplyr
33-
- glue
34-
- magrittr
35-
- stringi
36-
- stringr
37-
- munsell
38-
- rlang
39-
- reshape2
40-
- scales
41-
- tibble
42-
- ggplot2
43-
- StanHeaders
44-
- Rcpp
45-
- RcppEigen
46-
- BH
47-
- glmnet
48-
- earth
49-
- sparklyr
50-
- flexsurv
51-
- ranger
52-
- randomforest
53-
- xgboost
54-
- C50
24+
- RCurl
25+
- dplyr
26+
- glue
27+
- magrittr
28+
- stringi
29+
- stringr
30+
- munsell
31+
- rlang
32+
- reshape2
33+
- scales
34+
- tibble
35+
- ggplot2
36+
- Rcpp
37+
- RcppEigen
38+
- BH
39+
- glmnet
40+
- earth
41+
- sparklyr
42+
- flexsurv
43+
- ranger
44+
- randomforest
45+
- xgboost
46+
- C50
47+
5548

5649
cache:
5750
packages: true
5851
directories:
5952
- $HOME/.keras
6053
- $HOME/.cache/pip
6154

55+
env:
56+
global:
57+
- KERAS_BACKEND="tensorflow"
58+
- MAKEFLAGS="-j 2"
59+
60+
addons:
61+
apt:
62+
sources:
63+
- ubuntu-toolchain-r-test
64+
packages:
65+
g++-6
6266

6367
before_script:
6468
- python -m pip install --upgrade --ignore-installed --user travis pip setuptools wheel virtualenv
6569
- python -m pip install --upgrade --ignore-installed --user travis keras h5py pyyaml requests Pillow scipy theano
6670
- R -e 'tensorflow::install_tensorflow()'
6771

72+
6873
before_install:
6974
- sudo apt-get -y install libnlopt-dev
7075
- sudo apt-get update
7176
- sudo apt-get -y install python3
77+
- mkdir -p ~/.R && echo "CXX14=g++-6" > ~/.R/Makevars
78+
- echo "CXX14FLAGS += -fPIC" >> ~/.R/Makevars
79+
7280

7381
after_success:
7482
- Rscript -e 'covr::codecov()'
83+

DESCRIPTION

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,4 +38,14 @@ Suggests:
3838
keras,
3939
xgboost,
4040
covr,
41-
sparklyr
41+
C50,
42+
sparklyr,
43+
earth,
44+
glmnet,
45+
kernlab,
46+
kknn,
47+
randomForest,
48+
ranger,
49+
rpart,
50+
MASS,
51+
nlme

NEWS.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,11 @@ that are actually varying).
1414

1515
* `fit_control()` not returns an S3 method.
1616

17+
* For classification models, an error occurs if the outcome data are not encoded as factors (#115).
18+
1719
* The prediction modules (e.g. `predict_class`, `predict_numeric`, etc) were de-exported. These were internal functions that were not to be used by the users and the users were using them.
1820

21+
1922
## Bug Fixes
2023

2124
* `varying_args()` now uses the version from the `generics` package. This means

R/fit_helpers.R

Lines changed: 27 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -6,21 +6,21 @@
66
#' @importFrom stats model.frame model.response terms as.formula model.matrix
77
form_form <-
88
function(object, control, env, ...) {
9-
opts <- quos(...)
109

11-
if (object$mode != "regression") {
12-
y_levels <- levels_from_formula( # prob rewrite this as simple subset/levels
13-
env$formula,
14-
env$data
15-
)
10+
if (object$mode == "classification") {
11+
# prob rewrite this as simple subset/levels
12+
y_levels <- levels_from_formula(env$formula, env$data)
13+
if (!inherits(env$data, "tbl_spark") && is.null(y_levels))
14+
stop("For classification models, the outcome should be a factor.",
15+
call. = FALSE)
1616
} else {
1717
y_levels <- NULL
1818
}
1919

2020
object <- check_mode(object, y_levels)
2121

2222
# if descriptors are needed, update descr_env with the calculated values
23-
if(requires_descrs(object)) {
23+
if (requires_descrs(object)) {
2424
data_stats <- get_descr_form(env$formula, env$data)
2525
scoped_descrs(data_stats)
2626
}
@@ -71,8 +71,14 @@ xy_xy <- function(object, env, control, target = "none", ...) {
7171

7272
object <- check_mode(object, levels(env$y))
7373

74+
if (object$mode == "classification") {
75+
if (is.null(levels(env$y)))
76+
stop("For classification models, the outcome should be a factor.",
77+
call. = FALSE)
78+
}
79+
7480
# if descriptors are needed, update descr_env with the calculated values
75-
if(requires_descrs(object)) {
81+
if (requires_descrs(object)) {
7682
data_stats <- get_descr_form(env$formula, env$data)
7783
scoped_descrs(data_stats)
7884
}
@@ -125,13 +131,12 @@ form_xy <- function(object, control, env,
125131
env$x <- data_obj$x
126132
env$y <- data_obj$y
127133

128-
res <- list(
129-
lvl = levels_from_formula(
130-
env$formula,
131-
env$data
132-
),
133-
spec = object
134-
)
134+
res <- list(lvl = levels_from_formula(env$formula, env$data), spec = object)
135+
if (object$mode == "classification") {
136+
if (is.null(res$lvl))
137+
stop("For classification models, the outcome should be a factor.",
138+
call. = FALSE)
139+
}
135140

136141
res <- xy_xy(
137142
object = object,
@@ -148,6 +153,13 @@ form_xy <- function(object, control, env,
148153
}
149154

150155
xy_form <- function(object, env, control, ...) {
156+
157+
if (object$mode == "classification") {
158+
if (is.null(levels(env$y)))
159+
stop("For classification models, the outcome should be a factor.",
160+
call. = FALSE)
161+
}
162+
151163
data_obj <-
152164
convert_xy_to_form_fit(
153165
x = env$x,

R/multinom_reg.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ check_args.multinom_reg <- function(object) {
168168

169169
args <- lapply(object$args, rlang::eval_tidy)
170170

171-
if (is.numeric(args$penalty) && args$penalty < 0)
171+
if (all(is.numeric(args$penalty)) && any(args$penalty < 0))
172172
stop("The amount of regularization should be >= 0", call. = FALSE)
173173
if (is.numeric(args$mixture) && (args$mixture < 0 | args$mixture > 1))
174174
stop("The mixture proportion should be within [0,1]", call. = FALSE)

tests/testthat/test_boost_tree_C50.R

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,17 @@
11
library(testthat)
22
library(parsnip)
33
library(tibble)
4+
library(dplyr)
45

56
# ------------------------------------------------------------------------------
67

78
context("boosted tree execution with C5.0")
89

910
data("lending_club")
1011
lending_club <- head(lending_club, 200)
12+
lending_club_fail <-
13+
lending_club %>%
14+
mutate(bad = Inf, miss = NA)
1115
num_pred <- c("funded_amnt", "annual_inc", "num_il_tl")
1216
lc_basic <-
1317
boost_tree(mode = "classification") %>%
@@ -41,6 +45,8 @@ test_that('C5.0 execution', {
4145
),
4246
regexp = NA
4347
)
48+
49+
# outcome is not a factor:
4450
expect_error(
4551
res <- fit(
4652
lc_basic,
@@ -51,19 +57,21 @@ test_that('C5.0 execution', {
5157
)
5258
)
5359

60+
# Model fails
5461
C5.0_form_catch <- fit(
5562
lc_basic,
56-
funded_amnt ~ term,
57-
data = lending_club,
63+
Class ~ miss,
64+
data = lending_club_fail,
5865
control = caught_ctrl
5966
)
6067
expect_true(inherits(C5.0_form_catch$fit, "try-error"))
6168

69+
# Model fails
6270
C5.0_xy_catch <- fit_xy(
6371
lc_basic,
6472
control = caught_ctrl,
65-
x = lending_club[, num_pred],
66-
y = lending_club$total_bal_il
73+
x = lending_club_fail[, "miss"],
74+
y = lending_club_fail$Class
6775
)
6876
expect_true(inherits(C5.0_xy_catch$fit, "try-error"))
6977
})
@@ -108,11 +116,12 @@ test_that('C5.0 probabilities', {
108116
test_that('submodel prediction', {
109117

110118
skip_if_not_installed("C50")
119+
library(C50)
111120

112121
vars <- c("female", "tenure", "total_charges", "phone_service", "monthly_charges")
113122
class_fit <-
114123
boost_tree(trees = 20, mode = "classification") %>%
115-
set_engine("C5.0", control = C50::C5.0Control(earlyStopping = FALSE)) %>%
124+
set_engine("C5.0", control = C5.0Control(earlyStopping = FALSE)) %>%
116125
fit(churn ~ ., data = wa_churn[-(1:4), c("churn", vars)])
117126

118127
pred_class <- predict(class_fit$fit, wa_churn[1:4, vars], trials = 4, type = "prob")

tests/testthat/test_linear_reg_stan.R

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -102,14 +102,21 @@ test_that('stan intervals', {
102102
type = "pred_int",
103103
level = 0.93)
104104

105-
prediction_stan <-
106-
predictive_interval(res_xy$fit, newdata = iris[1:5, ], seed = 13,
107-
prob = 0.93)
108-
109-
stan_post <- posterior_linpred(res_xy$fit, newdata = iris[1:5, ],
110-
seed = 13)
111-
stan_lower <- apply(stan_post, 2, quantile, prob = 0.035)
112-
stan_upper <- apply(stan_post, 2, quantile, prob = 0.965)
105+
# prediction_stan <-
106+
# predictive_interval(res_xy$fit, newdata = iris[1:5, ], seed = 13,
107+
# prob = 0.93)
108+
#
109+
# stan_post <- posterior_linpred(res_xy$fit, newdata = iris[1:5, ],
110+
# seed = 13)
111+
# stan_lower <- apply(stan_post, 2, quantile, prob = 0.035)
112+
# stan_upper <- apply(stan_post, 2, quantile, prob = 0.965)
113+
114+
stan_lower <- c(`1` = 4.93164991101342, `2` = 4.60197941230393,
115+
`3` = 4.6671442757811, `4` = 4.74402724639963,
116+
`5` = 4.99248110476701)
117+
stan_upper <- c(`1` = 5.1002837047058, `2` = 4.77617561853506,
118+
`3` = 4.83183673602725, `4` = 4.90844811805409,
119+
`5` = 5.16979395659009)
113120

114121
expect_equivalent(confidence_parsnip$.pred_lower, stan_lower)
115122
expect_equivalent(confidence_parsnip$.pred_upper, stan_upper)

tests/testthat/test_logistic_reg.R

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -244,23 +244,24 @@ test_that('glm execution', {
244244
)
245245
)
246246

247-
# passes interactively but not on R CMD check
248-
# glm_form_catch <- fit(
249-
# lc_basic,
250-
# funded_amnt ~ term,
251-
# data = lending_club,
252-
#
253-
# control = caught_ctrl
254-
# )
255-
# expect_true(inherits(glm_form_catch$fit, "try-error"))
247+
# wrong outcome type
248+
expect_error(
249+
glm_form_catch <- fit(
250+
lc_basic,
251+
funded_amnt ~ term,
252+
data = lending_club,
253+
control = caught_ctrl
254+
)
255+
)
256256

257-
glm_xy_catch <- fit_xy(
258-
lc_basic,
259-
control = caught_ctrl,
260-
x = lending_club[, num_pred],
261-
y = lending_club$total_bal_il
257+
expect_error(
258+
glm_xy_catch <- fit_xy(
259+
lc_basic,
260+
control = caught_ctrl,
261+
x = lending_club[, num_pred],
262+
y = lending_club$total_bal_il
263+
)
262264
)
263-
expect_true(inherits(glm_xy_catch$fit, "try-error"))
264265
})
265266

266267
test_that('glm prediction', {

0 commit comments

Comments
 (0)