Skip to content

Commit e5414d2

Browse files
Don't multiple by 1 or 0 in tree results (#162)
1 parent e8d30c6 commit e5414d2

File tree

3 files changed

+56
-0
lines changed

3 files changed

+56
-0
lines changed

NEWS.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020

2121
- linear models such as `lm()` and `glm()` now work with interactions created with `*` and `:`. (#74)
2222

23+
- Cubist rules will return simplified rules whenever possible to avoid multiplying by 0 and 1. (#152)
24+
2325
# tidypredict 0.5.1
2426

2527
- Exported a number of internal functions to be used in {orbital} package

R/tree.R

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,12 +110,24 @@ generate_tree_node <- function(node, calc_mode = "") {
110110
prediction,
111111
~ {
112112
if (.x$is_intercept) {
113+
if (.x$val == 0) {
114+
return(NULL)
115+
}
113116
return(expr(!!.x$val))
114117
} else if (.x$op == "multiply") {
118+
if (.x$val == 0) {
119+
return(NULL)
120+
}
121+
122+
if (.x$val == 1) {
123+
return(expr(!!as.name(.x$col)))
124+
}
125+
115126
return(expr_multiplication(as.name(.x$col), .x$val))
116127
}
117128
}
118129
)
130+
pl <- purrr::discard(pl, is.null)
119131
pl <- reduce_addition(pl)
120132
} else {
121133
if (is.list(prediction) && prediction[[1]]$is_intercept) {

tests/testthat/test-tree.R

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,48 @@ test_that("generate_tree_node() avoids ifelse if path is always TRUE (#143)", {
260260
)
261261
})
262262

263+
test_that("generate_tree_node() avoids multipliying with 0 and 1 (#152)", {
264+
node <- list(
265+
path = list(
266+
list(type = "conditional", col = "disp", val = 100, op = "more")
267+
),
268+
prediction = list(
269+
list(col = "(Intercept)", val = 0, op = "none", is_intercept = 1),
270+
list(col = "hp", val = 4, op = "multiply", is_intercept = 0),
271+
list(col = "drat", val = 2, op = "multiply", is_intercept = 0)
272+
)
273+
)
274+
275+
expect_identical(
276+
generate_tree_node(node, calc_mode = "ifelse"),
277+
quote(ifelse(disp > 100, hp * 4 + drat * 2, 0))
278+
)
279+
expect_identical(
280+
generate_tree_node(node, calc_mode = ""),
281+
quote(disp > 100 ~ hp * 4 + drat * 2)
282+
)
283+
284+
node <- list(
285+
path = list(
286+
list(type = "conditional", col = "disp", val = 100, op = "more")
287+
),
288+
prediction = list(
289+
list(col = "(Intercept)", val = 14, op = "none", is_intercept = 1),
290+
list(col = "hp", val = 1, op = "multiply", is_intercept = 0),
291+
list(col = "drat", val = 0, op = "multiply", is_intercept = 0)
292+
)
293+
)
294+
295+
expect_identical(
296+
generate_tree_node(node, calc_mode = "ifelse"),
297+
quote(ifelse(disp > 100, 14 + hp, 0))
298+
)
299+
expect_identical(
300+
generate_tree_node(node, calc_mode = ""),
301+
quote(disp > 100 ~ 14 + hp)
302+
)
303+
})
304+
263305
test_that("path_formulas() works", {
264306
expect_identical(
265307
path_formulas(

0 commit comments

Comments
 (0)