@@ -64,7 +64,10 @@ def __init__(self, i, atts):
6464 self .nodes_missing_value_tracks_true = None
6565 for k , v in atts .items ():
6666 if k .startswith ("nodes" ):
67- setattr (self , k , v [i ])
67+ if k .endswith ("_as_tensor" ):
68+ setattr (self , k .replace ("_as_tensor" , "" ), v [i ])
69+ else :
70+ setattr (self , k , v [i ])
6871 self .depth = 0
6972 self .true_false = ""
7073 self .targets = []
@@ -120,10 +123,7 @@ def process_tree(atts, treeid):
120123 ]
121124 for k , v in atts .items ():
122125 if k .startswith (prefix ):
123- if "classlabels" in k :
124- short [k ] = list (v )
125- else :
126- short [k ] = [v [i ] for i in idx ]
126+ short [k ] = list (v ) if "classlabels" in k else [v [i ] for i in idx ]
127127
128128 nodes = OrderedDict ()
129129 for i in range (len (short ["nodes_treeids" ])):
@@ -132,9 +132,10 @@ def process_tree(atts, treeid):
132132 for i in range (len (short [f"{ prefix } _treeids" ])):
133133 idn = short [f"{ prefix } _nodeids" ][i ]
134134 node = nodes [idn ]
135- node .append_target (
136- tid = short [f"{ prefix } _ids" ][i ], weight = short [f"{ prefix } _weights" ][i ]
137- )
135+ key = f"{ prefix } _weights"
136+ if key not in short :
137+ key = f"{ prefix } _weights_as_tensor"
138+ node .append_target (tid = short [f"{ prefix } _ids" ][i ], weight = short [key ][i ])
138139
139140 def iterate (nodes , node , depth = 0 , true_false = "" ):
140141 node .depth = depth
0 commit comments