1+ /*
2+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+ * or more contributor license agreements. Licensed under the Elastic License;
4+ * you may not use this file except in compliance with the Elastic License.
5+ */
6+
7+ package org .elasticsearch .xpack .ml .inference .tree ;
8+
9+ import org .elasticsearch .test .ESTestCase ;
10+
11+ import java .util .ArrayList ;
12+ import java .util .Arrays ;
13+ import java .util .List ;
14+
15+ import static org .hamcrest .Matchers .hasSize ;
16+
17+ public class TreeTests extends ESTestCase {
18+
19+ public static Tree buildRandomTree (int numFeatures , int depth ) {
20+
21+ Tree .TreeBuilder builder = Tree .TreeBuilder .newTreeBuilder ();
22+
23+ Tree .Node node = builder .addJunction (0 , randomFeatureIndex (numFeatures ), true , randomDecisionThreshold ());
24+ List <Integer > childNodes = List .of (node .leftChild , node .rightChild );
25+
26+ for (int i =0 ; i <depth -1 ; i ++) {
27+
28+ List <Integer > nextNodes = new ArrayList <>();
29+ for (int nodeId : childNodes ) {
30+
31+ if (i == depth -2 ) {
32+ builder .addLeaf (nodeId , randomDecisionThreshold ());
33+ } else {
34+ Tree .Node childNode = builder .addJunction (nodeId , randomFeatureIndex (numFeatures ), true , randomDecisionThreshold ());
35+ nextNodes .add (childNode .leftChild );
36+ nextNodes .add (childNode .rightChild );
37+ }
38+ }
39+
40+ childNodes = nextNodes ;
41+ }
42+
43+ return builder .build ();
44+ }
45+
46+ static int randomFeatureIndex (int max ) {
47+ return randomIntBetween (0 , max -1 );
48+ }
49+
50+ static double randomDecisionThreshold () {
51+ return randomDouble ();
52+ }
53+
54+ public void testPredict () {
55+ // Build a tree with 2 nodes and 3 leaves using 2 features
56+ // The leaves have unique values 0.1, 0.2, 0.3
57+ Tree .TreeBuilder builder = Tree .TreeBuilder .newTreeBuilder ();
58+ Tree .Node rootNode = builder .addJunction (0 , 0 , true , 0.5 );
59+ builder .addLeaf (rootNode .rightChild , 0.3 );
60+ Tree .Node leftChildNode = builder .addJunction (rootNode .leftChild , 1 , true , 0.8 );
61+ builder .addLeaf (leftChildNode .leftChild , 0.1 );
62+ builder .addLeaf (leftChildNode .rightChild , 0.2 );
63+
64+ Tree tree = builder .build ();
65+
66+ // This feature vector should hit the right child of the root node
67+ List <Double > featureVector = Arrays .asList (0.6 , 0.0 );
68+ assertEquals (0.3 , tree .predict (featureVector ), 0.00001 );
69+
70+ // This should hit the left child of the left child of the root node
71+ // i.e. it takes the path left, left
72+ featureVector = Arrays .asList (0.3 , 0.7 );
73+ assertEquals (0.1 , tree .predict (featureVector ), 0.00001 );
74+
75+ // This should hit the right child of the left child of the root node
76+ // i.e. it takes the path left, right
77+ featureVector = Arrays .asList (0.3 , 0.8 );
78+ assertEquals (0.2 , tree .predict (featureVector ), 0.00001 );
79+ }
80+
81+ public void testTrace () {
82+ int numFeatures = randomIntBetween (1 , 6 );
83+ int depth = 6 ;
84+ Tree tree = buildRandomTree (numFeatures , depth );
85+
86+ List <Double > features = new ArrayList <>(numFeatures );
87+ for (int i =0 ; i <numFeatures ; i ++) {
88+ features .add (randomDecisionThreshold ());
89+ }
90+
91+ List <Tree .Node > trace = tree .trace (features );
92+ assertThat (trace , hasSize (depth ));
93+ for (int i =0 ; i <trace .size () -2 ; i ++) {
94+ assertFalse (trace .get (i ).isLeaf ());
95+ }
96+ assertTrue (trace .get (trace .size () -1 ).isLeaf ());
97+
98+ double prediction = tree .predict (features );
99+ assertEquals (trace .get (trace .size () -1 ).value (), prediction , 0.0001 );
100+
101+ // Because the tree is built up breadth first we can figure out o
102+ // a node's id from its child nodes. Then we can trace the route
103+ // taken an assert it's the branch decisions were correct
104+
105+ int expectedNodeId = 0 ;
106+ for (Tree .Node visitedNode : trace ) {
107+ if (visitedNode .isLeaf () == false ) {
108+ // Imagine the nodes array is 1 based index. The root node
109+ // has index 1, its children 2 & 3. Because the tree is built
110+ // breadth first node 2 children will be at indexes 4 & 5 and
111+ // node 3 children are at 6 & 7.
112+ // So a nodes children are at nodeindex * 2 and nodeindex * 2 +1
113+ // and the parent is at nodeindex / 2.
114+ // The +/- 1's are adjusting for a 0 based index
115+ int nodeId = ((visitedNode .leftChild + 1 ) / 2 ) - 1 ;
116+ assertEquals (expectedNodeId , nodeId );
117+
118+ // unfortunately this doesn't apply to leaf nodes
119+ // as their children are -1
120+
121+ expectedNodeId = visitedNode .compare (features );
122+ } else {
123+ assertEquals (prediction , visitedNode .value (), 0.0001 );
124+ }
125+ }
126+
127+ assertThat (tree .missingNodes (), hasSize (0 ));
128+ }
129+
130+ public void testCompare () {
131+ int leftChild = 1 ;
132+ int rightChild = 2 ;
133+ Tree .Node node = new Tree .Node (leftChild , rightChild , 0 , true , 0.5 );
134+
135+ List <Double > features = List .of (0.1 );
136+ assertEquals (leftChild , node .compare (features ));
137+
138+ features = List .of (0.9 );
139+ assertEquals (rightChild , node .compare (features ));
140+ }
141+
142+ public void testCompare_nonDefaultOperator () {
143+ int leftChild = 1 ;
144+ int rightChild = 2 ;
145+ Tree .Node node = new Tree .Node (leftChild , rightChild , 0 , true , 0.5 , (value , threshold ) -> value >= threshold );
146+
147+ List <Double > features = List .of (0.1 );
148+ assertEquals (rightChild , node .compare (features ));
149+ features = List .of (0.5 );
150+ assertEquals (leftChild , node .compare (features ));
151+ features = List .of (0.9 );
152+ assertEquals (leftChild , node .compare (features ));
153+
154+ node = new Tree .Node (leftChild , rightChild , 0 , true , 0.5 , (value , threshold ) -> value <= threshold );
155+
156+ features = List .of (0.1 );
157+ assertEquals (leftChild , node .compare (features ));
158+ features = List .of (0.5 );
159+ assertEquals (leftChild , node .compare (features ));
160+ features = List .of (0.9 );
161+ assertEquals (rightChild , node .compare (features ));
162+ }
163+
164+ public void testCompare_missingFeature () {
165+ int leftChild = 1 ;
166+ int rightChild = 2 ;
167+ Tree .Node leftBiasNode = new Tree .Node (leftChild , rightChild , 0 , true , 0.5 );
168+ List <Double > features = new ArrayList <>();
169+ features .add (null );
170+ assertEquals (leftChild , leftBiasNode .compare (features ));
171+
172+ Tree .Node rightBiasNode = new Tree .Node (leftChild , rightChild , 0 , false , 0.5 );
173+ assertEquals (rightChild , rightBiasNode .compare (features ));
174+ }
175+
176+ public void testIsLeaf () {
177+ Tree .Node leaf = new Tree .Node (0.0 );
178+ assertTrue (leaf .isLeaf ());
179+
180+ Tree .Node node = new Tree .Node (1 , 2 , 0 , false , 0.0 );
181+ assertFalse (node .isLeaf ());
182+ }
183+
184+ public void testMissingNodes () {
185+ Tree .TreeBuilder builder = Tree .TreeBuilder .newTreeBuilder ();
186+ Tree .Node rootNode = builder .addJunction (0 , 0 , true , randomDecisionThreshold ());
187+
188+ Tree .Node node2 = builder .addJunction (rootNode .rightChild , 0 , false , 0.1 );
189+ builder .addLeaf (node2 .leftChild , 0.1 );
190+
191+ List <Integer > missingNodeIndexes = builder .build ().missingNodes ();
192+ assertEquals (Arrays .asList (1 , 4 ), missingNodeIndexes );
193+ }
194+ }
0 commit comments