@@ -203,7 +203,7 @@ def predict_synthon(self, batch, k=1):
203203 center_topk_shifted = torch .cat ([- torch .ones (1 , dtype = torch .long , device = self .device ),
204204 center_topk [:- 1 ]])
205205 product_id_shifted = torch .cat ([- torch .ones (1 , dtype = torch .long , device = self .device ),
206- graph .product_id [:- 1 ]])
206+ graph .product_id [:- 1 ]])
207207 is_duplicate = (center_topk == center_topk_shifted ) & (graph .product_id == product_id_shifted )
208208 node_index = node_index [~ is_edge ]
209209 edge_index = edge_index [is_edge ]
@@ -847,11 +847,11 @@ def _apply_action(self, graph, action, logp):
847847 data_dict .pop (key )
848848 # pad 0 for node / edge attributes
849849 for k , v in data_dict .items ():
850- if meta_dict [k ] == "node" :
850+ if "node" in meta_dict [k ]:
851851 shape = (len (new_atom_type ), * v .shape [1 :])
852852 new_data = torch .zeros (shape , dtype = v .dtype , device = self .device )
853853 data_dict [k ] = functional ._extend (v , graph .num_nodes , new_data , has_new_node )[0 ]
854- if meta_dict [k ] == "edge" :
854+ if "edge" in meta_dict [k ]:
855855 shape = (len (new_edge_list ) * 2 , * v .shape [1 :])
856856 new_data = torch .zeros (shape , dtype = v .dtype , device = self .device )
857857 data_dict [k ] = functional ._extend (v , graph .num_edges , new_data , has_new_edge * 2 )[0 ]
0 commit comments