Skip to content

Commit d340ff0

Browse files
committed
Added doc about triplet scoring
1 parent d9b34bb commit d340ff0

File tree

2 files changed

+99
-7
lines changed

2 files changed

+99
-7
lines changed

doc/modules/ROOT/pages/gds-session-algorithms/knowledge-graph-embeddings.adoc

Lines changed: 85 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -326,7 +326,7 @@ predict_result = gds.kge.model.predict(
326326

327327
|====
328328

329-
For every `N` head entities and `M` relationship types, the function returns `N*M` rows.
329+
For every `N` head entities and `M` relationship types, the function returns `N*M` rows
330330
The result object is pandas DataFrame with the following columns:
331331

332332
.Predict result
@@ -352,6 +352,89 @@ The result object is pandas DataFrame with the following columns:
352352

353353
|====
354354

355+
=== Triplets scoring
356+
357+
Function `score_triplets` is used to compute the scores for the given triplets.
358+
Triplets are represented as a list of tuples `(head, relation, tail)`, where `head` and `tail` are node IDs and can be obtained using the `gds.find_node_id` function.
359+
`relation` is a string representing the relationship type.
360+
[source, python, role=no-test]
361+
----
362+
predict_result = gds.kge.model.score_triplets(
363+
model_name=model_name,
364+
triplets=[(node_id1, "RELATIONSHIP_TYPE1", node_id2), ... ],
365+
)
366+
----
367+
.Parameters
368+
[cols="1m,1m,1m,1", options="header"]
369+
|====
370+
| Parameter | Type | Default value | Description
371+
372+
| model_name
373+
| str
374+
| N/A
375+
| The name of the model to use for prediction
376+
377+
| triplets
378+
| list[tuple[int, str, int]]
379+
| N/A
380+
| List of triplets to score
381+
382+
|====
383+
384+
Score triplets function returns a list of scores where each score corresponds to the score of the triplet at the same index in the input list.
385+
355386
[[algorithms-embeddings-kge-examples]]
356387
== Examples
357-
TODO
388+
389+
Let's train the `TransE` model on the `Nations` dataset and predict the tail entities for a given head entity and relationship type.
390+
Upload the `Nations` dataset into neo4j database and project graph `G_train` for training.
391+
392+
Ensure that projected graph has multiple relationship types by calling `G_train.relationship_types()` function.
393+
394+
[source, python, role=no-test]
395+
----
396+
model_name = "my_transe_model"
397+
398+
gds.kge.model.train(
399+
G_train,
400+
model_name=model_name,
401+
scoring_function="transe",
402+
num_epochs=30,
403+
embedding_dimension=64,
404+
split_ratios={"TRAIN": 0.8, "VALID": 0.1, "TEST": 0.1},
405+
)
406+
----
407+
408+
This will train the `TransE` model and this model can be used for prediction.
409+
Prediction of topk tail entities for a given head entity and relationship type can be done as follows:
410+
411+
[source, python, role=no-test]
412+
----
413+
brazil_node = gds.find_node_id(["Entity"], {"text": "brazil"})
414+
uk_node = gds.find_node_id(["Entity"], {"text": "uk"})
415+
jordan_node = gds.find_node_id(["Entity"], {"text": "jordan"})
416+
417+
predict_result = gds.kge.model.predict(
418+
model_name=model_name,
419+
top_k=3,
420+
node_ids=[brazil_node, uk_node, jordan_node],
421+
rel_types=["REL_RELDIPLOMACY", "REL_RELNGO"],
422+
)
423+
424+
print(predict_result.to_string())
425+
----
426+
427+
There is also a function to score the triplets.
428+
429+
[source, python, role=no-test]
430+
----
431+
triplets = [
432+
(brazil_node, "REL_RELNGO", uk_node),
433+
(brazil_node, "REL_RELDIPLOMACY", jordan_node),
434+
]
435+
436+
scores = gds.kge.model.score_triplets(
437+
model_name=model_name,
438+
triplets=triplets,
439+
)
440+
----

examples/kge-distmult-nations.ipynb

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,16 @@
228228
"G_train, G_valid, G_test = project_graphs()"
229229
]
230230
},
231+
{
232+
"cell_type": "code",
233+
"execution_count": null,
234+
"id": "21da1ea76d247803",
235+
"metadata": {},
236+
"outputs": [],
237+
"source": [
238+
"G_train.relationship_types()"
239+
]
240+
},
231241
{
232242
"cell_type": "code",
233243
"execution_count": null,
@@ -242,11 +252,10 @@
242252
"gds.kge.model.train(\n",
243253
" G_train,\n",
244254
" model_name=model_name,\n",
245-
" scoring_function=\"distmult\",\n",
246-
" num_epochs=1,\n",
247-
" embedding_dimension=10,\n",
248-
" epochs_per_checkpoint=0,\n",
249-
" epochs_per_val=0,\n",
255+
" scoring_function=\"transe\",\n",
256+
" num_epochs=30,\n",
257+
" embedding_dimension=64,\n",
258+
" split_ratios={\"TRAIN\": 0.8, \"VALID\": 0.1, \"TEST\": 0.1},\n",
250259
")\n",
251260
"\n",
252261
"predict_result = gds.kge.model.predict(\n",

0 commit comments

Comments
 (0)