Skip to content

Commit ce6e824

Browse files
committed
KGE distmult nations is working
1 parent 577520b commit ce6e824

File tree

5 files changed

+369
-37
lines changed

5 files changed

+369
-37
lines changed

examples/kge-distmult-nations.py

Lines changed: 243 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,243 @@
1+
import os
2+
import time
3+
import warnings
4+
from collections import defaultdict
5+
6+
from neo4j.exceptions import ClientError
7+
from tqdm import tqdm
8+
9+
from graphdatascience import GraphDataScience
10+
11+
warnings.filterwarnings("ignore", category=DeprecationWarning)
12+
13+
14+
def setup_connection():
15+
NEO4J_URI = os.environ.get("NEO4J_URI", "bolt://localhost:7687")
16+
NEO4J_AUTH = None
17+
NEO4J_DB = os.environ.get("NEO4J_DB", "neo4j")
18+
if os.environ.get("NEO4J_USER") and os.environ.get("NEO4J_PASSWORD"):
19+
NEO4J_AUTH = (
20+
os.environ.get("NEO4J_USER"),
21+
os.environ.get("NEO4J_PASSWORD"),
22+
)
23+
gds = GraphDataScience(NEO4J_URI, auth=NEO4J_AUTH, database=NEO4J_DB, arrow=True)
24+
25+
return gds
26+
27+
28+
def create_constraint(gds):
29+
try:
30+
_ = gds.run_cypher("CREATE CONSTRAINT entity_id FOR (e:Entity) REQUIRE e.id IS UNIQUE")
31+
except ClientError:
32+
print("CONSTRAINT entity_id already exists")
33+
34+
35+
def download_data(raw_file_names):
36+
import os
37+
import zipfile
38+
39+
from ogb.utils.url import download_url
40+
41+
url = "https://download.microsoft.com/download/8/7/0/8700516A-AB3D-4850-B4BB-805C515AECE1/FB15K-237.2.zip"
42+
raw_dir = "./data_from_zip"
43+
download_url(f"{url}", raw_dir)
44+
45+
with zipfile.ZipFile(raw_dir + "/" + os.path.basename(url), "r") as zip_ref:
46+
for filename in raw_file_names:
47+
zip_ref.extract(f"Release/{filename}", path=raw_dir)
48+
data_dir = raw_dir + "/" + "Release"
49+
return data_dir
50+
51+
52+
def get_text_to_id_map(data_dir, text_to_id_filename):
53+
with open(data_dir + "/" + text_to_id_filename, "r") as f:
54+
data = [x.split("\t") for x in f.read().split("\n")[:-1]]
55+
text_to_id_map = {text: int(id) for text, id in data}
56+
return text_to_id_map
57+
58+
59+
def read_data():
60+
rel_types = {
61+
"train.txt": "TRAIN",
62+
"valid.txt": "VALID",
63+
"test.txt": "TEST",
64+
}
65+
raw_file_names = ["train.txt", "valid.txt", "test.txt"]
66+
node_id_filename = "entity2id.txt"
67+
rel_id_filename = "relation2id.txt"
68+
69+
data_dir = "/Users/olgarazvenskaia/work/datasets/KGDatasets/Nations"
70+
node_map = get_text_to_id_map(data_dir, node_id_filename)
71+
rel_map = get_text_to_id_map(data_dir, rel_id_filename)
72+
dataset = defaultdict(lambda: defaultdict(list))
73+
74+
rel_split_id = {"TRAIN": 0, "VALID": 1, "TEST": 2}
75+
76+
for file_name in raw_file_names:
77+
file_name_path = data_dir + "/" + file_name
78+
79+
with open(file_name_path, "r") as f:
80+
data = [x.split("\t") for x in f.read().split("\n")[:-1]]
81+
82+
for i, (src_text, rel_text, dst_text) in enumerate(data):
83+
source = node_map[src_text]
84+
target = node_map[dst_text]
85+
rel_type = "REL_" + rel_text.upper()
86+
rel_split = rel_types[file_name]
87+
88+
dataset[rel_split][rel_type].append(
89+
{
90+
"source": source,
91+
"source_text": src_text,
92+
"target": target,
93+
"target_text": dst_text,
94+
"rel_type": rel_type,
95+
"rel_id": rel_map[rel_text],
96+
"rel_split": rel_split,
97+
"rel_split_id": rel_split_id[rel_split],
98+
}
99+
)
100+
101+
print("Number of nodes: ", len(node_map))
102+
for rel_split in dataset:
103+
print(
104+
f"Number of relationships of type {rel_split}: ",
105+
sum([len(dataset[rel_split][rel_type]) for rel_type in dataset[rel_split]]),
106+
)
107+
return dataset
108+
109+
110+
def put_data_in_db(gds):
111+
res = gds.run_cypher("MATCH (m) RETURN count(m) as num_nodes")
112+
if res["num_nodes"].values[0] > 0:
113+
print("Data already in db, number of nodes: ", res["num_nodes"].values[0])
114+
return
115+
dataset = read_data()
116+
pbar = tqdm(
117+
desc="Putting data in db",
118+
total=sum([len(dataset[rel_split][rel_type]) for rel_split in dataset for rel_type in dataset[rel_split]]),
119+
)
120+
121+
for rel_split in dataset:
122+
for rel_type in dataset[rel_split]:
123+
edges = dataset[rel_split][rel_type]
124+
125+
gds.run_cypher(
126+
f"""
127+
UNWIND $ll as l
128+
MERGE (n:Entity {{id:l.source, text:l.source_text}})
129+
MERGE (m:Entity {{id:l.target, text:l.target_text}})
130+
MERGE (n)-[:{rel_type} {{split: l.rel_split_id, rel_id: l.rel_id}}]->(m)
131+
""",
132+
params={"ll": edges},
133+
)
134+
pbar.update(len(edges))
135+
pbar.close()
136+
137+
for rel_split in dataset:
138+
res = gds.run_cypher(
139+
f"""
140+
MATCH ()-[r:{rel_split}]->()
141+
RETURN COUNT(r) AS numberOfRelationships
142+
"""
143+
)
144+
print(f"Number of relationships of type {rel_split} in db: ", res.numberOfRelationships)
145+
146+
147+
def project_graphs(gds):
148+
all_rels = gds.run_cypher(
149+
"""
150+
CALL db.relationshipTypes() YIELD relationshipType
151+
"""
152+
)
153+
all_rels = all_rels["relationshipType"].to_list()
154+
all_rels = {rel: {"properties": "split"} for rel in all_rels if rel.startswith("REL_")}
155+
gds.graph.drop("fullGraph", failIfMissing=False)
156+
gds.graph.drop("trainGraph", failIfMissing=False)
157+
gds.graph.drop("validGraph", failIfMissing=False)
158+
gds.graph.drop("testGraph", failIfMissing=False)
159+
160+
G_full, _ = gds.graph.project("fullGraph", ["Entity"], all_rels)
161+
inspect_graph(G_full)
162+
163+
G_train, _ = gds.graph.filter("trainGraph", G_full, "*", "r.split = 0.0")
164+
G_valid, _ = gds.graph.filter("validGraph", G_full, "*", "r.split = 1.0")
165+
G_test, _ = gds.graph.filter("testGraph", G_full, "*", "r.split = 2.0")
166+
167+
inspect_graph(G_train)
168+
inspect_graph(G_valid)
169+
inspect_graph(G_test)
170+
171+
gds.graph.drop("fullGraph", failIfMissing=False)
172+
173+
return G_train, G_valid, G_test
174+
175+
176+
def inspect_graph(G):
177+
func_names = [
178+
"name",
179+
"node_count",
180+
"relationship_count",
181+
"node_labels",
182+
"relationship_types",
183+
]
184+
for func_name in func_names:
185+
print(f"==={func_name}===: {getattr(G, func_name)()}")
186+
187+
188+
if __name__ == "__main__":
189+
gds = setup_connection()
190+
create_constraint(gds)
191+
put_data_in_db(gds)
192+
G_train, G_valid, G_test = project_graphs(gds)
193+
inspect_graph(G_train)
194+
inspect_graph(G_valid)
195+
inspect_graph(G_test)
196+
197+
gds.set_compute_cluster_ip("localhost")
198+
199+
print(gds.debug.arrow())
200+
201+
model_name = "dummyModelName_" + str(time.time())
202+
203+
gds.kge.model.train(
204+
G_train,
205+
model_name=model_name,
206+
scoring_function="DistMult",
207+
num_epochs=1,
208+
embedding_dimension=10,
209+
epochs_per_checkpoint=0,
210+
)
211+
212+
df = gds.kge.model.predict(
213+
G_train,
214+
model_name=model_name,
215+
top_k=10,
216+
node_ids=[
217+
gds.find_node_id(["Entity"], {"text": "brazil"}),
218+
gds.find_node_id(["Entity"], {"text": "uk"}),
219+
gds.find_node_id(["Entity"], {"text": "jordan"}),
220+
],
221+
rel_types=["REL_RELDIPLOMACY", "REL_RELNGO"],
222+
)
223+
224+
print(df)
225+
#
226+
# gds.kge.model.predict_tail(
227+
# G_train,
228+
# model_name=model_name,
229+
# top_k=10,
230+
# node_ids=[gds.find_node_id(["Entity"], {"text": "/m/016wzw"}), gds.find_node_id(["Entity"], {"id": 2})],
231+
# rel_types=["REL_1", "REL_2"],
232+
# )
233+
#
234+
# gds.kge.model.score_triples(
235+
# G_train,
236+
# model_name=model_name,
237+
# triples=[
238+
# (gds.find_node_id(["Entity"], {"text": "/m/016wzw"}), "REL_1", gds.find_node_id(["Entity"], {"id": 2})),
239+
# (gds.find_node_id(["Entity"], {"id": 0}), "REL_123", gds.find_node_id(["Entity"], {"id": 3})),
240+
# ],
241+
# )
242+
243+
print("Finished training")

examples/kge-distmult.py

Lines changed: 64 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
import os
2+
import time
23
import warnings
34
from collections import defaultdict
45

5-
from graphdatascience import GraphDataScience
66
from neo4j.exceptions import ClientError
77
from tqdm import tqdm
88

9+
from graphdatascience import GraphDataScience
10+
911
warnings.filterwarnings("ignore", category=DeprecationWarning)
1012

1113

@@ -110,18 +112,19 @@ def put_data_in_db(gds):
110112
desc="Putting data in db",
111113
total=sum([len(dataset[rel_split][rel_type]) for rel_split in dataset for rel_type in dataset[rel_split]]),
112114
)
115+
rel_split_id = {"TRAIN": 0, "VALID": 1, "TEST": 2}
113116
for rel_split in dataset:
114117
for rel_type in dataset[rel_split]:
115118
edges = dataset[rel_split][rel_type]
116119

117120
# MERGE (n)-[:{rel_type} {{text:l.rel_text}}]->(m)
121+
# MERGE (n)-[:{rel_split}]->(m)
118122
gds.run_cypher(
119123
f"""
120124
UNWIND $ll as l
121125
MERGE (n:Entity {{id:l.source, text:l.source_text}})
122126
MERGE (m:Entity {{id:l.target, text:l.target_text}})
123-
MERGE (n)-[:{rel_split}]->(m)
124-
MERGE (n)-[:{rel_type}]->(m)
127+
MERGE (n)-[:{rel_type} {{split: {rel_split_id[rel_split]}}}]->(m)
125128
""",
126129
params={"ll": edges},
127130
)
@@ -153,6 +156,33 @@ def project_train_graph(gds):
153156
return G_train
154157

155158

159+
def project_predict_graph(gds):
160+
all_rels = gds.run_cypher(
161+
"""
162+
CALL db.relationshipTypes() YIELD relationshipType
163+
"""
164+
)
165+
all_rels = all_rels["relationshipType"].to_list()
166+
rel_spec = {}
167+
for rel in all_rels:
168+
if rel.startswith("REL_"):
169+
rel_spec[rel] = {"properties": ["split"]}
170+
171+
gds.graph.drop("fullGraph", failIfMissing=False)
172+
gds.graph.drop("predictGraph", failIfMissing=False)
173+
174+
# {"REL": {"properties": ["relY"]}, "RELR": {"properties": ["relY"]}}
175+
# print(rel_spec)
176+
177+
G_full, result = gds.graph.project("fullGraph", ["Entity"], all_rels)
178+
179+
G_full, result = gds.graph.project("fullGraph", ["Entity"], rel_spec)
180+
# G_predict = gds.graph.filter('predictGraph', 'fullGraph', '*', 'r.split == 2')
181+
182+
inspect_graph(G_full)
183+
return G_full
184+
185+
156186
def inspect_graph(G):
157187
func_names = [
158188
"name",
@@ -170,18 +200,47 @@ def inspect_graph(G):
170200
create_constraint(gds)
171201
put_data_in_db(gds)
172202
G_train = project_train_graph(gds)
173-
inspect_graph(G_train)
203+
# G_predict = project_predict_graph(gds)
204+
# inspect_graph(G_train)
174205

175206
gds.set_compute_cluster_ip("localhost")
176207

177208
print(gds.debug.arrow())
178209

210+
model_name = "dummyModelName_" + str(time.time())
211+
179212
gds.kge.model.train(
180213
G_train,
214+
model_name=model_name,
181215
scoring_function="DistMult",
182216
num_epochs=1,
183217
embedding_dimension=10,
184218
epochs_per_checkpoint=0,
185219
)
186220

187-
print('Finished training')
221+
gds.kge.model.predict(
222+
G_train,
223+
model_name=model_name,
224+
top_k=10,
225+
node_ids=[1, 2, 3],
226+
rel_types=["REL_1", "REL_2"],
227+
)
228+
229+
gds.kge.model.predict_tail(
230+
G_train,
231+
model_name=model_name,
232+
top_k=10,
233+
node_ids=[gds.find_node_id(["Entity"], {"text": "/m/016wzw"}), gds.find_node_id(["Entity"], {"id": 2})],
234+
rel_types=["REL_1", "REL_2"],
235+
)
236+
237+
gds.kge.model.score_triples(
238+
G_train,
239+
model_name=model_name,
240+
triples=[
241+
(gds.find_node_id(["Entity"], {"text": "/m/016wzw"}), "REL_1", gds.find_node_id(["Entity"], {"id": 2})),
242+
(gds.find_node_id(["Entity"], {"id": 0}), "REL_123", gds.find_node_id(["Entity"], {"id": 3})),
243+
],
244+
)
245+
246+
print("Finished training")

0 commit comments

Comments
 (0)