Skip to content

Commit a3398bf

Browse files
authored
Merge pull request #173 from rikhuijzer/rh/tree-digits
Round digits in `print_tree`
2 parents 63cb26a + 3ba1651 commit a3398bf

File tree

3 files changed

+39
-17
lines changed

3 files changed

+39
-17
lines changed

src/DecisionTree.jl

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -80,14 +80,18 @@ length(ensemble::Ensemble) = length(ensemble.trees)
8080
depth(leaf::Leaf) = 0
8181
depth(tree::Node) = 1 + max(depth(tree.left), depth(tree.right))
8282

83+
function print_tree(io::IO, leaf::Leaf, depth=-1, indent=0; feature_names=nothing)
84+
n_matches = count(leaf.values .== leaf.majority)
85+
ratio = string(n_matches, "/", length(leaf.values))
86+
println(io, "$(leaf.majority) : $(ratio)")
87+
end
8388
function print_tree(leaf::Leaf, depth=-1, indent=0; feature_names=nothing)
84-
matches = findall(leaf.values .== leaf.majority)
85-
ratio = string(length(matches)) * "/" * string(length(leaf.values))
86-
println("$(leaf.majority) : $(ratio)")
89+
return print_tree(stdout, leaf, depth, indent; feature_names=feature_names)
8790
end
8891

92+
8993
"""
90-
print_tree(tree::Node, depth=-1, indent=0; feature_names=nothing)
94+
print_tree([io::IO,] tree::Node, depth=-1, indent=0; sigdigits=4, feature_names=nothing)
9195
9296
Print a textual visualization of the specified `tree`. For example, if
9397
for some input pattern the value of "Feature 3" is "-30" and the value
@@ -100,7 +104,7 @@ non-majority classes are not shown.
100104
101105
# Example output:
102106
```
103-
Feature 3 < -28.15
107+
Feature 3 < -28.15 ?
104108
├─ Feature 2 < -161.0 ?
105109
├─ 5 : 842/3650
106110
└─ 7 : 2493/10555
@@ -113,20 +117,24 @@ To facilitate visualisation of trees using third party packages, a `DecisionTree
113117
`DecisionTree.Node` object can be wrapped to obtain a tree structure implementing the
114118
AbstractTrees.jl interface. See [`wrap`](@ref)` for details.
115119
"""
116-
function print_tree(tree::Node, depth=-1, indent=0; feature_names=nothing)
120+
function print_tree(io::IO, tree::Node, depth=-1, indent=0; sigdigits=2, feature_names=nothing)
117121
if depth == indent
118-
println()
122+
println(io)
119123
return
120124
end
125+
featval = round(tree.featval; sigdigits=sigdigits)
121126
if feature_names === nothing
122-
println("Feature $(tree.featid) < $(tree.featval)")
127+
println(io, "Feature $(tree.featid) < $featval ?")
123128
else
124-
println("Feature $(tree.featid): \"$(feature_names[tree.featid])\" < $(tree.featval)")
129+
println(io, "Feature $(tree.featid): \"$(feature_names[tree.featid])\" < $featval ?")
125130
end
126-
print(" " ^ indent * "├─ ")
127-
print_tree(tree.left, depth, indent + 1; feature_names = feature_names)
128-
print(" " ^ indent * "└─ ")
129-
print_tree(tree.right, depth, indent + 1; feature_names = feature_names)
131+
print(io, " " ^ indent * "├─ ")
132+
print_tree(io, tree.left, depth, indent + 1; feature_names=feature_names)
133+
print(io, " " ^ indent * "└─ ")
134+
print_tree(io, tree.right, depth, indent + 1; feature_names=feature_names)
135+
end
136+
function print_tree(tree::Node, depth=-1, indent=0; sigdigits=2, feature_names=nothing)
137+
return print_tree(stdout, tree, depth, indent; sigdigits=sigdigits, feature_names=feature_names)
130138
end
131139

132140
function show(io::IO, leaf::Leaf)

src/scikitlearnAPI.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -386,5 +386,7 @@ length(dt::DecisionTreeClassifier) = length(dt.root)
386386
length(dt::DecisionTreeRegressor) = length(dt.root)
387387

388388
print_tree(dt::DecisionTreeClassifier, depth=-1; kwargs...) = print_tree(dt.root, depth; kwargs...)
389-
print_tree(dt::DecisionTreeRegressor, depth=-1; kwargs...) = print_tree(dt.root, depth; kwargs...)
390-
print_tree(n::Nothing, depth=-1; kwargs...) = show(n)
389+
print_tree(io::IO, dt::DecisionTreeClassifier, depth=-1; kwargs...) = print_tree(io, dt.root, depth; kwargs...)
390+
print_tree(dt::DecisionTreeRegressor, depth=-1; kwargs...) = print_tree(dt.root, depth; kwargs...)
391+
print_tree(io::IO, dt::DecisionTreeRegressor, depth=-1; kwargs...) = print_tree(io, dt.root, depth; kwargs...)
392+
print_tree(n::Nothing, depth=-1; kwargs...) = show(n)

test/classification/random.jl

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
Random.seed!(16)
55

6-
n,m = 10^3, 5;
6+
n, m = 10^3, 5;
77
features = rand(n,m);
88
weights = rand(-1:1,m);
99
labels = round.(Int, features * weights);
@@ -15,7 +15,19 @@ preds = apply_tree(model, round.(Int, features))
1515
max_depth = 3
1616
model = build_tree(labels, features, 0, max_depth)
1717
@test depth(model) == max_depth
18-
print_tree(model, 3)
18+
19+
io = IOBuffer()
20+
print_tree(io, model, 3)
21+
text = String(take!(io))
22+
println()
23+
print(text)
24+
println()
25+
26+
# Read the regex as: many not arrow left followed by an arrow left, a space, some numbers and
27+
# a dot and a space and question mark.
28+
rx = r"[^<]*< [0-9\.]* ?"
29+
matches = eachmatch(rx, text)
30+
@test !isempty(matches)
1931

2032
model = build_tree(labels, features)
2133
preds = apply_tree(model, features)

0 commit comments

Comments
 (0)