|
37 | 37 | "\n", |
38 | 38 | "Additionally, the following Python libraries are required:\n", |
39 | 39 | "\n", |
40 | | - "- `graphdatascience` ([see documentation for installation instructions](https://neo4j.com/docs/graph-data-science-client/current/installation/))\n", |
41 | | - "- PyG ([see PyG documentation for installation instructions](https://pytorch-geometric.readthedocs.io/en/latest/notes/installation.html))\n", |
| 40 | + "- `graphdatascience`, [see documentation for installation instructions](https://neo4j.com/docs/graph-data-science-client/current/installation/)\n", |
| 41 | + "- `pytorch-geometric` version >= 2.5.0, [see PyG documentation for installation instructions](https://pytorch-geometric.readthedocs.io/en/latest/notes/installation.html)\n", |
42 | 42 | "\n", |
43 | 43 | "## Setup\n", |
44 | 44 | "\n", |
|
307 | 307 | "outputs": [], |
308 | 308 | "source": [ |
309 | 309 | "def create_data_from_graph(relationship_type):\n", |
310 | | - " rels_tmp = gds.graph.relationshipProperties.stream(\n", |
311 | | - " ttv_G, [\"rel_id\"], relationship_type, separate_property_columns=True\n", |
312 | | - " )\n", |
| 310 | + " rels_tmp = gds.graph.relationshipProperty.stream(ttv_G, \"rel_id\", relationship_type)\n", |
313 | 311 | " topology = [\n", |
314 | 312 | " rels_tmp.sourceNodeId.map(lambda x: nodeId_to_id[x]),\n", |
315 | 313 | " rels_tmp.targetNodeId.map(lambda x: nodeId_to_id[x]),\n", |
316 | 314 | " ]\n", |
317 | 315 | " edge_index = torch.tensor(topology, dtype=torch.long)\n", |
318 | | - " edge_type = torch.tensor(rels_tmp.rel_id.astype(int), dtype=torch.long)\n", |
| 316 | + " edge_type = torch.tensor(rels_tmp.propertyValue.astype(int), dtype=torch.long)\n", |
319 | 317 | " data = Data(edge_index=edge_index, edge_type=edge_type)\n", |
320 | 318 | " data.num_nodes = len(nodeId_to_id)\n", |
321 | 319 | " display(data)\n", |
|
398 | 396 | " head_index=data.edge_index[0],\n", |
399 | 397 | " rel_type=data.edge_type,\n", |
400 | 398 | " tail_index=data.edge_index[1],\n", |
401 | | - " batch_size=20000,\n", |
| 399 | + " batch_size=1000,\n", |
402 | 400 | " k=10,\n", |
403 | 401 | " )\n", |
404 | 402 | "\n", |
|
411 | 409 | " rank, hits = test(val_tensor_data)\n", |
412 | 410 | " print(f\"Epoch: {epoch:03d}, Val Mean Rank: {rank:.2f}, \" f\"Val Hits@10: {hits:.4f}\")\n", |
413 | 411 | "\n", |
414 | | - " print(model)\n", |
415 | | - " rank, hits_at_10 = test(test_tensor_data)\n", |
416 | | - " print(f\"Test Mean Rank: {rank:.2f}, Test Hits@10: {hits_at_10:.4f}\")\n", |
417 | | - "\n", |
418 | 412 | " torch.save(model, f\"./model_{epoch_count}.pt\")\n", |
419 | 413 | "\n", |
| 414 | + " mean_rank, mrr, hits_at_k = test(test_tensor_data)\n", |
| 415 | + " print(f\"Test Mean Rank: {mean_rank:.2f}, Test Hits@10: {hits_at_k:.4f}, MRR: {mrr:.4f}\")\n", |
| 416 | + "\n", |
420 | 417 | " return model" |
421 | 418 | ] |
422 | 419 | }, |
|
0 commit comments