Skip to content

Commit d9b34bb

Browse files
committed
Added score_triplets function
1 parent 59b1e43 commit d9b34bb

File tree

4 files changed

+115
-40
lines changed

4 files changed

+115
-40
lines changed

examples/kge-distmult-nations.ipynb

Lines changed: 43 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
{
1414
"cell_type": "code",
1515
"execution_count": null,
16-
"id": "8d9719b198c3fe8e",
16+
"id": "9135277efcde2800",
1717
"metadata": {},
1818
"outputs": [],
1919
"source": [
@@ -29,7 +29,7 @@
2929
{
3030
"cell_type": "code",
3131
"execution_count": null,
32-
"id": "d4d82474217c5ca2",
32+
"id": "1551fddc3a67fa5b",
3333
"metadata": {},
3434
"outputs": [],
3535
"source": [
@@ -39,7 +39,7 @@
3939
{
4040
"cell_type": "code",
4141
"execution_count": null,
42-
"id": "c522b3dba2a0c1c9",
42+
"id": "2f05ee7fdb496f84",
4343
"metadata": {},
4444
"outputs": [],
4545
"source": [
@@ -57,7 +57,7 @@
5757
{
5858
"cell_type": "code",
5959
"execution_count": null,
60-
"id": "532f7596",
60+
"id": "658c9f8369fff77e",
6161
"metadata": {},
6262
"outputs": [],
6363
"source": [
@@ -70,7 +70,7 @@
7070
{
7171
"cell_type": "code",
7272
"execution_count": null,
73-
"id": "00757ac4",
73+
"id": "bdbf4f91da4b9934",
7474
"metadata": {},
7575
"outputs": [],
7676
"source": [
@@ -84,7 +84,7 @@
8484
{
8585
"cell_type": "code",
8686
"execution_count": null,
87-
"id": "6c9a1c4d",
87+
"id": "485869468ad5ad2e",
8888
"metadata": {},
8989
"outputs": [],
9090
"source": [
@@ -142,16 +142,16 @@
142142
" f\"Number of relationships of type {rel_split}: \",\n",
143143
" sum([len(dataset[rel_split][rel_type]) for rel_type in dataset[rel_split]]),\n",
144144
" )\n",
145-
" return dataset\n",
145+
" return dataset, node_map\n",
146146
"\n",
147147
"\n",
148-
"dataset = read_data()"
148+
"dataset, node_map = read_data()"
149149
]
150150
},
151151
{
152152
"cell_type": "code",
153153
"execution_count": null,
154-
"id": "e1cb98e4",
154+
"id": "2032a4e1aed1bd5",
155155
"metadata": {},
156156
"outputs": [],
157157
"source": [
@@ -160,7 +160,6 @@
160160
" if res[\"num_nodes\"].values[0] > 0:\n",
161161
" print(\"Data already in db, number of nodes: \", res[\"num_nodes\"].values[0])\n",
162162
" return\n",
163-
" dataset = read_data()\n",
164163
" pbar = tqdm(\n",
165164
" desc=\"Putting data in db\",\n",
166165
" total=sum([len(dataset[rel_split][rel_type]) for rel_split in dataset for rel_type in dataset[rel_split]]),\n",
@@ -198,7 +197,7 @@
198197
{
199198
"cell_type": "code",
200199
"execution_count": null,
201-
"id": "0fceb15b",
200+
"id": "5c4f1523a225fa3c",
202201
"metadata": {},
203202
"outputs": [],
204203
"source": [
@@ -232,7 +231,7 @@
232231
{
233232
"cell_type": "code",
234233
"execution_count": null,
235-
"id": "b4e2825a",
234+
"id": "5d518e67375f6ab3",
236235
"metadata": {},
237236
"outputs": [],
238237
"source": [
@@ -261,43 +260,53 @@
261260
" rel_types=[\"REL_RELDIPLOMACY\", \"REL_RELNGO\"],\n",
262261
")\n",
263262
"\n",
264-
"print(predict_result.to_string())\n",
265-
"#\n",
266-
"# gds.kge.model.predict_tail(\n",
267-
"# G_train,\n",
268-
"# model_name=model_name,\n",
269-
"# top_k=10,\n",
270-
"# node_ids=[gds.find_node_id([\"Entity\"], {\"text\": \"/m/016wzw\"}), gds.find_node_id([\"Entity\"], {\"id\": 2})],\n",
271-
"# rel_types=[\"REL_1\", \"REL_2\"],\n",
272-
"# )\n",
273-
"#\n",
274-
"# gds.kge.model.score_triples(\n",
275-
"# G_train,\n",
276-
"# model_name=model_name,\n",
277-
"# triples=[\n",
278-
"# (gds.find_node_id([\"Entity\"], {\"text\": \"/m/016wzw\"}), \"REL_1\", gds.find_node_id([\"Entity\"], {\"id\": 2})),\n",
279-
"# (gds.find_node_id([\"Entity\"], {\"id\": 0}), \"REL_123\", gds.find_node_id([\"Entity\"], {\"id\": 3})),\n",
280-
"# ],\n",
281-
"# )"
263+
"print(predict_result.to_string())"
282264
]
283265
},
284266
{
285267
"cell_type": "code",
286268
"execution_count": null,
287-
"id": "786eda29280ed31f",
269+
"id": "83b75194c69259a2",
288270
"metadata": {},
289271
"outputs": [],
290272
"source": [
291-
"# Create the dictionary"
273+
"for index, row in predict_result.iterrows():\n",
274+
" h = row[\"head\"]\n",
275+
" r = row[\"rel\"]\n",
276+
" gds.run_cypher(\n",
277+
" f\"\"\"\n",
278+
" UNWIND $tt as t\n",
279+
" MATCH (a:Entity WHERE id(a) = {h})\n",
280+
" MATCH (b:Entity WHERE id(b) = t)\n",
281+
" MERGE (a)-[:NEW_REL_{r}]->(b)\n",
282+
" \"\"\",\n",
283+
" params={\"tt\": row[\"tail\"]},\n",
284+
" )"
292285
]
293286
},
294287
{
295288
"cell_type": "code",
296289
"execution_count": null,
297-
"id": "74c501f8fcb411eb",
290+
"id": "b4e2825a",
298291
"metadata": {},
299292
"outputs": [],
300-
"source": []
293+
"source": [
294+
"brazil_node = gds.find_node_id([\"Entity\"], {\"text\": \"brazil\"})\n",
295+
"uk_node = gds.find_node_id([\"Entity\"], {\"text\": \"uk\"})\n",
296+
"jordan_node = gds.find_node_id([\"Entity\"], {\"text\": \"jordan\"})\n",
297+
"\n",
298+
"triplets = [\n",
299+
" (brazil_node, \"REL_RELNGO\", uk_node),\n",
300+
" (brazil_node, \"REL_RELDIPLOMACY\", jordan_node),\n",
301+
"]\n",
302+
"\n",
303+
"scores = gds.kge.model.score_triplets(\n",
304+
" model_name=model_name,\n",
305+
" triplets=triplets,\n",
306+
")\n",
307+
"\n",
308+
"print(scores)"
309+
]
301310
}
302311
],
303312
"metadata": {},

examples/kge-distmult-nations.py

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,8 @@ def inspect_graph(G):
186186
put_data_in_db(gds)
187187
G_train, G_valid, G_test = project_graphs(gds)
188188

189+
inspect_graph(G_train)
190+
189191
gds.set_compute_cluster_ip("localhost")
190192

191193
model_name = "dummyModelName_" + str(time.time())
@@ -197,10 +199,11 @@ def inspect_graph(G):
197199
num_epochs=1,
198200
embedding_dimension=10,
199201
epochs_per_checkpoint=0,
200-
epochs_per_val=0,
202+
epochs_per_val=5,
203+
split_ratios={"TRAIN": 0.8, "VALID": 0.1, "TEST": 0.1},
201204
)
202205

203-
df = gds.kge.model.predict(
206+
predict_result = gds.kge.model.predict(
204207
model_name=model_name,
205208
top_k=3,
206209
node_ids=[
@@ -211,7 +214,37 @@ def inspect_graph(G):
211214
rel_types=["REL_RELDIPLOMACY", "REL_RELNGO"],
212215
)
213216

214-
print(df.to_string())
217+
print(predict_result.to_string())
218+
219+
print(predict_result.to_string())
220+
for index, row in predict_result.iterrows():
221+
h = row["head"]
222+
r = row["rel"]
223+
gds.run_cypher(
224+
f"""
225+
UNWIND $tt as t
226+
MATCH (a:Entity WHERE id(a) = {h})
227+
MATCH (b:Entity WHERE id(b) = t)
228+
MERGE (a)-[:NEW_REL_{r}]->(b)
229+
""",
230+
params={"tt": row["tail"]},
231+
)
232+
233+
brazil_node = gds.find_node_id(["Entity"], {"text": "brazil"})
234+
uk_node = gds.find_node_id(["Entity"], {"text": "uk"})
235+
jordan_node = gds.find_node_id(["Entity"], {"text": "jordan"})
236+
237+
triplets = [
238+
(brazil_node, "REL_RELNGO", uk_node),
239+
(brazil_node, "REL_RELDIPLOMACY", jordan_node),
240+
]
241+
242+
scores = gds.kge.model.score_triplets(
243+
model_name=model_name,
244+
triplets=triplets,
245+
)
246+
247+
print(scores)
215248
#
216249
# gds.kge.model.predict_tail(
217250
# G_train,

graphdatascience/graph_data_science.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
from .query_runner.query_runner import QueryRunner
1818
from .server_version.server_version import ServerVersion
1919
from graphdatascience.graph.graph_proc_runner import GraphProcRunner
20-
from graphdatascience.utils.util_proc_runner import UtilProcRunner
2120

2221

2322
class GraphDataScience(DirectEndpoints, UncallableNamespace):

graphdatascience/model/kge_runner.py

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,41 @@ def predict(
155155

156156
return self._stream_results(config["user_name"], config["task_config"]["modelname"], job_id)
157157

158+
@client_only_endpoint("gds.kge.model")
159+
def score_triplets(
160+
self,
161+
model_name: str,
162+
triplets: list[tuple[int, str, int]],
163+
mlflow_experiment_name: Optional[str] = None,
164+
) -> DataFrame:
165+
166+
algo_config = {
167+
"triplets": triplets,
168+
}
169+
170+
config = {
171+
"user_name": "DUMMY_USER",
172+
"task": "KGE_SCORE_TRIPLETS_PYG",
173+
"task_config": {
174+
"modelname": model_name,
175+
"task_config": algo_config,
176+
},
177+
"graph_arrow_uri": self._arrow_uri,
178+
}
179+
if self._encrypted_db_password is not None:
180+
config["encrypted_db_password"] = self._encrypted_db_password
181+
182+
if mlflow_experiment_name is not None:
183+
config["task_config"]["mlflow"] = {
184+
"config": {"tracking_uri": self._compute_cluster_mlflow_uri, "experiment_name": mlflow_experiment_name}
185+
}
186+
187+
job_id = self._start_job(config)
188+
189+
self._wait_for_job(job_id)
190+
191+
return self._stream_results(config["user_name"], config["task_config"]["modelname"], job_id)
192+
158193
def _stream_results(self, user_name: str, model_name: str, job_id: str) -> DataFrame:
159194
res = requests.get(
160195
f"{self._compute_cluster_web_uri}/internal/fetch-result",
@@ -172,11 +207,10 @@ def _stream_results(self, user_name: str, model_name: str, job_id: str) -> DataF
172207

173208
def _start_job(self, config: Dict[str, Any]) -> str:
174209
url = f"{self._compute_cluster_web_uri}/api/machine-learning/start"
175-
print(url)
176210
res = requests.post(url, json=config)
177211
res.raise_for_status()
178212
job_id = res.json()["job_id"]
179-
logging.info(f"Job with ID '{job_id}' started")
213+
logging.info(f"Job '{config['task']}' with ID '{job_id}' started")
180214

181215
return job_id
182216

0 commit comments

Comments
 (0)