Skip to content

Commit a709d22

Browse files
committed
Make field notebook working again
1 parent 961906d commit a709d22

File tree

4 files changed

+38
-36
lines changed

4 files changed

+38
-36
lines changed

examples/kge-distmult-nations.ipynb renamed to examples/kge-distmult-nations-field.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -347,7 +347,7 @@
347347
"outputs": [],
348348
"source": [
349349
"for index, row in predict_result.iterrows():\n",
350-
" h = row[\"head\"]\n",
350+
" h = row[\"sourceNodeId\"]\n",
351351
" r = row[\"rel\"]\n",
352352
" gds.run_cypher(\n",
353353
" f\"\"\"\n",
@@ -356,7 +356,7 @@
356356
" MATCH (b:Entity WHERE id(b) = t)\n",
357357
" MERGE (a)-[:NEW_REL_{r}]->(b)\n",
358358
" \"\"\",\n",
359-
" params={\"tt\": row[\"tail\"]},\n",
359+
" params={\"tt\": row[\"targetNodeIdTopK\"]},\n",
360360
" )"
361361
]
362362
},

examples/kge-distmult-nations.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -195,11 +195,11 @@ def inspect_graph(G):
195195
res = gds.kge.model.train(
196196
G_train,
197197
model_name=model_name,
198-
scoring_function="distmult",
199-
num_epochs=1,
200-
embedding_dimension=10,
198+
scoring_function="TransE",
199+
num_epochs=30,
200+
embedding_dimension=64,
201201
epochs_per_checkpoint=0,
202-
epochs_per_val=5,
202+
epochs_per_val=0,
203203
split_ratios={"TRAIN": 0.8, "VALID": 0.1, "TEST": 0.1},
204204
)
205205
print(res["metrics"])
@@ -218,7 +218,7 @@ def inspect_graph(G):
218218
print(predict_result.to_string())
219219

220220
for index, row in predict_result.iterrows():
221-
h = row["head"]
221+
h = row["sourceNodeId"]
222222
r = row["rel"]
223223
gds.run_cypher(
224224
f"""
@@ -227,7 +227,7 @@ def inspect_graph(G):
227227
MATCH (b:Entity WHERE id(b) = t)
228228
MERGE (a)-[:NEW_REL_{r}]->(b)
229229
""",
230-
params={"tt": row["tail"]},
230+
params={"tt": row["targetNodeIdTopK"]},
231231
)
232232

233233
brazil_node = gds.find_node_id(["Entity"], {"text": "brazil"})

graphdatascience/graph_data_science.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,12 @@
44
import sys
55
from typing import Any, Dict, Optional, Tuple, Type, Union
66

7-
import rsa
87
from neo4j import Driver
98
from pandas import DataFrame
109

10+
from graphdatascience.graph.graph_proc_runner import GraphProcRunner
11+
from graphdatascience.utils.util_proc_runner import UtilProcRunner
12+
1113
from .call_builder import IndirectCallBuilder
1214
from .endpoints import AlphaEndpoints, BetaEndpoints, DirectEndpoints
1315
from .error.uncallable_namespace import UncallableNamespace
@@ -16,8 +18,6 @@
1618
from .query_runner.neo4j_query_runner import Neo4jQueryRunner
1719
from .query_runner.query_runner import QueryRunner
1820
from .server_version.server_version import ServerVersion
19-
from graphdatascience.graph.graph_proc_runner import GraphProcRunner
20-
from graphdatascience.utils.util_proc_runner import UtilProcRunner
2121

2222

2323
class GraphDataScience(DirectEndpoints, UncallableNamespace):
@@ -53,11 +53,11 @@ def __init__(
5353
database: Optional[str], default None
5454
The Neo4j database to query against.
5555
arrow : Union[str, bool], default True
56-
Arrow connection information. This is either a bool or a string.
57-
If it is a string, it will be interpreted as a connection URL to a GDS Arrow Server.
58-
If it is a bool,
59-
True will make the client discover the connection URI to the GDS Arrow server via the Neo4j endpoint,
60-
while False will make the client use Bolt for all operations.
56+
Arrow connection information. This is either a string or a bool.
57+
- If it is a string, it will be interpreted as a connection URL to a GDS Arrow Server.
58+
- If it is a bool:
59+
- True will make the client discover the connection URI to the GDS Arrow server via the Neo4j endpoint.
60+
- False will make the client use Bolt for all operations.
6161
arrow_disable_server_verification : bool, default True
6262
A flag that overrides other TLS settings and disables server verification for TLS connections.
6363
arrow_tls_root_certs : Optional[bytes], default None
@@ -91,6 +91,7 @@ def __init__(
9191
# pub_key = rsa.PublicKey.load_pkcs1(f.read())
9292
# self._encrypted_db_password = rsa.encrypt(auth[1].encode(), pub_key).hex()
9393

94+
self._encrypted_db_password = None
9495
self._compute_cluster_ip = None
9596

9697
super().__init__(self._query_runner, "gds", self._server_version)

graphdatascience/model/kge_runner.py

Lines changed: 21 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import time
55
from typing import Any, Dict, Optional
66

7-
import pandas as pd
7+
import pyarrow
88
import requests
99
from pandas import DataFrame, Series
1010

@@ -32,12 +32,13 @@ def __init__(
3232
self._namespace = namespace
3333
self._server_version = server_version
3434
self._compute_cluster_web_uri = f"http://{compute_cluster_ip}:5005"
35+
self._compute_cluster_arrow_uri = f"grpc://{compute_cluster_ip}:8815"
3536
self._compute_cluster_mlflow_uri = f"http://{compute_cluster_ip}:8080"
3637
self._encrypted_db_password = encrypted_db_password
3738
self._arrow_uri = arrow_uri
3839

3940
@property
40-
def model(self):
41+
def model(self) -> "KgeRunner":
4142
return self
4243

4344
# @compatible_with("stream", min_inclusive=ServerVersion(2, 5, 0))
@@ -75,7 +76,7 @@ def train(
7576
mlflow_experiment_name: Optional[str] = None,
7677
) -> Series:
7778
if epochs_per_checkpoint is None:
78-
epochs_per_checkpoint = max(num_epochs / 10, 1)
79+
epochs_per_checkpoint = max(int(num_epochs / 10), 1)
7980
if loss_function_kwargs is None:
8081
loss_function_kwargs = dict(margin=1.0, adversarial_temperature=1.0, gamma=20.0)
8182
if lr_scheduler_kwargs is None:
@@ -92,7 +93,7 @@ def train(
9293
}
9394
print(algo_config)
9495

95-
graph_config = {"name": G.name()}
96+
graph_config = {"name": G.name(), "config_type": "GdsGraphConfig"}
9697

9798
config = {
9899
"user_name": "DUMMY_USER",
@@ -133,7 +134,6 @@ def predict(
133134
rel_types: list[str],
134135
mlflow_experiment_name: Optional[str] = None,
135136
) -> DataFrame:
136-
137137
algo_config = {
138138
"top_k": top_k,
139139
"node_ids": node_ids,
@@ -144,8 +144,10 @@ def predict(
144144
"user_name": "DUMMY_USER",
145145
"task": "KGE_PREDICT_PYG",
146146
"task_config": {
147+
"graph_config": {"config_type": "GdsGraphConfig", "name": "NOGRAPH"},
147148
"modelname": model_name,
148149
"task_config": algo_config,
150+
"stream_rel_results": True,
149151
},
150152
"graph_arrow_uri": self._arrow_uri,
151153
}
@@ -162,7 +164,7 @@ def predict(
162164

163165
self._wait_for_job(job_id)
164166

165-
return self._stream_results(config["user_name"], config["task_config"]["modelname"], job_id)
167+
return self._stream_results(config, job_id)
166168

167169
@client_only_endpoint("gds.kge.model")
168170
def score_triplets(
@@ -171,7 +173,6 @@ def score_triplets(
171173
triplets: list[tuple[int, str, int]],
172174
mlflow_experiment_name: Optional[str] = None,
173175
) -> DataFrame:
174-
175176
algo_config = {
176177
"triplets": triplets,
177178
}
@@ -180,8 +181,10 @@ def score_triplets(
180181
"user_name": "DUMMY_USER",
181182
"task": "KGE_SCORE_TRIPLETS_PYG",
182183
"task_config": {
184+
"graph_config": {"config_type": "GdsGraphConfig", "name": "NOGRAPH"},
183185
"modelname": model_name,
184186
"task_config": algo_config,
187+
"stream_rel_results": True,
185188
},
186189
"graph_arrow_uri": self._arrow_uri,
187190
}
@@ -198,22 +201,20 @@ def score_triplets(
198201

199202
self._wait_for_job(job_id)
200203

201-
return self._stream_results(config["user_name"], config["task_config"]["modelname"], job_id)
204+
return self._stream_results(config, job_id)
202205

203-
def _stream_results(self, user_name: str, model_name: str, job_id: str) -> DataFrame:
204-
res = requests.get(
205-
f"{self._compute_cluster_web_uri}/internal/fetch-result",
206-
params={"user_name": user_name, "modelname": model_name, "job_id": job_id},
207-
)
208-
res.raise_for_status()
206+
def _stream_results(self, config: dict, job_id: str) -> DataFrame:
207+
client = pyarrow.flight.connect(self._compute_cluster_arrow_uri)
209208

210-
res_file_name = f"res_{job_id}.json"
211-
with open(res_file_name, mode="wb+") as f:
212-
f.write(res.content)
209+
if config["task_config"].get("stream_rel_results", False):
210+
upload_descriptor = pyarrow.flight.FlightDescriptor.for_path(f"{job_id}.relationships")
211+
else:
212+
raise ValueError("No results to fetch: need to set stream_rel_results or stream_graph_results to True")
213+
flight = client.get_flight_info(upload_descriptor)
214+
reader = client.do_get(flight.endpoints[0].ticket)
215+
read_table = reader.read_all()
213216

214-
df = pd.read_json(res_file_name, orient="records", lines=True)
215-
os.remove(res_file_name)
216-
return df
217+
return read_table.to_pandas()
217218

218219
def _get_metrics(self, user_name: str, model_name: str, job_id: str) -> DataFrame:
219220
res = requests.get(

0 commit comments

Comments
 (0)