Skip to content

Commit 342c87e

Browse files
committed
fix feature padding in G2Gs
1 parent 5e07dd3 commit 342c87e

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

torchdrug/tasks/retrosynthesis.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,7 @@ def predict_synthon(self, batch, k=1):
203203
center_topk_shifted = torch.cat([-torch.ones(1, dtype=torch.long, device=self.device),
204204
center_topk[:-1]])
205205
product_id_shifted = torch.cat([-torch.ones(1, dtype=torch.long, device=self.device),
206-
graph.product_id[:-1]])
206+
graph.product_id[:-1]])
207207
is_duplicate = (center_topk == center_topk_shifted) & (graph.product_id == product_id_shifted)
208208
node_index = node_index[~is_edge]
209209
edge_index = edge_index[is_edge]
@@ -847,11 +847,11 @@ def _apply_action(self, graph, action, logp):
847847
data_dict.pop(key)
848848
# pad 0 for node / edge attributes
849849
for k, v in data_dict.items():
850-
if meta_dict[k] == "node":
850+
if "node" in meta_dict[k]:
851851
shape = (len(new_atom_type), *v.shape[1:])
852852
new_data = torch.zeros(shape, dtype=v.dtype, device=self.device)
853853
data_dict[k] = functional._extend(v, graph.num_nodes, new_data, has_new_node)[0]
854-
if meta_dict[k] == "edge":
854+
if "edge" in meta_dict[k]:
855855
shape = (len(new_edge_list) * 2, *v.shape[1:])
856856
new_data = torch.zeros(shape, dtype=v.dtype, device=self.device)
857857
data_dict[k] = functional._extend(v, graph.num_edges, new_data, has_new_edge * 2)[0]

0 commit comments

Comments
 (0)