Skip to content

Commit ce22056

Browse files
committed
Fix correspond adoc file
1 parent b64bdae commit ce22056

File tree

1 file changed

+12
-15
lines changed

1 file changed

+12
-15
lines changed

doc/modules/ROOT/pages/tutorials/kge-predict-transe-pyg-train.adoc

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,12 @@ version (2.5+ or later) installed.
2929

3030
Additionally, the following Python libraries are required:
3131

32-
* `graphdatascience`
33-
(https://neo4j.com/docs/graph-data-science-client/current/installation/[see
34-
documentation for installation instructions])
35-
* PyG
36-
(https://pytorch-geometric.readthedocs.io/en/latest/notes/installation.html[see
37-
PyG documentation for installation instructions])
32+
* `graphdatascience`,
33+
https://neo4j.com/docs/graph-data-science-client/current/installation/[see
34+
documentation for installation instructions]
35+
* `pytorch-geometric` version >= 2.5.0,
36+
https://pytorch-geometric.readthedocs.io/en/latest/notes/installation.html[see
37+
PyG documentation for installation instructions]
3838

3939
== Setup
4040

@@ -233,15 +233,13 @@ format it into a `Data` structure suitable for training with PyG.
233233
[source, python, role=no-test]
234234
----
235235
def create_data_from_graph(relationship_type):
236-
rels_tmp = gds.graph.relationshipProperties.stream(
237-
ttv_G, ["rel_id"], relationship_type, separate_property_columns=True
238-
)
236+
rels_tmp = gds.graph.relationshipProperty.stream(ttv_G, "rel_id", relationship_type)
239237
topology = [
240238
rels_tmp.sourceNodeId.map(lambda x: nodeId_to_id[x]),
241239
rels_tmp.targetNodeId.map(lambda x: nodeId_to_id[x]),
242240
]
243241
edge_index = torch.tensor(topology, dtype=torch.long)
244-
edge_type = torch.tensor(rels_tmp.rel_id.astype(int), dtype=torch.long)
242+
edge_type = torch.tensor(rels_tmp.propertyValue.astype(int), dtype=torch.long)
245243
data = Data(edge_index=edge_index, edge_type=edge_type)
246244
data.num_nodes = len(nodeId_to_id)
247245
display(data)
@@ -303,7 +301,7 @@ def train_model_with_pyg():
303301
head_index=data.edge_index[0],
304302
rel_type=data.edge_type,
305303
tail_index=data.edge_index[1],
306-
batch_size=20000,
304+
batch_size=1000,
307305
k=10,
308306
)
309307
@@ -316,12 +314,11 @@ def train_model_with_pyg():
316314
rank, hits = test(val_tensor_data)
317315
print(f"Epoch: {epoch:03d}, Val Mean Rank: {rank:.2f}, " f"Val Hits@10: {hits:.4f}")
318316
319-
print(model)
320-
rank, hits_at_10 = test(test_tensor_data)
321-
print(f"Test Mean Rank: {rank:.2f}, Test Hits@10: {hits_at_10:.4f}")
322-
323317
torch.save(model, f"./model_{epoch_count}.pt")
324318
319+
mean_rank, mrr, hits_at_k = test(test_tensor_data)
320+
print(f"Test Mean Rank: {mean_rank:.2f}, Test Hits@10: {hits_at_k:.4f}, MRR: {mrr:.4f}")
321+
325322
return model
326323
----
327324

0 commit comments

Comments
 (0)