Skip to content

Commit 88dfca0

Browse files
committed
Fix problem with KgeRunner
1 parent d6f6743 commit 88dfca0

File tree

3 files changed

+365
-12
lines changed

3 files changed

+365
-12
lines changed

examples/kge-distmult.py

Lines changed: 352 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,352 @@
1+
import collections
2+
import os
3+
4+
from neo4j.exceptions import ClientError
5+
from tqdm import tqdm
6+
7+
from graphdatascience import GraphDataScience
8+
9+
NEO4J_URI = os.environ.get("NEO4J_URI", "bolt://localhost:7687")
10+
NEO4J_AUTH = None
11+
NEO4J_DB = os.environ.get("NEO4J_DB", "neo4j")
12+
if os.environ.get("NEO4J_USER") and os.environ.get("NEO4J_PASSWORD"):
13+
NEO4J_AUTH = (
14+
os.environ.get("NEO4J_USER"),
15+
os.environ.get("NEO4J_PASSWORD"),
16+
)
17+
gds = GraphDataScience(NEO4J_URI, auth=NEO4J_AUTH, database=NEO4J_DB, arrow=True)
18+
19+
20+
try:
21+
_ = gds.run_cypher("CREATE CONSTRAINT entity_id FOR (e:Entity) REQUIRE e.id IS UNIQUE")
22+
except ClientError:
23+
print("CONSTRAINT entity_id already exists")
24+
25+
import os
26+
import zipfile
27+
from collections import defaultdict
28+
29+
from ogb.utils.url import download_url
30+
31+
url = "https://download.microsoft.com/download/8/7/0/8700516A-AB3D-4850-B4BB-805C515AECE1/FB15K-237.2.zip"
32+
raw_dir = "./data_from_zip"
33+
download_url(f"{url}", raw_dir)
34+
35+
raw_file_names = ["train.txt", "valid.txt", "test.txt"]
36+
with zipfile.ZipFile(raw_dir + "/" + os.path.basename(url), "r") as zip_ref:
37+
for filename in raw_file_names:
38+
zip_ref.extract(f"Release/{filename}", path=raw_dir)
39+
data_dir = raw_dir + "/" + "Release"
40+
41+
rel_types = {
42+
"train.txt": "TRAIN",
43+
"valid.txt": "VALID",
44+
"test.txt": "TEST",
45+
}
46+
rel_id_to_text_dict = {}
47+
rel_type_dict = collections.defaultdict(list)
48+
rel_dict = {}
49+
50+
51+
def read_data():
52+
node_id_set = {}
53+
dataset = defaultdict(lambda: defaultdict(list))
54+
for file_name in raw_file_names:
55+
file_name_path = data_dir + "/" + file_name
56+
57+
with open(file_name_path, "r") as f:
58+
data = [x.split("\t") for x in f.read().split("\n")[:-1]]
59+
60+
for i, (src_text, rel_text, dst_text) in enumerate(data):
61+
if src_text not in node_id_set:
62+
node_id_set[src_text] = len(node_id_set)
63+
if dst_text not in node_id_set:
64+
node_id_set[dst_text] = len(node_id_set)
65+
if rel_text not in rel_dict:
66+
rel_dict[rel_text] = len(rel_dict)
67+
rel_id_to_text_dict[rel_dict[rel_text]] = rel_text
68+
69+
source = node_id_set[src_text]
70+
target = node_id_set[dst_text]
71+
rel_type = "REL_" + str(rel_dict[rel_text])
72+
rel_split = rel_types[file_name]
73+
74+
dataset[rel_split][rel_type].append(
75+
{
76+
"source": source,
77+
"source_text": src_text,
78+
"target": target,
79+
"target_text": dst_text,
80+
# "rel_text": rel_text,
81+
}
82+
)
83+
84+
print("Number of nodes: ", len(node_id_set))
85+
for rel_split in dataset:
86+
print(
87+
f"Number of relationships of type {rel_split}: ",
88+
sum([len(dataset[rel_split][rel_type]) for rel_type in dataset[rel_split]]),
89+
)
90+
return dataset
91+
92+
93+
dataset = read_data()
94+
95+
96+
def put_data_in_db(dataset):
97+
for rel_split in tqdm(dataset, desc="Relationship"):
98+
for rel_type in tqdm(dataset[rel_split], mininterval=1, leave=False):
99+
edges = dataset[rel_split][rel_type]
100+
101+
# MERGE (n)-[:{rel_type} {{text:l.rel_text}}]->(m)
102+
gds.run_cypher(
103+
f"""
104+
UNWIND $ll as l
105+
MERGE (n:Entity {{id:l.source, text:l.source_text}})
106+
MERGE (m:Entity {{id:l.target, text:l.target_text}})
107+
MERGE (n)-[:{rel_split}]->(m)
108+
MERGE (n)-[:{rel_type}]->(m)
109+
""",
110+
params={"ll": edges},
111+
)
112+
113+
for rel_split in dataset:
114+
res = gds.run_cypher(
115+
f"""
116+
MATCH ()-[r:{rel_split}]->()
117+
RETURN COUNT(r) AS numberOfRelationships
118+
"""
119+
)
120+
print(f"Number of relationships of type {rel_split} in db: ", res.numberOfRelationships)
121+
122+
123+
# put_data_in_db(dataset)
124+
125+
ALL_RELS = dataset["TRAIN"].keys()
126+
gds.graph.drop("trainGraph", failIfMissing=False)
127+
G_train, result = gds.graph.cypher.project(
128+
"""
129+
MATCH (n:Entity)-[:TRAIN]->(m:Entity)<-[:"""
130+
+ "|".join(ALL_RELS)
131+
+ """]-(n)
132+
RETURN gds.graph.project($graph_name, n, m, {
133+
sourceNodeLabels: $label,
134+
targetNodeLabels: $label
135+
})
136+
""", # Cypher query
137+
database="neo4j", # Target database
138+
graph_name="trainGraph", # Query parameter
139+
label="Entity", # Query parameter
140+
)
141+
142+
143+
def inspect_graph(G):
144+
func_names = [
145+
"name",
146+
# "database",
147+
"node_count",
148+
"relationship_count",
149+
"node_labels",
150+
"relationship_types",
151+
# "degree_distribution", "density", "size_in_bytes", "memory_usage", "exists", "configuration", "creation_time", "modification_time",
152+
]
153+
for func_name in func_names:
154+
print(f"==={func_name}===: {getattr(G, func_name)()}")
155+
156+
157+
inspect_graph(G_train)
158+
159+
gds.set_compute_cluster_ip("localhost")
160+
161+
kkge = gds.kge
162+
163+
gds.kge.model.train(
164+
G_train,
165+
scoring_function="distmult",
166+
num_epochs=10,
167+
embedding_dimension=100,
168+
)
169+
#
170+
# node_projection = {"Entity": {"properties": "id"}}
171+
# relationship_projection = [
172+
# {"TRAIN": {"orientation": "NATURAL", "properties": "rel_id"}},
173+
# {"TEST": {"orientation": "NATURAL", "properties": "rel_id"}},
174+
# {"VALID": {"orientation": "NATURAL", "properties": "rel_id"}},
175+
# ]
176+
#
177+
# ttv_G, result = gds.graph.project(
178+
# "fb15k-graph-ttv",
179+
# node_projection,
180+
# relationship_projection,
181+
# )
182+
#
183+
# node_properties = gds.graph.nodeProperties.stream(
184+
# ttv_G,
185+
# ["id"],
186+
# separate_property_columns=True,
187+
# )
188+
#
189+
# nodeId_to_id = dict(zip(node_properties.nodeId, node_properties.id))
190+
# id_to_nodeId = dict(zip(node_properties.id, node_properties.nodeId))
191+
#
192+
# def create_data_from_graph(relationship_type):
193+
# rels_tmp = gds.graph.relationshipProperty.stream(ttv_G, "rel_id", relationship_type)
194+
# topology = [
195+
# rels_tmp.sourceNodeId.map(lambda x: nodeId_to_id[x]),
196+
# rels_tmp.targetNodeId.map(lambda x: nodeId_to_id[x]),
197+
# ]
198+
# edge_index = torch.tensor(topology, dtype=torch.long)
199+
# edge_type = torch.tensor(rels_tmp.propertyValue.astype(int), dtype=torch.long)
200+
# data = Data(edge_index=edge_index, edge_type=edge_type)
201+
# data.num_nodes = len(nodeId_to_id)
202+
# display(data)
203+
# return data
204+
#
205+
#
206+
# train_tensor_data = create_data_from_graph("TRAIN")
207+
# test_tensor_data = create_data_from_graph("TEST")
208+
# val_tensor_data = create_data_from_graph("VALID")
209+
#
210+
# gds.graph.drop(ttv_G)
211+
#
212+
# def train_model_with_pyg():
213+
# device = "cuda" if torch.cuda.is_available() else "cpu"
214+
#
215+
# model = TransE(
216+
# num_nodes=train_tensor_data.num_nodes,
217+
# num_relations=train_tensor_data.num_edge_types,
218+
# hidden_channels=50,
219+
# ).to(device)
220+
#
221+
# loader = model.loader(
222+
# head_index=train_tensor_data.edge_index[0],
223+
# rel_type=train_tensor_data.edge_type,
224+
# tail_index=train_tensor_data.edge_index[1],
225+
# batch_size=1000,
226+
# shuffle=True,
227+
# )
228+
#
229+
# optimizer = optim.Adam(model.parameters(), lr=0.01)
230+
#
231+
# def train():
232+
# model.train()
233+
# total_loss = total_examples = 0
234+
# for head_index, rel_type, tail_index in loader:
235+
# optimizer.zero_grad()
236+
# loss = model.loss(head_index, rel_type, tail_index)
237+
# loss.backward()
238+
# optimizer.step()
239+
# total_loss += float(loss) * head_index.numel()
240+
# total_examples += head_index.numel()
241+
# return total_loss / total_examples
242+
#
243+
# @torch.no_grad()
244+
# def test(data):
245+
# model.eval()
246+
# return model.test(
247+
# head_index=data.edge_index[0],
248+
# rel_type=data.edge_type,
249+
# tail_index=data.edge_index[1],
250+
# batch_size=1000,
251+
# k=10,
252+
# )
253+
#
254+
# # Consider increasing the number of epochs
255+
# epoch_count = 5
256+
# for epoch in range(1, epoch_count):
257+
# loss = train()
258+
# print(f"Epoch: {epoch:03d}, Loss: {loss:.4f}")
259+
# if epoch % 75 == 0:
260+
# rank, hits = test(val_tensor_data)
261+
# print(f"Epoch: {epoch:03d}, Val Mean Rank: {rank:.2f}, " f"Val Hits@10: {hits:.4f}")
262+
#
263+
# torch.save(model, f"./model_{epoch_count}.pt")
264+
#
265+
# mean_rank, mrr, hits_at_k = test(test_tensor_data)
266+
# print(f"Test Mean Rank: {mean_rank:.2f}, Test Hits@10: {hits_at_k:.4f}, MRR: {mrr:.4f}")
267+
#
268+
# return model
269+
#
270+
# model = train_model_with_pyg()
271+
# # The model can be loaded if it was trained before
272+
# # model = torch.load("./model_501.pt")
273+
#
274+
# for i in tqdm(range(len(nodeId_to_id))):
275+
# gds.run_cypher(
276+
# "MATCH (n:Entity {id: $i}) SET n.emb=$EMBEDDING",
277+
# params={"i": i, "EMBEDDING": model.node_emb.weight[i].tolist()},
278+
# )
279+
#
280+
# relationship_to_predict = "/film/film/genre"
281+
# rel_id_to_predict = rel_dict[relationship_to_predict]
282+
# rel_label_to_predict = f"REL_{rel_id_to_predict}"
283+
#
284+
# G_test, result = gds.graph.project(
285+
# "graph_to_predict_",
286+
# {"Entity": {"properties": ["id", "emb"]}},
287+
# rel_label_to_predict,
288+
# )
289+
#
290+
#
291+
# def print_graph_info(G):
292+
# print(f"Graph '{G.name()}' node count: {G.node_count()}")
293+
# print(f"Graph '{G.name()}' node labels: {G.node_labels()}")
294+
# print(f"Graph '{G.name()}' relationship types: {G.relationship_types()}")
295+
# print(f"Graph '{G.name()}' relationship count: {G.relationship_count()}")
296+
#
297+
#
298+
# print_graph_info(G_test)
299+
#
300+
# target_emb = model.node_emb.weight[rel_id_to_predict].tolist()
301+
# transe_model = gds.model.transe.create(G_test, "emb", {rel_label_to_predict: target_emb})
302+
#
303+
# source_node_list = ["/m/07l450", "/m/0ds2l81", "/m/0jvt9"]
304+
# source_ids_df = gds.run_cypher(
305+
# "UNWIND $node_text_list AS t MATCH (n:Entity) WHERE n.text=t RETURN id(n) as nodeId",
306+
# params={"node_text_list": source_node_list},
307+
# )
308+
#
309+
# result = transe_model.predict_stream(
310+
# source_node_filter=source_ids_df.nodeId,
311+
# target_node_filter="Entity",
312+
# relationship_type=rel_label_to_predict,
313+
# top_k=3,
314+
# concurrency=4,
315+
# )
316+
# print(result)
317+
#
318+
# ids_in_result = pd.unique(pd.concat([result.sourceNodeId, result.targetNodeId]))
319+
#
320+
# ids_to_text = gds.run_cypher(
321+
# "UNWIND $ids AS id MATCH (n:Entity) WHERE id(n)=id RETURN id(n) AS nodeId, n.text AS tag, n.id AS id",
322+
# params={"ids": ids_in_result},
323+
# )
324+
#
325+
# nodeId_to_text_res = dict(zip(ids_to_text.nodeId, ids_to_text.tag))
326+
# nodeId_to_id_res = dict(zip(ids_to_text.nodeId, ids_to_text.id))
327+
#
328+
# result.insert(1, "sourceTag", result.sourceNodeId.map(lambda x: nodeId_to_text_res[x]))
329+
# result.insert(2, "sourceId", result.sourceNodeId.map(lambda x: nodeId_to_id_res[x]))
330+
# result.insert(4, "targetTag", result.targetNodeId.map(lambda x: nodeId_to_text_res[x]))
331+
# result.insert(5, "targetId", result.targetNodeId.map(lambda x: nodeId_to_id_res[x]))
332+
#
333+
# print(result)
334+
#
335+
# write_relationship_type = "PREDICTED_" + rel_label_to_predict
336+
# result_write = transe_model.predict_write(
337+
# source_node_filter=source_ids_df.nodeId,
338+
# target_node_filter="Entity",
339+
# relationship_type=rel_label_to_predict,
340+
# write_relationship_type=write_relationship_type,
341+
# write_property="transe_score",
342+
# top_k=3,
343+
# concurrency=4,
344+
# )
345+
#
346+
# gds.run_cypher(
347+
# "MATCH (n)-[r:"
348+
# + write_relationship_type
349+
# + "]->(m) RETURN n.id AS sourceId, n.text AS sourceTag, m.id AS targetId, m.text AS targetTag, r.transe_score AS score"
350+
# )
351+
#
352+
# gds.graph.drop(G_test)

graphdatascience/graph_data_science.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -87,10 +87,10 @@ def __init__(
8787
None if arrow is True else arrow,
8888
)
8989

90-
if auth is not None:
91-
with open(self._path("graphdatascience.resources.field-testing", "pub.pem"), "rb") as f:
92-
pub_key = rsa.PublicKey.load_pkcs1(f.read())
93-
self._encrypted_db_password = rsa.encrypt(auth[1].encode(), pub_key).hex()
90+
# if auth is not None:
91+
# with open(self._path("graphdatascience.resources.field-testing", "pub.pem"), "rb") as f:
92+
# pub_key = rsa.PublicKey.load_pkcs1(f.read())
93+
# self._encrypted_db_password = rsa.encrypt(auth[1].encode(), pub_key).hex()
9494

9595
self._compute_cluster_ip = None
9696

@@ -143,20 +143,20 @@ def fastpath(self) -> FastPathRunner:
143143

144144
@property
145145
def kge(self) -> KgeRunner:
146-
print("!!!kge")
147-
# if not isinstance(self._query_runner, ArrowQueryRunner):
148-
# raise ValueError("Running FastPath requires GDS with the Arrow server enabled")
146+
print("!kge")
147+
if not isinstance(self._query_runner, ArrowQueryRunner):
148+
raise ValueError("Running FastPath requires GDS with the Arrow server enabled")
149149
if self._compute_cluster_ip is None:
150150
raise ValueError(
151151
"You must set a valid computer cluster ip with the method `set_compute_cluster_ip` to use this feature"
152152
)
153153
return KgeRunner(
154154
self._query_runner,
155-
"gds.kge",
155+
"gds.kge.model",
156156
self._server_version,
157157
self._compute_cluster_ip,
158158
self._encrypted_db_password,
159-
self._query_runner.uri,
159+
None,
160160
)
161161

162162
def __getattr__(self, attr: str) -> IndirectCallBuilder:

0 commit comments

Comments
 (0)