|
1 | 1 | __precompile__() |
2 | 2 |
|
3 | | -module DecisionTree |
| 3 | +module DecisionTree |
4 | 4 |
|
5 | 5 | import Base: length, show, convert, promote_rule, zero |
6 | 6 | using DelimitedFiles |
@@ -80,55 +80,61 @@ length(ensemble::Ensemble) = length(ensemble.trees) |
80 | 80 | depth(leaf::Leaf) = 0 |
81 | 81 | depth(tree::Node) = 1 + max(depth(tree.left), depth(tree.right)) |
82 | 82 |
|
| 83 | +function print_tree(io::IO, leaf::Leaf, depth=-1, indent=0; feature_names=nothing) |
| 84 | + n_matches = count(leaf.values .== leaf.majority) |
| 85 | + ratio = string(n_matches, "/", length(leaf.values)) |
| 86 | + println(io, "$(leaf.majority) : $(ratio)") |
| 87 | +end |
83 | 88 | function print_tree(leaf::Leaf, depth=-1, indent=0; feature_names=nothing) |
84 | | - matches = findall(leaf.values .== leaf.majority) |
85 | | - ratio = string(length(matches)) * "/" * string(length(leaf.values)) |
86 | | - println("$(leaf.majority) : $(ratio)") |
| 89 | + return print_tree(stdout, leaf, depth, indent; feature_names=feature_names) |
87 | 90 | end |
88 | 91 |
|
| 92 | + |
89 | 93 | """ |
90 | | - print_tree(tree::Node, depth=-1, indent=0; feature_names=nothing) |
91 | | -
|
92 | | -Print a textual visualization of the given decision tree `tree`. |
93 | | -In the example output below, the top node considers whether |
94 | | -"Feature 3" is above or below the threshold -28.156052806422238. |
95 | | -If the value of "Feature 3" is strictly below the threshold for some input to be classified, |
96 | | -we move to the `L->` part underneath, which is a node |
97 | | -looking at if "Feature 2" is above or below -161.04351901384842. |
98 | | -If the value of "Feature 2" is strictly below the threshold for some input to be classified, |
99 | | -we end up at `L-> 5 : 842/3650`. This is to be read as "In the left split, |
100 | | -the tree will classify the input as class 5, as 842 of the 3650 datapoints |
101 | | -in the training data that ended up here were of class 5." |
| 94 | + print_tree([io::IO,] tree::Node, depth=-1, indent=0; sigdigits=4, feature_names=nothing) |
| 95 | +
|
| 96 | +Print a textual visualization of the specified `tree`. For example, if |
| 97 | +for some input pattern the value of "Feature 3" is "-30" and the value |
| 98 | +of "Feature 2" is "100", then, according to the sample output below, |
| 99 | +the majority class prediction is 7. Moreover, one can see that of the |
| 100 | +10555 training samples that terminate at the same leaf as this input |
| 101 | +data, 2493 of these predict the majority class, leading to a |
| 102 | +probabilistic prediction for class 7 of `2493/10555`. Ratios for |
| 103 | +non-majority classes are not shown. |
102 | 104 |
|
103 | 105 | # Example output: |
104 | 106 | ``` |
105 | | -Feature 3, Threshold -28.156052806422238 |
106 | | -L-> Feature 2, Threshold -161.04351901384842 |
107 | | - L-> 5 : 842/3650 |
108 | | - R-> 7 : 2493/10555 |
109 | | -R-> Feature 7, Threshold 108.1408338577021 |
110 | | - L-> 2 : 2434/15287 |
111 | | - R-> 8 : 1227/3508 |
| 107 | +Feature 3 < -28.15 ? |
| 108 | +├─ Feature 2 < -161.0 ? |
| 109 | + ├─ 5 : 842/3650 |
| 110 | + └─ 7 : 2493/10555 |
| 111 | +└─ Feature 7 < 108.1 ? |
| 112 | + ├─ 2 : 2434/15287 |
| 113 | + └─ 8 : 1227/3508 |
112 | 114 | ``` |
113 | 115 |
|
114 | | -To facilitate visualisation of trees using third party packages, a `DecisionTree.Leaf` object or |
115 | | -`DecisionTree.Node` object can be wrapped to obtain a tree structure implementing the |
116 | | -AbstractTrees.jl interface. See [`wrap`](@ref)` for details. |
| 116 | +To facilitate visualisation of trees using third party packages, a `DecisionTree.Leaf` object or |
| 117 | +`DecisionTree.Node` object can be wrapped to obtain a tree structure implementing the |
| 118 | +AbstractTrees.jl interface. See [`wrap`](@ref)` for details. |
117 | 119 | """ |
118 | | -function print_tree(tree::Node, depth=-1, indent=0; feature_names=nothing) |
| 120 | +function print_tree(io::IO, tree::Node, depth=-1, indent=0; sigdigits=2, feature_names=nothing) |
119 | 121 | if depth == indent |
120 | | - println() |
| 122 | + println(io) |
121 | 123 | return |
122 | 124 | end |
| 125 | + featval = round(tree.featval; sigdigits=sigdigits) |
123 | 126 | if feature_names === nothing |
124 | | - println("Feature $(tree.featid), Threshold $(tree.featval)") |
| 127 | + println(io, "Feature $(tree.featid) < $featval ?") |
125 | 128 | else |
126 | | - println("Feature $(tree.featid): \"$(feature_names[tree.featid])\", Threshold $(tree.featval)") |
| 129 | + println(io, "Feature $(tree.featid): \"$(feature_names[tree.featid])\" < $featval ?") |
127 | 130 | end |
128 | | - print(" " ^ indent * "L-> ") |
129 | | - print_tree(tree.left, depth, indent + 1; feature_names = feature_names) |
130 | | - print(" " ^ indent * "R-> ") |
131 | | - print_tree(tree.right, depth, indent + 1; feature_names = feature_names) |
| 131 | + print(io, " " ^ indent * "├─ ") |
| 132 | + print_tree(io, tree.left, depth, indent + 1; feature_names=feature_names) |
| 133 | + print(io, " " ^ indent * "└─ ") |
| 134 | + print_tree(io, tree.right, depth, indent + 1; feature_names=feature_names) |
| 135 | +end |
| 136 | +function print_tree(tree::Node, depth=-1, indent=0; sigdigits=2, feature_names=nothing) |
| 137 | + return print_tree(stdout, tree, depth, indent; sigdigits=sigdigits, feature_names=feature_names) |
132 | 138 | end |
133 | 139 |
|
134 | 140 | function show(io::IO, leaf::Leaf) |
|
0 commit comments