|
13 | 13 | { |
14 | 14 | "cell_type": "code", |
15 | 15 | "execution_count": null, |
16 | | - "id": "8d9719b198c3fe8e", |
| 16 | + "id": "9135277efcde2800", |
17 | 17 | "metadata": {}, |
18 | 18 | "outputs": [], |
19 | 19 | "source": [ |
|
29 | 29 | { |
30 | 30 | "cell_type": "code", |
31 | 31 | "execution_count": null, |
32 | | - "id": "d4d82474217c5ca2", |
| 32 | + "id": "1551fddc3a67fa5b", |
33 | 33 | "metadata": {}, |
34 | 34 | "outputs": [], |
35 | 35 | "source": [ |
|
39 | 39 | { |
40 | 40 | "cell_type": "code", |
41 | 41 | "execution_count": null, |
42 | | - "id": "c522b3dba2a0c1c9", |
| 42 | + "id": "2f05ee7fdb496f84", |
43 | 43 | "metadata": {}, |
44 | 44 | "outputs": [], |
45 | 45 | "source": [ |
|
57 | 57 | { |
58 | 58 | "cell_type": "code", |
59 | 59 | "execution_count": null, |
60 | | - "id": "532f7596", |
| 60 | + "id": "658c9f8369fff77e", |
61 | 61 | "metadata": {}, |
62 | 62 | "outputs": [], |
63 | 63 | "source": [ |
|
70 | 70 | { |
71 | 71 | "cell_type": "code", |
72 | 72 | "execution_count": null, |
73 | | - "id": "00757ac4", |
| 73 | + "id": "bdbf4f91da4b9934", |
74 | 74 | "metadata": {}, |
75 | 75 | "outputs": [], |
76 | 76 | "source": [ |
|
84 | 84 | { |
85 | 85 | "cell_type": "code", |
86 | 86 | "execution_count": null, |
87 | | - "id": "6c9a1c4d", |
| 87 | + "id": "485869468ad5ad2e", |
88 | 88 | "metadata": {}, |
89 | 89 | "outputs": [], |
90 | 90 | "source": [ |
|
142 | 142 | " f\"Number of relationships of type {rel_split}: \",\n", |
143 | 143 | " sum([len(dataset[rel_split][rel_type]) for rel_type in dataset[rel_split]]),\n", |
144 | 144 | " )\n", |
145 | | - " return dataset\n", |
| 145 | + " return dataset, node_map\n", |
146 | 146 | "\n", |
147 | 147 | "\n", |
148 | | - "dataset = read_data()" |
| 148 | + "dataset, node_map = read_data()" |
149 | 149 | ] |
150 | 150 | }, |
151 | 151 | { |
152 | 152 | "cell_type": "code", |
153 | 153 | "execution_count": null, |
154 | | - "id": "e1cb98e4", |
| 154 | + "id": "2032a4e1aed1bd5", |
155 | 155 | "metadata": {}, |
156 | 156 | "outputs": [], |
157 | 157 | "source": [ |
|
160 | 160 | " if res[\"num_nodes\"].values[0] > 0:\n", |
161 | 161 | " print(\"Data already in db, number of nodes: \", res[\"num_nodes\"].values[0])\n", |
162 | 162 | " return\n", |
163 | | - " dataset = read_data()\n", |
164 | 163 | " pbar = tqdm(\n", |
165 | 164 | " desc=\"Putting data in db\",\n", |
166 | 165 | " total=sum([len(dataset[rel_split][rel_type]) for rel_split in dataset for rel_type in dataset[rel_split]]),\n", |
|
198 | 197 | { |
199 | 198 | "cell_type": "code", |
200 | 199 | "execution_count": null, |
201 | | - "id": "0fceb15b", |
| 200 | + "id": "5c4f1523a225fa3c", |
202 | 201 | "metadata": {}, |
203 | 202 | "outputs": [], |
204 | 203 | "source": [ |
|
232 | 231 | { |
233 | 232 | "cell_type": "code", |
234 | 233 | "execution_count": null, |
235 | | - "id": "b4e2825a", |
| 234 | + "id": "5d518e67375f6ab3", |
236 | 235 | "metadata": {}, |
237 | 236 | "outputs": [], |
238 | 237 | "source": [ |
|
261 | 260 | " rel_types=[\"REL_RELDIPLOMACY\", \"REL_RELNGO\"],\n", |
262 | 261 | ")\n", |
263 | 262 | "\n", |
264 | | - "print(predict_result.to_string())\n", |
265 | | - "#\n", |
266 | | - "# gds.kge.model.predict_tail(\n", |
267 | | - "# G_train,\n", |
268 | | - "# model_name=model_name,\n", |
269 | | - "# top_k=10,\n", |
270 | | - "# node_ids=[gds.find_node_id([\"Entity\"], {\"text\": \"/m/016wzw\"}), gds.find_node_id([\"Entity\"], {\"id\": 2})],\n", |
271 | | - "# rel_types=[\"REL_1\", \"REL_2\"],\n", |
272 | | - "# )\n", |
273 | | - "#\n", |
274 | | - "# gds.kge.model.score_triples(\n", |
275 | | - "# G_train,\n", |
276 | | - "# model_name=model_name,\n", |
277 | | - "# triples=[\n", |
278 | | - "# (gds.find_node_id([\"Entity\"], {\"text\": \"/m/016wzw\"}), \"REL_1\", gds.find_node_id([\"Entity\"], {\"id\": 2})),\n", |
279 | | - "# (gds.find_node_id([\"Entity\"], {\"id\": 0}), \"REL_123\", gds.find_node_id([\"Entity\"], {\"id\": 3})),\n", |
280 | | - "# ],\n", |
281 | | - "# )" |
| 263 | + "print(predict_result.to_string())" |
282 | 264 | ] |
283 | 265 | }, |
284 | 266 | { |
285 | 267 | "cell_type": "code", |
286 | 268 | "execution_count": null, |
287 | | - "id": "786eda29280ed31f", |
| 269 | + "id": "83b75194c69259a2", |
288 | 270 | "metadata": {}, |
289 | 271 | "outputs": [], |
290 | 272 | "source": [ |
291 | | - "# Create the dictionary" |
| 273 | + "for index, row in predict_result.iterrows():\n", |
| 274 | + " h = row[\"head\"]\n", |
| 275 | + " r = row[\"rel\"]\n", |
| 276 | + " gds.run_cypher(\n", |
| 277 | + " f\"\"\"\n", |
| 278 | + " UNWIND $tt as t\n", |
| 279 | + " MATCH (a:Entity WHERE id(a) = {h})\n", |
| 280 | + " MATCH (b:Entity WHERE id(b) = t)\n", |
| 281 | + " MERGE (a)-[:NEW_REL_{r}]->(b)\n", |
| 282 | + " \"\"\",\n", |
| 283 | + " params={\"tt\": row[\"tail\"]},\n", |
| 284 | + " )" |
292 | 285 | ] |
293 | 286 | }, |
294 | 287 | { |
295 | 288 | "cell_type": "code", |
296 | 289 | "execution_count": null, |
297 | | - "id": "74c501f8fcb411eb", |
| 290 | + "id": "b4e2825a", |
298 | 291 | "metadata": {}, |
299 | 292 | "outputs": [], |
300 | | - "source": [] |
| 293 | + "source": [ |
| 294 | + "brazil_node = gds.find_node_id([\"Entity\"], {\"text\": \"brazil\"})\n", |
| 295 | + "uk_node = gds.find_node_id([\"Entity\"], {\"text\": \"uk\"})\n", |
| 296 | + "jordan_node = gds.find_node_id([\"Entity\"], {\"text\": \"jordan\"})\n", |
| 297 | + "\n", |
| 298 | + "triplets = [\n", |
| 299 | + " (brazil_node, \"REL_RELNGO\", uk_node),\n", |
| 300 | + " (brazil_node, \"REL_RELDIPLOMACY\", jordan_node),\n", |
| 301 | + "]\n", |
| 302 | + "\n", |
| 303 | + "scores = gds.kge.model.score_triplets(\n", |
| 304 | + " model_name=model_name,\n", |
| 305 | + " triplets=triplets,\n", |
| 306 | + ")\n", |
| 307 | + "\n", |
| 308 | + "print(scores)" |
| 309 | + ] |
301 | 310 | } |
302 | 311 | ], |
303 | 312 | "metadata": {}, |
|
0 commit comments