Skip to content

Commit 3fb3694

Browse files
Test that all models work with as_parsed_model() (#163)
1 parent e5414d2 commit 3fb3694

21 files changed

+102
-178
lines changed

R/model-cubist.R

Lines changed: 28 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -64,13 +64,33 @@ parse_model.cubist <- function(model) {
6464
}
6565
)
6666
comm <- purrr::list_flatten(committees2)
67+
68+
if (model$committees == 1) {
69+
ommittee_id <- rep(1, length(comm))
70+
} else {
71+
model_print <- utils::capture.output(print(model))
72+
model_print <- model_print[grep(
73+
"Number of rules per committee",
74+
model_print
75+
)]
76+
model_print <- regmatches(
77+
model_print,
78+
m = gregexpr("[0-9]+", model_print)
79+
)[[
80+
1
81+
]]
82+
ommittee_id <- as.integer(model_print)
83+
ommittee_id <- rep(seq_along(ommittee_id), times = ommittee_id)
84+
}
85+
6786
pm <- list(
6887
general = list(
6988
model = "cubist",
7089
type = "tree",
71-
version = 2,
90+
version = 3,
7291
mode = "ifelse",
73-
divisor = model$committees
92+
n_committees = model$committees,
93+
ommittee_id = ommittee_id
7494
),
7595
trees = list(comm)
7696
)
@@ -80,28 +100,15 @@ parse_model.cubist <- function(model) {
80100
#' @export
81101
tidypredict_fit.cubist <- function(model) {
82102
parsedmodel <- parse_model(model)
103+
tidypredict_fit_cubist(parsedmodel)
104+
}
105+
106+
tidypredict_fit_cubist <- function(parsedmodel) {
83107
rules <- generate_tree_nodes(parsedmodel$trees[[1]], parsedmodel$general$mode)
84108
paths <- lapply(parsedmodel$trees[[1]], function(x) path_formulas(x$path))
85109

86-
n_committees <- model$committees
87-
88-
if (n_committees == 1) {
89-
ommittee_id <- rep(1, length(rules))
90-
} else {
91-
model_print <- utils::capture.output(print(model))
92-
model_print <- model_print[grep(
93-
"Number of rules per committee",
94-
model_print
95-
)]
96-
model_print <- regmatches(
97-
model_print,
98-
m = gregexpr("[0-9]+", model_print)
99-
)[[
100-
1
101-
]]
102-
ommittee_id <- as.integer(model_print)
103-
ommittee_id <- rep(seq_along(ommittee_id), times = ommittee_id)
104-
}
110+
n_committees <- parsedmodel$general$n_committees
111+
ommittee_id <- parsedmodel$general$ommittee_id
105112

106113
committees <- purrr::map2(
107114
split(rules, ommittee_id),

R/model-rf.R

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,10 @@ parse_model.randomForest <- function(model) {
8686
#' @export
8787
tidypredict_fit.randomForest <- function(model) {
8888
parsedmodel <- parse_model(model)
89+
tidypredict_fit_randomForest(parsedmodel)
90+
}
91+
92+
tidypredict_fit_randomForest <- function(parsedmodel) {
8993
res <- generate_case_when_trees(parsedmodel)
9094
res <- reduce_addition(res)
9195
n_trees <- length(parsedmodel$trees)

R/predict-fit.R

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,16 @@ tidypredict_fit.pm_regression <- function(model) {
2121

2222
#' @export
2323
tidypredict_fit.pm_tree <- function(model) {
24-
generate_case_when_trees(model)
24+
if (model$general$model == "cubist") {
25+
return(tidypredict_fit_cubist(model))
26+
}
27+
if (model$general$model == "randomForest") {
28+
return(tidypredict_fit_randomForest(model))
29+
}
30+
31+
res <- generate_case_when_trees(model)
32+
33+
reduce_addition(res)
2534
}
2635

2736
#' @export

tests/testthat/_snaps/model-cubist.md

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,6 @@
55
Output
66
[1] "(37.2 + hp * -0.0318 + wt * -3.88 + (ifelse(disp > 95.099998, \n 14.89 + hp * -0.0406 + drat * 2.4, 0) + ifelse(disp <= 95.099998, \n 33.06, 0))/((disp > 95.099998) + (disp <= 95.099998)) + (37.26 + \n wt * -5.28))/3"
77

8-
# Model can be saved and re-loaded
9-
10-
Code
11-
tidypredict_fit(pm)
12-
Output
13-
(37.2 + hp * -0.0318 + wt * -3.88 + ifelse(disp > 95.099998,
14-
14.89 + hp * -0.0406 + drat * 2.4, 0) + ifelse(disp <= 95.099998,
15-
33.06, 0) + (37.26 + wt * -5.28))/3
16-
178
# formulas produces correct predictions
189

1910
Code

tests/testthat/_snaps/model-earth.md

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,6 @@
55
Output
66
[1] "20.534817535821 + (ifelse(disp < 145, 145 - disp, 0) * 0.148589866311) + \n (ifelse(disp > 145, disp - 145, 0) * -0.025012854678)"
77

8-
# Model can be saved and re-loaded
9-
10-
Code
11-
tidypredict_fit(pm)
12-
Output
13-
20.5348175 + (ifelse(disp < 145, 145 - disp, 0) * 0.1485899) +
14-
(ifelse(disp > 145, disp - 145, 0) * -0.0250129)
15-
168
# formulas produces correct predictions
179

1810
Code

tests/testthat/_snaps/model-glm.md

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,6 @@
55
Output
66
[1] "1.520331147866 + (wt * -0.372988616484) + (cyl * 0.013885491477)"
77

8-
# Model can be saved and re-loaded
9-
10-
Code
11-
tidypredict_fit(pm)
12-
Output
13-
1.5203311 + (wt * -0.3729886) + (cyl * 0.0138855)
14-
158
# formulas produces correct predictions
169

1710
Code

tests/testthat/_snaps/model-lm.md

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,6 @@
55
Output
66
[1] "1.520331147866 + (wt * -0.372988616484) + (cyl * 0.013885491477)"
77

8-
# Model can be saved and re-loaded
9-
10-
Code
11-
tidypredict_fit(pm)
12-
Output
13-
1.5203311 + (wt * -0.3729886) + (cyl * 0.0138855)
14-
158
# formulas produces correct predictions
169

1710
Code

tests/testthat/_snaps/model-partykit.md

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,6 @@
55
Output
66
[1] "case_when(cyl <= 4 ~ 26.6636363636364, cyl <= 6 & cyl > 4 ~ 19.7428571428571, \n cyl > 6 & cyl > 4 ~ 15.1)"
77

8-
# Model can be saved and re-loaded
9-
10-
Code
11-
tidypredict_fit(pm)
12-
Output
13-
[[1]]
14-
case_when(cyl <= 4 ~ 26.6636364, cyl <= 6 & cyl > 4 ~ 19.7428571,
15-
cyl > 6 & cyl > 4 ~ 15.1)
16-
17-
188
# formulas produces correct predictions
199

2010
Code

tests/testthat/_snaps/model-ranger.md

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -5,27 +5,6 @@
55
Output
66
[1] "case_when(Petal.Length < 2.6 ~ \"setosa\", Sepal.Length < 6.25 & \n Petal.Length >= 2.6 ~ \"versicolor\", Sepal.Length >= 6.25 & \n Petal.Length >= 2.6 ~ \"virginica\") + case_when(Petal.Width < \n 0.75 ~ \"setosa\", Petal.Width < 1.75 & Petal.Width >= 0.75 ~ \n \"versicolor\", Petal.Width >= 1.75 & Petal.Width >= 0.75 ~ \n \"virginica\") + case_when(Petal.Length < 2.35 ~ \"setosa\", \n Petal.Length < 4.75 & Petal.Length >= 2.35 ~ \"versicolor\", \n Petal.Length >= 4.75 & Petal.Length >= 2.35 ~ \"virginica\")"
77

8-
# Model can be saved and re-loaded
9-
10-
Code
11-
tidypredict_fit(pm)
12-
Output
13-
[[1]]
14-
case_when(Petal.Length < 2.6 ~ "setosa", Sepal.Length < 6.25 &
15-
Petal.Length >= 2.6 ~ "versicolor", Sepal.Length >= 6.25 &
16-
Petal.Length >= 2.6 ~ "virginica")
17-
18-
[[2]]
19-
case_when(Petal.Width < 0.75 ~ "setosa", Petal.Width < 1.75 &
20-
Petal.Width >= 0.75 ~ "versicolor", Petal.Width >= 1.75 &
21-
Petal.Width >= 0.75 ~ "virginica")
22-
23-
[[3]]
24-
case_when(Petal.Length < 2.35 ~ "setosa", Petal.Length < 4.75 &
25-
Petal.Length >= 2.35 ~ "versicolor", Petal.Length >= 4.75 &
26-
Petal.Length >= 2.35 ~ "virginica")
27-
28-
298
# formulas produces correct predictions
309

3110
Code

tests/testthat/_snaps/model-rf.md

Lines changed: 0 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -5,54 +5,6 @@
55
Output
66
[1] "(case_when(disp < 78.85 & wt < 2.3325 ~ 33.525, disp >= 78.85 & \n wt < 2.3325 ~ 27.425, disp >= 281 & wt >= 2.3325 ~ 14.1, \n drat < 3.695 & cyl < 5 & disp < 281 & wt >= 2.3325 ~ 24.4, \n drat >= 3.695 & cyl < 5 & disp < 281 & wt >= 2.3325 ~ 22.3666666666667, \n disp < 152.5 & wt < 3.3275 & cyl >= 5 & disp < 281 & wt >= \n 2.3325 ~ 19.7, disp >= 152.5 & wt < 3.3275 & cyl >= 5 & \n disp < 281 & wt >= 2.3325 ~ 21.1, disp < 196.3 & wt >= \n 3.3275 & cyl >= 5 & disp < 281 & wt >= 2.3325 ~ 18.92, \n disp >= 196.3 & wt >= 3.3275 & cyl >= 5 & disp < 281 & wt >= \n 2.3325 ~ 18.1) + case_when(hp < 80.5 & disp < 266.9 ~ \n 28.5333333333333, drat < 3.035 & disp >= 266.9 ~ 10.4, carb < \n 2.5 & drat >= 3.035 & disp >= 266.9 ~ 18.7, wt < 1.989 & \n hp < 118 & hp >= 80.5 & disp < 266.9 ~ 30.4, qsec < 18.6 & \n hp >= 118 & hp >= 80.5 & disp < 266.9 ~ 19.6, qsec >= 18.6 & \n hp >= 118 & hp >= 80.5 & disp < 266.9 ~ 17.8, drat >= 3.635 & \n carb >= 2.5 & drat >= 3.035 & disp >= 266.9 ~ 13.3, hp < \n 96 & wt >= 1.989 & hp < 118 & hp >= 80.5 & disp < 266.9 ~ \n 22.8, wt < 4.5625 & drat < 3.635 & carb >= 2.5 & drat >= \n 3.035 & disp >= 266.9 ~ 15.04, wt >= 4.5625 & drat < 3.635 & \n carb >= 2.5 & drat >= 3.035 & disp >= 266.9 ~ 14.7, qsec < \n 19.725 & hp >= 96 & wt >= 1.989 & hp < 118 & hp >= 80.5 & \n disp < 266.9 ~ 21.4, qsec >= 19.725 & hp >= 96 & wt >= 1.989 & \n hp < 118 & hp >= 80.5 & disp < 266.9 ~ 21.5) + case_when(wt < \n 3.4725 & drat < 3.75 ~ 20.8333333333333, qsec < 16.23 & wt >= \n 3.4725 & drat < 3.75 ~ 13.3, disp < 78.85 & disp < 130.55 & \n drat >= 3.75 ~ 32.2333333333333, disp >= 78.85 & disp < 130.55 & \n drat >= 3.75 ~ 27.75, cyl < 5 & disp >= 130.55 & drat >= \n 3.75 ~ 22.8, disp >= 456 & qsec >= 16.23 & wt >= 3.4725 & \n drat < 3.75 ~ 10.4, qsec < 17.66 & cyl >= 5 & disp >= 130.55 & \n drat >= 3.75 ~ 21, qsec >= 17.66 & cyl >= 5 & disp >= 130.55 & \n drat >= 3.75 ~ 19.2, qsec < 17.225 & disp < 456 & qsec >= \n 16.23 & wt >= 3.4725 & drat < 3.75 ~ 19.2, wt < 4.7075 & \n qsec >= 17.225 & disp < 456 & qsec >= 16.23 & wt >= 3.4725 & \n drat < 3.75 ~ 16.34, wt >= 4.7075 & qsec >= 17.225 & disp < \n 456 & qsec >= 16.23 & wt >= 3.4725 & drat < 3.75 ~ 14.7))/3L"
77

8-
# Model can be saved and re-loaded
9-
10-
Code
11-
tidypredict_fit(pm)
12-
Output
13-
[[1]]
14-
case_when(disp < 78.85 & wt < 2.3325 ~ 33.525, disp >= 78.85 &
15-
wt < 2.3325 ~ 27.425, disp >= 281 & wt >= 2.3325 ~ 14.1,
16-
drat < 3.695 & cyl < 5 & disp < 281 & wt >= 2.3325 ~ 24.4,
17-
drat >= 3.695 & cyl < 5 & disp < 281 & wt >= 2.3325 ~ 22.3666667,
18-
disp < 152.5 & wt < 3.3275 & cyl >= 5 & disp < 281 & wt >=
19-
2.3325 ~ 19.7, disp >= 152.5 & wt < 3.3275 & cyl >= 5 &
20-
disp < 281 & wt >= 2.3325 ~ 21.1, disp < 196.3 & wt >=
21-
3.3275 & cyl >= 5 & disp < 281 & wt >= 2.3325 ~ 18.92,
22-
disp >= 196.3 & wt >= 3.3275 & cyl >= 5 & disp < 281 & wt >=
23-
2.3325 ~ 18.1)
24-
25-
[[2]]
26-
case_when(hp < 80.5 & disp < 266.9 ~ 28.5333333, drat < 3.035 &
27-
disp >= 266.9 ~ 10.4, carb < 2.5 & drat >= 3.035 & disp >=
28-
266.9 ~ 18.7, wt < 1.989 & hp < 118 & hp >= 80.5 & disp <
29-
266.9 ~ 30.4, qsec < 18.6 & hp >= 118 & hp >= 80.5 & disp <
30-
266.9 ~ 19.6, qsec >= 18.6 & hp >= 118 & hp >= 80.5 & disp <
31-
266.9 ~ 17.8, drat >= 3.635 & carb >= 2.5 & drat >= 3.035 &
32-
disp >= 266.9 ~ 13.3, hp < 96 & wt >= 1.989 & hp < 118 &
33-
hp >= 80.5 & disp < 266.9 ~ 22.8, wt < 4.5625 & drat < 3.635 &
34-
carb >= 2.5 & drat >= 3.035 & disp >= 266.9 ~ 15.04, wt >=
35-
4.5625 & drat < 3.635 & carb >= 2.5 & drat >= 3.035 & disp >=
36-
266.9 ~ 14.7, qsec < 19.725 & hp >= 96 & wt >= 1.989 & hp <
37-
118 & hp >= 80.5 & disp < 266.9 ~ 21.4, qsec >= 19.725 &
38-
hp >= 96 & wt >= 1.989 & hp < 118 & hp >= 80.5 & disp < 266.9 ~
39-
21.5)
40-
41-
[[3]]
42-
case_when(wt < 3.4725 & drat < 3.75 ~ 20.8333333, qsec < 16.23 &
43-
wt >= 3.4725 & drat < 3.75 ~ 13.3, disp < 78.85 & disp <
44-
130.55 & drat >= 3.75 ~ 32.2333333, disp >= 78.85 & disp <
45-
130.55 & drat >= 3.75 ~ 27.75, cyl < 5 & disp >= 130.55 &
46-
drat >= 3.75 ~ 22.8, disp >= 456 & qsec >= 16.23 & wt >=
47-
3.4725 & drat < 3.75 ~ 10.4, qsec < 17.66 & cyl >= 5 & disp >=
48-
130.55 & drat >= 3.75 ~ 21, qsec >= 17.66 & cyl >= 5 & disp >=
49-
130.55 & drat >= 3.75 ~ 19.2, qsec < 17.225 & disp < 456 &
50-
qsec >= 16.23 & wt >= 3.4725 & drat < 3.75 ~ 19.2, wt < 4.7075 &
51-
qsec >= 17.225 & disp < 456 & qsec >= 16.23 & wt >= 3.4725 &
52-
drat < 3.75 ~ 16.34, wt >= 4.7075 & qsec >= 17.225 & disp <
53-
456 & qsec >= 16.23 & wt >= 3.4725 & drat < 3.75 ~ 14.7)
54-
55-
568
# formulas produces correct predictions
579

5810
Code

0 commit comments

Comments
 (0)