Skip to content

Commit 52007b1

Browse files
authored
Merge pull request #158 from roland-KA/abstract-tree
Add implementation of AbstractTrees-interface
2 parents 17cb46d + bbfa757 commit 52007b1

File tree

8 files changed

+234
-62
lines changed

8 files changed

+234
-62
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,4 @@
22
styleguide.txt
33
makefile
44
.DS_Store
5+
Manifest.toml

Manifest.toml

Lines changed: 0 additions & 61 deletions
This file was deleted.

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ desc = "Julia implementation of Decision Tree (CART) and Random Forest algorithm
55
version = "0.10.11"
66

77
[deps]
8+
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
89
DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab"
910
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
1011
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

README.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -319,3 +319,12 @@ using JLD2
319319
@save "model_file.jld2" model
320320
```
321321
Note that even though features and labels of type `Array{Any}` are supported, it is highly recommended that data be cast to explicit types (ie with `float.(), string.()`, etc). This significantly improves model training and prediction execution times, and also drastically reduces the size of saved models.
322+
323+
## Visualization
324+
A `DecisionTree` model can be visualized using the `print_tree`-function of its native interface
325+
(for an example see above in section 'Classification Example').
326+
327+
In addition, an abstraction layer using `AbstractTrees.jl` has been implemented with the intention to facilitate visualizations, which don't rely on any implementation details of `DecisionTree`. For more information have a look at the docs in `src/abstract_trees.jl` and the [`wrap`](@ref)-function, which creates this layer for a `DecisionTree` model.
328+
329+
Apart from this, `AbstractTrees.jl` brings its own implementation of `print_tree`.
330+

src/DecisionTree.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ using DelimitedFiles
77
using LinearAlgebra
88
using Random
99
using Statistics
10+
import AbstractTrees
1011

1112
export Leaf, Node, Ensemble, print_tree, depth, build_stump, build_tree,
1213
prune_tree, apply_tree, apply_tree_proba, nfoldCV_tree, build_forest,
@@ -22,6 +23,7 @@ export DecisionTreeClassifier, DecisionTreeRegressor, RandomForestClassifier,
2223
# `using ScikitLearnBase`.
2324
predict, predict_proba, fit!, get_classes
2425

26+
export InfoNode, InfoLeaf, wrap
2527

2628
###########################
2729
########## Types ##########
@@ -65,6 +67,7 @@ include("util.jl")
6567
include("classification/main.jl")
6668
include("regression/main.jl")
6769
include("scikitlearnAPI.jl")
70+
include("abstract_trees.jl")
6871

6972

7073
#############################
@@ -107,6 +110,10 @@ R-> Feature 7, Threshold 108.1408338577021
107110
L-> 2 : 2434/15287
108111
R-> 8 : 1227/3508
109112
```
113+
114+
To facilitate visualisation of trees using third party packages, a `DecisionTree.Leaf` object or
115+
`DecisionTree.Node` object can be wrapped to obtain a tree structure implementing the
116+
AbstractTrees.jl interface. See [`wrap`](@ref)` for details.
110117
"""
111118
function print_tree(tree::Node, depth=-1, indent=0; feature_names=nothing)
112119
if depth == indent

src/abstract_trees.jl

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
"""
2+
Implementation of the `AbstractTrees.jl`-interface
3+
(see: [AbstractTrees.jl](https://github.com/JuliaCollections/AbstractTrees.jl)).
4+
5+
The functions `children` and `printnode` make up the interface traits of `AbstractTrees.jl`
6+
(see below for details).
7+
8+
The goal of this implementation is to wrap a `DecisionTree` in this abstract layer,
9+
so that a plot recipe for visualization of the tree can be created that doesn't rely
10+
on any implementation details of `DecisionTree.jl`. That opens the possibility to create
11+
a plot recipe which can be used by a variety of tree-like models.
12+
13+
For a more detailed explanation of this concept have a look at the follwing article
14+
in "Towards Data Science":
15+
["If things are not ready to use"](https://towardsdatascience.com/part-iii-if-things-are-not-ready-to-use-59d2db378bec)
16+
"""
17+
18+
19+
"""
20+
InfoNode{S, T}
21+
InfoLeaf{T}
22+
23+
These types are introduced so that additional information currently not present in
24+
a `DecisionTree`-structure -- namely the feature names and the class labels --
25+
can be used for visualization. This additional information is stored in the attribute `info` of
26+
these types. It is a `NamedTuple`. So it can be used to store arbitraty information,
27+
apart from the two points mentioned.
28+
29+
In analogy to the type definitions of `DecisionTree`, the generic type `S` is
30+
the type of the feature values used within a node as a threshold for the splits
31+
between its children and `T` is the type of the classes given (these might be ids or labels).
32+
"""
33+
struct InfoNode{S, T}
34+
node :: DecisionTree.Node{S, T}
35+
info :: NamedTuple
36+
end
37+
38+
struct InfoLeaf{T}
39+
leaf :: DecisionTree.Leaf{T}
40+
info :: NamedTuple
41+
end
42+
43+
"""
44+
wrap(node::DecisionTree.Node, info = NamedTuple())
45+
wrap(leaf::DecisionTree.Leaf, info = NamedTuple())
46+
47+
Add to each `node` (or `leaf`) the additional information `info`
48+
and wrap both in an `InfoNode`/`InfoLeaf`.
49+
50+
Typically a `node` or a `leaf` is obtained by creating a decision tree using either
51+
the native interface of `DecisionTree.jl` or via other interfaces which are available
52+
for this package (like `MLJ`, ScikitLearn; see their docs for further details).
53+
Using the function `build_tree` of the native interface returns such an object.
54+
55+
To use a DecisionTree `dc` (obtained this way) with the abstraction layer
56+
provided by the `AbstractTrees`-interface implemented here
57+
and optionally add feature names `feature_names` and/or `class_labels`
58+
(both: arrays of strings) use the following syntax:
59+
60+
1. `wdc = wrap(dc)`
61+
2. `wdc = wrap(dc, (featurenames = feature_names, classlabels = class_labels))`
62+
3. `wdc = wrap(dc, (featurenames = feature_names, ))`
63+
4. `wdc = wrap(dc, (classlabels = class_labels, ))`
64+
65+
In the first case `dc` gets just wrapped, no information is added. No. 2 adds feature names
66+
as well as class labels. In the last two cases either of this information is added (Note the
67+
trailing comma; it's needed to make it a tuple).
68+
"""
69+
wrap(node::DecisionTree.Node, info::NamedTuple = NamedTuple()) = InfoNode(node, info)
70+
wrap(leaf::DecisionTree.Leaf, info::NamedTuple = NamedTuple()) = InfoLeaf(leaf, info)
71+
72+
"""
73+
children(node::InfoNode)
74+
75+
Return for each `node` given, its children.
76+
77+
In case of a `DecisionTree` there are always exactly two children, because
78+
the model produces binary trees where all nodes have exactly one left and
79+
one right child. `children` is used for tree traversal.
80+
81+
The additional information `info` is carried over from `node` to its children.
82+
"""
83+
AbstractTrees.children(node::InfoNode) = (
84+
wrap(node.node.left, node.info),
85+
wrap(node.node.right, node.info)
86+
)
87+
AbstractTrees.children(node::InfoLeaf) = ()
88+
89+
"""
90+
printnode(io::IO, node::InfoNode)
91+
printnode(io::IO, leaf::InfoLeaf)
92+
93+
Write a printable representation of `node` or `leaf` to output-stream `io`.
94+
95+
If `node.info`/`leaf.info` have a field called
96+
- `featurenames` it is expected to have an array of feature names corresponding
97+
to the feature ids used in the `DecsionTree`s nodes.
98+
They will be used for printing instead of the ids.
99+
- `classlabels` it is expected to have an array of class labels corresponding
100+
to the class ids used in the `DecisionTree`s leaves.
101+
They will be used for printing instead of the ids.
102+
(Note: DecisionTrees created using MLJ use ids in their leaves;
103+
otherwise class labels are present)
104+
105+
For the condition of the form `feature < value` which gets printed in the `printnode`
106+
variant for `InfoNode`, the left subtree is the 'yes-branch' and the right subtree
107+
accordingly the 'no-branch'. `AbstractTrees.print_tree` outputs the left subtree first
108+
and then below the right subtree.
109+
"""
110+
function AbstractTrees.printnode(io::IO, node::InfoNode)
111+
if :featurenames keys(node.info)
112+
print(io, node.info.featurenames[node.node.featid], " < ", node.node.featval)
113+
else
114+
print(io, "Feature: ", node.node.featid, " < ", node.node.featval)
115+
end
116+
end
117+
118+
function AbstractTrees.printnode(io::IO, leaf::InfoLeaf)
119+
dt_leaf = leaf.leaf
120+
matches = findall(dt_leaf.values .== dt_leaf.majority)
121+
match_count = length(matches)
122+
val_count = length(dt_leaf.values)
123+
if :classlabels keys(leaf.info)
124+
print(io, leaf.info.classlabels[dt_leaf.majority], " ($match_count/$val_count)")
125+
else
126+
print(io, "Class: ", dt_leaf.majority, " ($match_count/$val_count)")
127+
end
128+
end
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
# Test `AbstractTrees`-interface
2+
3+
@testset "abstract_trees_test.jl" begin
4+
5+
# CAVEAT: These tests rely heavily on the texts generated in `printnode`.
6+
# After changes in `printnode` the following `*pattern`s might be adapted.
7+
8+
### Some content-checking helpers
9+
# if no feature names or class labels are given, the following keywords must be present
10+
featid_pattern = "Feature: " # feature ids are prepended by this text
11+
classid_pattern = "Class: " # `Leaf.majority` is prepended by this text
12+
# if feature names and class labels are given, they can be identified within the tree using these patterns
13+
fname_pattern(fname) = fname * " <" # feature names are followed by " <"
14+
clabel_pattern(clabel) = "" * clabel * " (" # class labels are embedded in "─ " and " ("
15+
16+
# occur all elements of `pool` in the form defined by `fname_/clabel_pattern` in `str_tree`?
17+
check_occurence(str_tree, pool, pattern) = count(map(elem -> occursin(pattern(elem), str_tree), pool)) == length(pool)
18+
19+
@info("Test base functionality")
20+
l1 = Leaf(1, [1,1,2])
21+
l2 = Leaf(2, [1,2,2])
22+
l3 = Leaf(3, [3,3,1])
23+
n2 = Node(2, 0.5, l2, l3)
24+
n1 = Node(1, 0.7, l1, n2)
25+
feature_names = ["firstFt", "secondFt"]
26+
class_labels = ["a", "b", "c"]
27+
28+
infotree1 = wrap(n1, (featurenames = feature_names, classlabels = class_labels))
29+
infotree2 = wrap(n1, (featurenames = feature_names,))
30+
infotree3 = wrap(n1, (classlabels = class_labels,))
31+
infotree4 = wrap(n1, (x = feature_names, y = class_labels))
32+
infotree5 = wrap(n1)
33+
34+
@info(" -- Tree with feature names and class labels")
35+
AbstractTrees.print_tree(infotree1)
36+
rep1 = AbstractTrees.repr_tree(infotree1)
37+
@test check_occurence(rep1, feature_names, fname_pattern)
38+
@test check_occurence(rep1, class_labels, clabel_pattern)
39+
40+
@info(" -- Tree with feature names")
41+
AbstractTrees.print_tree(infotree2)
42+
rep2 = AbstractTrees.repr_tree(infotree2)
43+
@test check_occurence(rep2, feature_names, fname_pattern)
44+
@test occursin(classid_pattern, rep2)
45+
46+
@info(" -- Tree with class labels")
47+
AbstractTrees.print_tree(infotree3)
48+
rep3 = AbstractTrees.repr_tree(infotree3)
49+
@test occursin(featid_pattern, rep3)
50+
@test check_occurence(rep3, class_labels, clabel_pattern)
51+
52+
@info(" -- Tree with ids only (nonsense parameters)")
53+
AbstractTrees.print_tree(infotree4)
54+
rep4 = AbstractTrees.repr_tree(infotree4)
55+
@test occursin(featid_pattern, rep4)
56+
@test occursin(classid_pattern, rep4)
57+
58+
@info(" -- Tree with ids only")
59+
AbstractTrees.print_tree(infotree5)
60+
rep5 = AbstractTrees.repr_tree(infotree5)
61+
@test occursin(featid_pattern, rep5)
62+
@test occursin(classid_pattern, rep5)
63+
64+
@info("Test `children` with 'adult' decision tree")
65+
@info(" -- Preparing test data")
66+
features, labels = load_data("adult")
67+
feature_names_adult = ["age", "workclass", "fnlwgt", "education", "education-num", "marital-status", "occupation",
68+
"relationship", "race", "sex", "capital-gain", "capital-loss", "hours-per-week", "native-country"]
69+
model = build_tree(labels, features)
70+
wrapped_tree = wrap(model, (featurenames = feature_names_adult,))
71+
72+
@info(" -- Test `children`")
73+
function traverse_tree(node::InfoNode)
74+
l, r = AbstractTrees.children(node)
75+
@test l.info == node.info
76+
@test r.info == node.info
77+
traverse_tree(l)
78+
traverse_tree(r)
79+
end
80+
81+
traverse_tree(leaf::InfoLeaf) = nothing
82+
83+
traverse_tree(wrapped_tree)
84+
end

test/runtests.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ using DelimitedFiles
33
using Random
44
using ScikitLearnBase
55
using Statistics
6+
import AbstractTrees
67
using Test
78

89
println("Julia version: ", VERSION)
@@ -33,8 +34,10 @@ regression = [
3334
]
3435

3536
miscellaneous = [
36-
"miscellaneous/convert.jl"
37+
"miscellaneous/convert.jl",
38+
"miscellaneous/abstract_trees_test.jl"
3739
# "miscellaneous/parallel.jl"
40+
3841
]
3942

4043
test_suites = [

0 commit comments

Comments
 (0)