|
| 1 | +""" |
| 2 | +Implementation of the `AbstractTrees.jl`-interface |
| 3 | +(see: [AbstractTrees.jl](https://github.com/JuliaCollections/AbstractTrees.jl)). |
| 4 | +
|
| 5 | +The functions `children` and `printnode` make up the interface traits of `AbstractTrees.jl` |
| 6 | +(see below for details). |
| 7 | +
|
| 8 | +The goal of this implementation is to wrap a `DecisionTree` in this abstract layer, |
| 9 | +so that a plot recipe for visualization of the tree can be created that doesn't rely |
| 10 | +on any implementation details of `DecisionTree.jl`. That opens the possibility to create |
| 11 | +a plot recipe which can be used by a variety of tree-like models. |
| 12 | +
|
| 13 | +For a more detailed explanation of this concept have a look at the follwing article |
| 14 | +in "Towards Data Science": |
| 15 | +["If things are not ready to use"](https://towardsdatascience.com/part-iii-if-things-are-not-ready-to-use-59d2db378bec) |
| 16 | +""" |
| 17 | + |
| 18 | + |
| 19 | +""" |
| 20 | + InfoNode{S, T} |
| 21 | + InfoLeaf{T} |
| 22 | +
|
| 23 | +These types are introduced so that additional information currently not present in |
| 24 | +a `DecisionTree`-structure -- namely the feature names and the class labels -- |
| 25 | +can be used for visualization. This additional information is stored in the attribute `info` of |
| 26 | +these types. It is a `NamedTuple`. So it can be used to store arbitraty information, |
| 27 | +apart from the two points mentioned. |
| 28 | +
|
| 29 | +In analogy to the type definitions of `DecisionTree`, the generic type `S` is |
| 30 | +the type of the feature values used within a node as a threshold for the splits |
| 31 | +between its children and `T` is the type of the classes given (these might be ids or labels). |
| 32 | +""" |
| 33 | +struct InfoNode{S, T} |
| 34 | + node :: DecisionTree.Node{S, T} |
| 35 | + info :: NamedTuple |
| 36 | +end |
| 37 | + |
| 38 | +struct InfoLeaf{T} |
| 39 | + leaf :: DecisionTree.Leaf{T} |
| 40 | + info :: NamedTuple |
| 41 | +end |
| 42 | + |
| 43 | +""" |
| 44 | + wrap(node::DecisionTree.Node, info = NamedTuple()) |
| 45 | + wrap(leaf::DecisionTree.Leaf, info = NamedTuple()) |
| 46 | +
|
| 47 | +Add to each `node` (or `leaf`) the additional information `info` |
| 48 | +and wrap both in an `InfoNode`/`InfoLeaf`. |
| 49 | +
|
| 50 | +Typically a `node` or a `leaf` is obtained by creating a decision tree using either |
| 51 | +the native interface of `DecisionTree.jl` or via other interfaces which are available |
| 52 | +for this package (like `MLJ`, ScikitLearn; see their docs for further details). |
| 53 | +Using the function `build_tree` of the native interface returns such an object. |
| 54 | +
|
| 55 | +To use a DecisionTree `dc` (obtained this way) with the abstraction layer |
| 56 | +provided by the `AbstractTrees`-interface implemented here |
| 57 | +and optionally add feature names `feature_names` and/or `class_labels` |
| 58 | +(both: arrays of strings) use the following syntax: |
| 59 | +
|
| 60 | +1. `wdc = wrap(dc)` |
| 61 | +2. `wdc = wrap(dc, (featurenames = feature_names, classlabels = class_labels))` |
| 62 | +3. `wdc = wrap(dc, (featurenames = feature_names, ))` |
| 63 | +4. `wdc = wrap(dc, (classlabels = class_labels, ))` |
| 64 | +
|
| 65 | +In the first case `dc` gets just wrapped, no information is added. No. 2 adds feature names |
| 66 | +as well as class labels. In the last two cases either of this information is added (Note the |
| 67 | +trailing comma; it's needed to make it a tuple). |
| 68 | +""" |
| 69 | +wrap(node::DecisionTree.Node, info::NamedTuple = NamedTuple()) = InfoNode(node, info) |
| 70 | +wrap(leaf::DecisionTree.Leaf, info::NamedTuple = NamedTuple()) = InfoLeaf(leaf, info) |
| 71 | + |
| 72 | +""" |
| 73 | + children(node::InfoNode) |
| 74 | +
|
| 75 | +Return for each `node` given, its children. |
| 76 | + |
| 77 | +In case of a `DecisionTree` there are always exactly two children, because |
| 78 | +the model produces binary trees where all nodes have exactly one left and |
| 79 | +one right child. `children` is used for tree traversal. |
| 80 | +
|
| 81 | +The additional information `info` is carried over from `node` to its children. |
| 82 | +""" |
| 83 | +AbstractTrees.children(node::InfoNode) = ( |
| 84 | + wrap(node.node.left, node.info), |
| 85 | + wrap(node.node.right, node.info) |
| 86 | +) |
| 87 | +AbstractTrees.children(node::InfoLeaf) = () |
| 88 | + |
| 89 | +""" |
| 90 | + printnode(io::IO, node::InfoNode) |
| 91 | + printnode(io::IO, leaf::InfoLeaf) |
| 92 | +
|
| 93 | +Write a printable representation of `node` or `leaf` to output-stream `io`. |
| 94 | +
|
| 95 | +If `node.info`/`leaf.info` have a field called |
| 96 | +- `featurenames` it is expected to have an array of feature names corresponding |
| 97 | + to the feature ids used in the `DecsionTree`s nodes. |
| 98 | + They will be used for printing instead of the ids. |
| 99 | +- `classlabels` it is expected to have an array of class labels corresponding |
| 100 | + to the class ids used in the `DecisionTree`s leaves. |
| 101 | + They will be used for printing instead of the ids. |
| 102 | + (Note: DecisionTrees created using MLJ use ids in their leaves; |
| 103 | + otherwise class labels are present) |
| 104 | +
|
| 105 | +For the condition of the form `feature < value` which gets printed in the `printnode` |
| 106 | +variant for `InfoNode`, the left subtree is the 'yes-branch' and the right subtree |
| 107 | +accordingly the 'no-branch'. `AbstractTrees.print_tree` outputs the left subtree first |
| 108 | +and then below the right subtree. |
| 109 | +""" |
| 110 | +function AbstractTrees.printnode(io::IO, node::InfoNode) |
| 111 | + if :featurenames ∈ keys(node.info) |
| 112 | + print(io, node.info.featurenames[node.node.featid], " < ", node.node.featval) |
| 113 | + else |
| 114 | + print(io, "Feature: ", node.node.featid, " < ", node.node.featval) |
| 115 | + end |
| 116 | +end |
| 117 | + |
| 118 | +function AbstractTrees.printnode(io::IO, leaf::InfoLeaf) |
| 119 | + dt_leaf = leaf.leaf |
| 120 | + matches = findall(dt_leaf.values .== dt_leaf.majority) |
| 121 | + match_count = length(matches) |
| 122 | + val_count = length(dt_leaf.values) |
| 123 | + if :classlabels ∈ keys(leaf.info) |
| 124 | + print(io, leaf.info.classlabels[dt_leaf.majority], " ($match_count/$val_count)") |
| 125 | + else |
| 126 | + print(io, "Class: ", dt_leaf.majority, " ($match_count/$val_count)") |
| 127 | + end |
| 128 | +end |
0 commit comments