Skip to content

Commit 9fac912

Browse files
committed
fix feature padding in GCPN
1 parent 342c87e commit 9fac912

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

torchdrug/tasks/generation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1269,11 +1269,11 @@ def _apply_action(self, graph, off_policy, max_resample=10, verbose=0, min_node=
12691269
data_dict.pop(key)
12701270
# pad 0 for node / edge attributes
12711271
for k, v in data_dict.items():
1272-
if meta_dict[k] == "node":
1272+
if "node" in meta_dict[k]:
12731273
shape = (len(new_atom_type), *v.shape[1:])
12741274
new_data = torch.zeros(shape, dtype=v.dtype, device=self.device)
12751275
data_dict[k] = functional._extend(v, graph.num_nodes, new_data, has_new_node)[0]
1276-
if meta_dict[k] == "edge":
1276+
if "edge" in meta_dict[k]:
12771277
shape = (len(new_edge_list) * 2, *v.shape[1:])
12781278
new_data = torch.zeros(shape, dtype=v.dtype, device=self.device)
12791279
data_dict[k] = functional._extend(v, graph.num_edges, new_data, has_new_edge * 2)[0]

0 commit comments

Comments
 (0)