@@ -31,17 +31,16 @@ def _get_tree_info(X, tree_model, target_names, target_colors, color_map):
3131 :return:
3232 dictionary of useful information
3333 '''
34- # classify features into 3 types: binary, float and int
35- binary_features = []
36- for col in X .columns .values :
37- if list (sorted (np .unique (X [col ].values ))) == [0 , 1 ]:
38- binary_features .append (col )
39-
40- int_features = []
41- for col in list (set (X .columns .values ) - set (binary_features )):
42- if list (X [col ].map (int ).values ) == list (X [col ].values ):
43- int_features .append (col )
44-
34+ binary_features = [
35+ col
36+ for col in X .columns .values
37+ if list (sorted (np .unique (X [col ].values ))) == [0 , 1 ]
38+ ]
39+ int_features = [
40+ col
41+ for col in list (set (X .columns .values ) - set (binary_features ))
42+ if list (X [col ].map (int ).values ) == list (X [col ].values )
43+ ]
4544 # get feature names
4645 feature_names = X .columns .values
4746
@@ -51,23 +50,19 @@ def _get_tree_info(X, tree_model, target_names, target_colors, color_map):
5150
5251 # color mapping for targets
5352 if target_colors is None :
54- if color_map is not None :
55- cm = plt .get_cmap (color_map )
56- else :
57- cm = plt .get_cmap ('tab20' )
58- target_colors = []
59- for n in range (tree_model .tree_ .n_classes [0 ]):
60- target_colors .append (str (matplotlib .colors .rgb2hex (cm (n + 1 ))))
61-
62- tree_info = {
53+ cm = plt .get_cmap ('tab20' ) if color_map is None else plt .get_cmap (color_map )
54+ target_colors = [
55+ str (matplotlib .colors .rgb2hex (cm (n + 1 )))
56+ for n in range (tree_model .tree_ .n_classes [0 ])
57+ ]
58+ return {
6359 'tree_model' : tree_model ,
6460 'features' : [feature_names [i ] for i in tree_model .tree_ .feature ],
6561 'binary_features' : binary_features ,
6662 'int_features' : int_features ,
6763 'target_names' : target_names ,
68- 'target_colors' : target_colors
64+ 'target_colors' : target_colors ,
6965 }
70- return tree_info
7166
7267
7368def _parse_tree (node_id , parent , pos , tree_info ):
@@ -86,30 +81,33 @@ def _parse_tree(node_id, parent, pos, tree_info):
8681 complete tree structure
8782 '''
8883 tree_model = tree_info ['tree_model' ]
89- features = tree_info ['features' ]
90- binary_features = tree_info ['binary_features' ]
91- int_features = tree_info ['int_features' ]
9284 target_names = tree_info ['target_names' ]
9385
9486 node = {}
9587 if parent == 'null' :
9688 node ['name' ] = "HEAD"
9789 else :
90+ features = tree_info ['features' ]
9891 feature = features [parent ]
92+ binary_features = tree_info ['binary_features' ]
93+ int_features = tree_info ['int_features' ]
9994 if pos == 'left' :
10095 if feature in binary_features :
101- node ['name' ] = feature + ' : 0'
96+ node ['name' ] = f' { feature } : 0'
10297 elif feature in int_features :
103- node ['name' ] = feature + " <= " + str ( int (tree_model .tree_ .threshold [parent ]))
98+ node ['name' ] = f" { feature } <= { int (tree_model .tree_ .threshold [parent ])} "
10499 else :
105- node ['name' ] = feature + " <= " + str (round (tree_model .tree_ .threshold [parent ], 3 ))
100+ node [
101+ 'name'
102+ ] = f"{ feature } <= { str (round (tree_model .tree_ .threshold [parent ], 3 ))} "
103+ elif feature in binary_features :
104+ node ['name' ] = f'{ feature } : 1'
105+ elif feature in int_features :
106+ node ['name' ] = f"{ feature } > { int (tree_model .tree_ .threshold [parent ])} "
106107 else :
107- if feature in binary_features :
108- node ['name' ] = feature + ': 1'
109- elif feature in int_features :
110- node ['name' ] = feature + " > " + str (int (tree_model .tree_ .threshold [parent ]))
111- else :
112- node ['name' ] = feature + " > " + str (round (tree_model .tree_ .threshold [parent ], 3 ))
108+ node [
109+ 'name'
110+ ] = f"{ feature } > { str (round (tree_model .tree_ .threshold [parent ], 3 ))} "
113111 try :
114112 node ['parent' ] = int (parent )
115113 except :
@@ -125,12 +123,12 @@ def _parse_tree(node_id, parent, pos, tree_info):
125123
126124 if tree_model .tree_ .children_left [node_id ] != - 1 or tree_model .tree_ .children_right [node_id ] != - 1 :
127125 node ['children' ] = []
128- if tree_model .tree_ .children_left [node_id ] != - 1 :
129- child = tree_model .tree_ .children_left [node_id ]
130- node ['children' ].append (_parse_tree (child , node_id , 'left' , tree_info ))
131- if tree_model .tree_ .children_right [node_id ] != - 1 :
132- child = tree_model .tree_ .children_right [node_id ]
133- node ['children' ].append (_parse_tree (child , node_id , 'right' , tree_info ))
126+ if tree_model .tree_ .children_left [node_id ] != - 1 :
127+ child = tree_model .tree_ .children_left [node_id ]
128+ node ['children' ].append (_parse_tree (child , node_id , 'left' , tree_info ))
129+ if tree_model .tree_ .children_right [node_id ] != - 1 :
130+ child = tree_model .tree_ .children_right [node_id ]
131+ node ['children' ].append (_parse_tree (child , node_id , 'right' , tree_info ))
134132 return node
135133
136134
@@ -154,9 +152,7 @@ def _extract_rules(node_id, parent, pos, tree_rules, tree_info):
154152 features = tree_info ['features' ]
155153 tree_model = tree_info ['tree_model' ]
156154
157- tree_rules [node_id ] = {}
158- tree_rules [node_id ]['features' ] = {}
159-
155+ tree_rules [node_id ] = {'features' : {}}
160156 if parent != "null" :
161157 previous = copy .deepcopy (tree_rules [parent ]['features' ])
162158 tree_rules [node_id ]['features' ] = previous
@@ -202,24 +198,20 @@ def _clean_rules(tree_rules, tree_info):
202198 for k in node ['features' ].keys ():
203199 feat = node ['features' ][k ]
204200 if k in tree_info ['binary_features' ]:
205- if feat [0 ] == - sys .maxsize :
206- rule = k + ': 0'
207- else :
208- rule = k + ': 1'
201+ rule = f'{ k } : 0' if feat [0 ] == - sys .maxsize else f'{ k } : 1'
209202 elif k in tree_info ['int_features' ]:
210203 if feat [0 ] == - sys .maxsize :
211- rule = k + ' <= ' + str ( int (feat [1 ]))
204+ rule = f' { k } <= { int (feat [1 ])} '
212205 elif feat [1 ] == sys .maxsize :
213- rule = k + ' > ' + str ( int (feat [0 ]))
206+ rule = f' { k } > { int (feat [0 ])} '
214207 else :
215- rule = str (int (feat [0 ])) + ' < ' + k + ' <= ' + str (int (feat [1 ]))
208+ rule = f'{ int (feat [0 ])} < { k } <= { int (feat [1 ])} '
209+ elif feat [0 ] == - sys .maxsize :
210+ rule = f'{ k } <= { str (round (feat [1 ], 3 ))} '
211+ elif feat [1 ] == sys .maxsize :
212+ rule = f'{ k } > { str (round (feat [0 ], 3 ))} '
216213 else :
217- if feat [0 ] == - sys .maxsize :
218- rule = k + ' <= ' + str (round (feat [1 ], 3 ))
219- elif feat [1 ] == sys .maxsize :
220- rule = k + ' > ' + str (round (feat [0 ], 3 ))
221- else :
222- rule = str (round (feat [0 ], 3 )) + ' < ' + k + ' <= ' + str (round (feat [1 ], 3 ))
214+ rule = f'{ str (round (feat [0 ], 3 ))} < { k } <= { str (round (feat [1 ], 3 ))} '
223215 rules .append (rule )
224216 rules = sorted (rules , key = lambda x : len (x ))
225217 tree_rules_clean [key ] = rules
0 commit comments