@@ -29,12 +29,12 @@ version (2.5+ or later) installed.
2929
3030Additionally, the following Python libraries are required:
3131
32- * `graphdatascience`
33- ( https://neo4j.com/docs/graph-data-science-client/current/installation/[see
34- documentation for installation instructions])
35- * PyG
36- ( https://pytorch-geometric.readthedocs.io/en/latest/notes/installation.html[see
37- PyG documentation for installation instructions])
32+ * `graphdatascience`,
33+ https://neo4j.com/docs/graph-data-science-client/current/installation/[see
34+ documentation for installation instructions]
35+ * `pytorch-geometric` version >= 2.5.0,
36+ https://pytorch-geometric.readthedocs.io/en/latest/notes/installation.html[see
37+ PyG documentation for installation instructions]
3838
3939== Setup
4040
@@ -233,15 +233,13 @@ format it into a `Data` structure suitable for training with PyG.
233233[source, python, role=no-test]
234234----
235235def create_data_from_graph(relationship_type):
236- rels_tmp = gds.graph.relationshipProperties.stream(
237- ttv_G, ["rel_id"], relationship_type, separate_property_columns=True
238- )
236+ rels_tmp = gds.graph.relationshipProperty.stream(ttv_G, "rel_id", relationship_type)
239237 topology = [
240238 rels_tmp.sourceNodeId.map(lambda x: nodeId_to_id[x]),
241239 rels_tmp.targetNodeId.map(lambda x: nodeId_to_id[x]),
242240 ]
243241 edge_index = torch.tensor(topology, dtype=torch.long)
244- edge_type = torch.tensor(rels_tmp.rel_id .astype(int), dtype=torch.long)
242+ edge_type = torch.tensor(rels_tmp.propertyValue .astype(int), dtype=torch.long)
245243 data = Data(edge_index=edge_index, edge_type=edge_type)
246244 data.num_nodes = len(nodeId_to_id)
247245 display(data)
@@ -303,7 +301,7 @@ def train_model_with_pyg():
303301 head_index=data.edge_index[0],
304302 rel_type=data.edge_type,
305303 tail_index=data.edge_index[1],
306- batch_size=20000 ,
304+ batch_size=1000 ,
307305 k=10,
308306 )
309307
@@ -316,12 +314,11 @@ def train_model_with_pyg():
316314 rank, hits = test(val_tensor_data)
317315 print(f"Epoch: {epoch:03d}, Val Mean Rank: {rank:.2f}, " f"Val Hits@10: {hits:.4f}")
318316
319- print(model)
320- rank, hits_at_10 = test(test_tensor_data)
321- print(f"Test Mean Rank: {rank:.2f}, Test Hits@10: {hits_at_10:.4f}")
322-
323317 torch.save(model, f"./model_{epoch_count}.pt")
324318
319+ mean_rank, mrr, hits_at_k = test(test_tensor_data)
320+ print(f"Test Mean Rank: {mean_rank:.2f}, Test Hits@10: {hits_at_k:.4f}, MRR: {mrr:.4f}")
321+
325322 return model
326323----
327324
0 commit comments