Skip to content

Commit 745afab

Browse files
author
Sourcery AI
committed
'Refactored by Sourcery'
1 parent a3e111c commit 745afab

File tree

3 files changed

+50
-62
lines changed

3 files changed

+50
-62
lines changed

Plotting/interactive_decision_tree.py

Lines changed: 48 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -31,17 +31,16 @@ def _get_tree_info(X, tree_model, target_names, target_colors, color_map):
3131
:return:
3232
dictionary of useful information
3333
'''
34-
# classify features into 3 types: binary, float and int
35-
binary_features = []
36-
for col in X.columns.values:
37-
if list(sorted(np.unique(X[col].values))) == [0, 1]:
38-
binary_features.append(col)
39-
40-
int_features = []
41-
for col in list(set(X.columns.values) - set(binary_features)):
42-
if list(X[col].map(int).values) == list(X[col].values):
43-
int_features.append(col)
44-
34+
binary_features = [
35+
col
36+
for col in X.columns.values
37+
if list(sorted(np.unique(X[col].values))) == [0, 1]
38+
]
39+
int_features = [
40+
col
41+
for col in list(set(X.columns.values) - set(binary_features))
42+
if list(X[col].map(int).values) == list(X[col].values)
43+
]
4544
# get feature names
4645
feature_names = X.columns.values
4746

@@ -51,23 +50,19 @@ def _get_tree_info(X, tree_model, target_names, target_colors, color_map):
5150

5251
# color mapping for targets
5352
if target_colors is None:
54-
if color_map is not None:
55-
cm = plt.get_cmap(color_map)
56-
else:
57-
cm = plt.get_cmap('tab20')
58-
target_colors = []
59-
for n in range(tree_model.tree_.n_classes[0]):
60-
target_colors.append(str(matplotlib.colors.rgb2hex(cm(n + 1))))
61-
62-
tree_info = {
53+
cm = plt.get_cmap('tab20') if color_map is None else plt.get_cmap(color_map)
54+
target_colors = [
55+
str(matplotlib.colors.rgb2hex(cm(n + 1)))
56+
for n in range(tree_model.tree_.n_classes[0])
57+
]
58+
return {
6359
'tree_model': tree_model,
6460
'features': [feature_names[i] for i in tree_model.tree_.feature],
6561
'binary_features': binary_features,
6662
'int_features': int_features,
6763
'target_names': target_names,
68-
'target_colors': target_colors
64+
'target_colors': target_colors,
6965
}
70-
return tree_info
7166

7267

7368
def _parse_tree(node_id, parent, pos, tree_info):
@@ -86,30 +81,33 @@ def _parse_tree(node_id, parent, pos, tree_info):
8681
complete tree structure
8782
'''
8883
tree_model = tree_info['tree_model']
89-
features = tree_info['features']
90-
binary_features = tree_info['binary_features']
91-
int_features = tree_info['int_features']
9284
target_names = tree_info['target_names']
9385

9486
node = {}
9587
if parent == 'null':
9688
node['name'] = "HEAD"
9789
else:
90+
features = tree_info['features']
9891
feature = features[parent]
92+
binary_features = tree_info['binary_features']
93+
int_features = tree_info['int_features']
9994
if pos == 'left':
10095
if feature in binary_features:
101-
node['name'] = feature + ': 0'
96+
node['name'] = f'{feature}: 0'
10297
elif feature in int_features:
103-
node['name'] = feature + " <= " + str(int(tree_model.tree_.threshold[parent]))
98+
node['name'] = f"{feature} <= {int(tree_model.tree_.threshold[parent])}"
10499
else:
105-
node['name'] = feature + " <= " + str(round(tree_model.tree_.threshold[parent], 3))
100+
node[
101+
'name'
102+
] = f"{feature} <= {str(round(tree_model.tree_.threshold[parent], 3))}"
103+
elif feature in binary_features:
104+
node['name'] = f'{feature}: 1'
105+
elif feature in int_features:
106+
node['name'] = f"{feature} > {int(tree_model.tree_.threshold[parent])}"
106107
else:
107-
if feature in binary_features:
108-
node['name'] = feature + ': 1'
109-
elif feature in int_features:
110-
node['name'] = feature + " > " + str(int(tree_model.tree_.threshold[parent]))
111-
else:
112-
node['name'] = feature + " > " + str(round(tree_model.tree_.threshold[parent], 3))
108+
node[
109+
'name'
110+
] = f"{feature} > {str(round(tree_model.tree_.threshold[parent], 3))}"
113111
try:
114112
node['parent'] = int(parent)
115113
except:
@@ -125,12 +123,12 @@ def _parse_tree(node_id, parent, pos, tree_info):
125123

126124
if tree_model.tree_.children_left[node_id] != -1 or tree_model.tree_.children_right[node_id] != -1:
127125
node['children'] = []
128-
if tree_model.tree_.children_left[node_id] != -1:
129-
child = tree_model.tree_.children_left[node_id]
130-
node['children'].append(_parse_tree(child, node_id, 'left', tree_info))
131-
if tree_model.tree_.children_right[node_id] != -1:
132-
child = tree_model.tree_.children_right[node_id]
133-
node['children'].append(_parse_tree(child, node_id, 'right', tree_info))
126+
if tree_model.tree_.children_left[node_id] != -1:
127+
child = tree_model.tree_.children_left[node_id]
128+
node['children'].append(_parse_tree(child, node_id, 'left', tree_info))
129+
if tree_model.tree_.children_right[node_id] != -1:
130+
child = tree_model.tree_.children_right[node_id]
131+
node['children'].append(_parse_tree(child, node_id, 'right', tree_info))
134132
return node
135133

136134

@@ -154,9 +152,7 @@ def _extract_rules(node_id, parent, pos, tree_rules, tree_info):
154152
features = tree_info['features']
155153
tree_model = tree_info['tree_model']
156154

157-
tree_rules[node_id] = {}
158-
tree_rules[node_id]['features'] = {}
159-
155+
tree_rules[node_id] = {'features': {}}
160156
if parent != "null":
161157
previous = copy.deepcopy(tree_rules[parent]['features'])
162158
tree_rules[node_id]['features'] = previous
@@ -202,24 +198,20 @@ def _clean_rules(tree_rules, tree_info):
202198
for k in node['features'].keys():
203199
feat = node['features'][k]
204200
if k in tree_info['binary_features']:
205-
if feat[0] == -sys.maxsize:
206-
rule = k + ': 0'
207-
else:
208-
rule = k + ': 1'
201+
rule = f'{k}: 0' if feat[0] == -sys.maxsize else f'{k}: 1'
209202
elif k in tree_info['int_features']:
210203
if feat[0] == -sys.maxsize:
211-
rule = k + ' <= ' + str(int(feat[1]))
204+
rule = f'{k} <= {int(feat[1])}'
212205
elif feat[1] == sys.maxsize:
213-
rule = k + ' > ' + str(int(feat[0]))
206+
rule = f'{k} > {int(feat[0])}'
214207
else:
215-
rule = str(int(feat[0])) + ' < ' + k + ' <= ' + str(int(feat[1]))
208+
rule = f'{int(feat[0])} < {k} <= {int(feat[1])}'
209+
elif feat[0] == -sys.maxsize:
210+
rule = f'{k} <= {str(round(feat[1], 3))}'
211+
elif feat[1] == sys.maxsize:
212+
rule = f'{k} > {str(round(feat[0], 3))}'
216213
else:
217-
if feat[0] == -sys.maxsize:
218-
rule = k + ' <= ' + str(round(feat[1], 3))
219-
elif feat[1] == sys.maxsize:
220-
rule = k + ' > ' + str(round(feat[0], 3))
221-
else:
222-
rule = str(round(feat[0], 3)) + ' < ' + k + ' <= ' + str(round(feat[1], 3))
214+
rule = f'{str(round(feat[0], 3))} < {k} <= {str(round(feat[1], 3))}'
223215
rules.append(rule)
224216
rules = sorted(rules, key= lambda x : len(x))
225217
tree_rules_clean[key] = rules

Python/codon_expt.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
11
import time
22

33
def fib(n):
4-
if n<=1:
5-
return 1
6-
return fib(n-1) + fib(n-2)
4+
return 1 if n<=1 else fib(n-1) + fib(n-2)
75

86
def approximate_pi(num_terms):
97
"""

Run-time Optimization/codon_expt.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
11
import time
22

33
def fib(n):
4-
if n<=1:
5-
return 1
6-
return fib(n-1) + fib(n-2)
4+
return 1 if n<=1 else fib(n-1) + fib(n-2)
75

86
def approximate_pi(num_terms):
97
"""

0 commit comments

Comments
 (0)