Skip to content

Commit d97a138

Browse files
committed
Works, dummy encrypted db password
1 parent 6f28c9c commit d97a138

File tree

3 files changed

+99
-270
lines changed

3 files changed

+99
-270
lines changed

examples/kge-distmult.py

Lines changed: 85 additions & 262 deletions
Original file line numberDiff line numberDiff line change
@@ -1,54 +1,64 @@
1-
import collections
21
import os
2+
import warnings
3+
from collections import defaultdict
34

5+
from graphdatascience import GraphDataScience
46
from neo4j.exceptions import ClientError
57
from tqdm import tqdm
68

7-
from graphdatascience import GraphDataScience
9+
warnings.filterwarnings("ignore", category=DeprecationWarning)
810

9-
NEO4J_URI = os.environ.get("NEO4J_URI", "bolt://localhost:7687")
10-
NEO4J_AUTH = None
11-
NEO4J_DB = os.environ.get("NEO4J_DB", "neo4j")
12-
if os.environ.get("NEO4J_USER") and os.environ.get("NEO4J_PASSWORD"):
13-
NEO4J_AUTH = (
14-
os.environ.get("NEO4J_USER"),
15-
os.environ.get("NEO4J_PASSWORD"),
16-
)
17-
gds = GraphDataScience(NEO4J_URI, auth=NEO4J_AUTH, database=NEO4J_DB, arrow=True)
1811

12+
def setup_connection():
13+
NEO4J_URI = os.environ.get("NEO4J_URI", "bolt://localhost:7687")
14+
NEO4J_AUTH = None
15+
NEO4J_DB = os.environ.get("NEO4J_DB", "neo4j")
16+
if os.environ.get("NEO4J_USER") and os.environ.get("NEO4J_PASSWORD"):
17+
NEO4J_AUTH = (
18+
os.environ.get("NEO4J_USER"),
19+
os.environ.get("NEO4J_PASSWORD"),
20+
)
21+
gds = GraphDataScience(NEO4J_URI, auth=NEO4J_AUTH, database=NEO4J_DB, arrow=True)
1922

20-
try:
21-
_ = gds.run_cypher("CREATE CONSTRAINT entity_id FOR (e:Entity) REQUIRE e.id IS UNIQUE")
22-
except ClientError:
23-
print("CONSTRAINT entity_id already exists")
23+
return gds
24+
25+
26+
def create_constraint(gds):
27+
try:
28+
_ = gds.run_cypher("CREATE CONSTRAINT entity_id FOR (e:Entity) REQUIRE e.id IS UNIQUE")
29+
except ClientError:
30+
print("CONSTRAINT entity_id already exists")
2431

25-
import os
26-
import zipfile
27-
from collections import defaultdict
2832

29-
from ogb.utils.url import download_url
33+
def download_data(raw_file_names):
34+
import os
35+
import zipfile
3036

31-
url = "https://download.microsoft.com/download/8/7/0/8700516A-AB3D-4850-B4BB-805C515AECE1/FB15K-237.2.zip"
32-
raw_dir = "./data_from_zip"
33-
download_url(f"{url}", raw_dir)
37+
from ogb.utils.url import download_url
3438

35-
raw_file_names = ["train.txt", "valid.txt", "test.txt"]
36-
with zipfile.ZipFile(raw_dir + "/" + os.path.basename(url), "r") as zip_ref:
37-
for filename in raw_file_names:
38-
zip_ref.extract(f"Release/{filename}", path=raw_dir)
39-
data_dir = raw_dir + "/" + "Release"
39+
url = "https://download.microsoft.com/download/8/7/0/8700516A-AB3D-4850-B4BB-805C515AECE1/FB15K-237.2.zip"
40+
raw_dir = "./data_from_zip"
41+
download_url(f"{url}", raw_dir)
4042

41-
rel_types = {
42-
"train.txt": "TRAIN",
43-
"valid.txt": "VALID",
44-
"test.txt": "TEST",
45-
}
46-
rel_id_to_text_dict = {}
47-
rel_type_dict = collections.defaultdict(list)
48-
rel_dict = {}
43+
with zipfile.ZipFile(raw_dir + "/" + os.path.basename(url), "r") as zip_ref:
44+
for filename in raw_file_names:
45+
zip_ref.extract(f"Release/{filename}", path=raw_dir)
46+
data_dir = raw_dir + "/" + "Release"
47+
return data_dir
4948

5049

5150
def read_data():
51+
rel_types = {
52+
"train.txt": "TRAIN",
53+
"valid.txt": "VALID",
54+
"test.txt": "TEST",
55+
}
56+
raw_file_names = ["train.txt", "valid.txt", "test.txt"]
57+
58+
data_dir = download_data(raw_file_names)
59+
60+
rel_id_to_text_dict = {}
61+
rel_dict = {}
5262
node_id_set = {}
5363
dataset = defaultdict(lambda: defaultdict(list))
5464
for file_name in raw_file_names:
@@ -90,15 +100,16 @@ def read_data():
90100
return dataset
91101

92102

93-
dataset = read_data()
94-
95-
96-
def put_data_in_db(dataset):
103+
def put_data_in_db(gds):
97104
res = gds.run_cypher("MATCH (m) RETURN count(m) as num_nodes")
98-
if res['num_nodes'].values[0] > 0:
99-
print("Data already in db, number of nodes: ", res['num_nodes'].values[0])
105+
if res["num_nodes"].values[0] > 0:
106+
print("Data already in db, number of nodes: ", res["num_nodes"].values[0])
100107
return
101-
pbar = tqdm(desc='Putting data in db', total=sum([len(dataset[rel_split][rel_type]) for rel_split in dataset for rel_type in dataset[rel_split]]))
108+
dataset = read_data()
109+
pbar = tqdm(
110+
desc="Putting data in db",
111+
total=sum([len(dataset[rel_split][rel_type]) for rel_split in dataset for rel_type in dataset[rel_split]]),
112+
)
102113
for rel_split in dataset:
103114
for rel_type in dataset[rel_split]:
104115
edges = dataset[rel_split][rel_type]
@@ -127,238 +138,50 @@ def put_data_in_db(dataset):
127138
print(f"Number of relationships of type {rel_split} in db: ", res.numberOfRelationships)
128139

129140

130-
put_data_in_db(dataset)
131-
132-
ALL_RELS = dataset["TRAIN"].keys()
133-
gds.graph.drop("trainGraph", failIfMissing=False)
134-
G_train, result = gds.graph.cypher.project(
141+
def project_train_graph(gds):
142+
all_rels = gds.run_cypher(
143+
"""
144+
CALL db.relationshipTypes() YIELD relationshipType
135145
"""
136-
MATCH (n:Entity)-[:TRAIN]->(m:Entity)<-[:"""
137-
+ "|".join(ALL_RELS)
138-
+ """]-(n)
139-
RETURN gds.graph.project($graph_name, n, m, {
140-
sourceNodeLabels: $label,
141-
targetNodeLabels: $label
142-
})
143-
""", # Cypher query
144-
database="neo4j", # Target database
145-
graph_name="trainGraph", # Query parameter
146-
label="Entity", # Query parameter
147-
)
146+
)
147+
all_rels = all_rels["relationshipType"].to_list()
148+
all_rels = [rel for rel in all_rels if rel.startswith("REL_")]
149+
gds.graph.drop("trainGraph", failIfMissing=False)
150+
151+
G_train, result = gds.graph.project("trainGraph", ["Entity"], all_rels)
152+
153+
return G_train
148154

149155

150156
def inspect_graph(G):
151157
func_names = [
152158
"name",
153-
# "database",
154159
"node_count",
155160
"relationship_count",
156161
"node_labels",
157162
"relationship_types",
158-
# "degree_distribution", "density", "size_in_bytes", "memory_usage", "exists", "configuration", "creation_time", "modification_time",
159163
]
160164
for func_name in func_names:
161165
print(f"==={func_name}===: {getattr(G, func_name)()}")
162166

163167

164-
inspect_graph(G_train)
165-
166-
gds.set_compute_cluster_ip("localhost")
167-
168-
kkge = gds.kge
169-
kmodel = gds.kge.model
170-
171-
print(gds.debug.arrow())
172-
173-
gds.kge.model.train(
174-
G_train,
175-
scoring_function="DistMult",
176-
num_epochs=1,
177-
embedding_dimension=10,
178-
)
179-
180-
print('Finished training')
181-
#
182-
# node_projection = {"Entity": {"properties": "id"}}
183-
# relationship_projection = [
184-
# {"TRAIN": {"orientation": "NATURAL", "properties": "rel_id"}},
185-
# {"TEST": {"orientation": "NATURAL", "properties": "rel_id"}},
186-
# {"VALID": {"orientation": "NATURAL", "properties": "rel_id"}},
187-
# ]
188-
#
189-
# ttv_G, result = gds.graph.project(
190-
# "fb15k-graph-ttv",
191-
# node_projection,
192-
# relationship_projection,
193-
# )
194-
#
195-
# node_properties = gds.graph.nodeProperties.stream(
196-
# ttv_G,
197-
# ["id"],
198-
# separate_property_columns=True,
199-
# )
200-
#
201-
# nodeId_to_id = dict(zip(node_properties.nodeId, node_properties.id))
202-
# id_to_nodeId = dict(zip(node_properties.id, node_properties.nodeId))
203-
#
204-
# def create_data_from_graph(relationship_type):
205-
# rels_tmp = gds.graph.relationshipProperty.stream(ttv_G, "rel_id", relationship_type)
206-
# topology = [
207-
# rels_tmp.sourceNodeId.map(lambda x: nodeId_to_id[x]),
208-
# rels_tmp.targetNodeId.map(lambda x: nodeId_to_id[x]),
209-
# ]
210-
# edge_index = torch.tensor(topology, dtype=torch.long)
211-
# edge_type = torch.tensor(rels_tmp.propertyValue.astype(int), dtype=torch.long)
212-
# data = Data(edge_index=edge_index, edge_type=edge_type)
213-
# data.num_nodes = len(nodeId_to_id)
214-
# display(data)
215-
# return data
216-
#
217-
#
218-
# train_tensor_data = create_data_from_graph("TRAIN")
219-
# test_tensor_data = create_data_from_graph("TEST")
220-
# val_tensor_data = create_data_from_graph("VALID")
221-
#
222-
# gds.graph.drop(ttv_G)
223-
#
224-
# def train_model_with_pyg():
225-
# device = "cuda" if torch.cuda.is_available() else "cpu"
226-
#
227-
# model = TransE(
228-
# num_nodes=train_tensor_data.num_nodes,
229-
# num_relations=train_tensor_data.num_edge_types,
230-
# hidden_channels=50,
231-
# ).to(device)
232-
#
233-
# loader = model.loader(
234-
# head_index=train_tensor_data.edge_index[0],
235-
# rel_type=train_tensor_data.edge_type,
236-
# tail_index=train_tensor_data.edge_index[1],
237-
# batch_size=1000,
238-
# shuffle=True,
239-
# )
240-
#
241-
# optimizer = optim.Adam(model.parameters(), lr=0.01)
242-
#
243-
# def train():
244-
# model.train()
245-
# total_loss = total_examples = 0
246-
# for head_index, rel_type, tail_index in loader:
247-
# optimizer.zero_grad()
248-
# loss = model.loss(head_index, rel_type, tail_index)
249-
# loss.backward()
250-
# optimizer.step()
251-
# total_loss += float(loss) * head_index.numel()
252-
# total_examples += head_index.numel()
253-
# return total_loss / total_examples
254-
#
255-
# @torch.no_grad()
256-
# def test(data):
257-
# model.eval()
258-
# return model.test(
259-
# head_index=data.edge_index[0],
260-
# rel_type=data.edge_type,
261-
# tail_index=data.edge_index[1],
262-
# batch_size=1000,
263-
# k=10,
264-
# )
265-
#
266-
# # Consider increasing the number of epochs
267-
# epoch_count = 5
268-
# for epoch in range(1, epoch_count):
269-
# loss = train()
270-
# print(f"Epoch: {epoch:03d}, Loss: {loss:.4f}")
271-
# if epoch % 75 == 0:
272-
# rank, hits = test(val_tensor_data)
273-
# print(f"Epoch: {epoch:03d}, Val Mean Rank: {rank:.2f}, " f"Val Hits@10: {hits:.4f}")
274-
#
275-
# torch.save(model, f"./model_{epoch_count}.pt")
276-
#
277-
# mean_rank, mrr, hits_at_k = test(test_tensor_data)
278-
# print(f"Test Mean Rank: {mean_rank:.2f}, Test Hits@10: {hits_at_k:.4f}, MRR: {mrr:.4f}")
279-
#
280-
# return model
281-
#
282-
# model = train_model_with_pyg()
283-
# # The model can be loaded if it was trained before
284-
# # model = torch.load("./model_501.pt")
285-
#
286-
# for i in tqdm(range(len(nodeId_to_id))):
287-
# gds.run_cypher(
288-
# "MATCH (n:Entity {id: $i}) SET n.emb=$EMBEDDING",
289-
# params={"i": i, "EMBEDDING": model.node_emb.weight[i].tolist()},
290-
# )
291-
#
292-
# relationship_to_predict = "/film/film/genre"
293-
# rel_id_to_predict = rel_dict[relationship_to_predict]
294-
# rel_label_to_predict = f"REL_{rel_id_to_predict}"
295-
#
296-
# G_test, result = gds.graph.project(
297-
# "graph_to_predict_",
298-
# {"Entity": {"properties": ["id", "emb"]}},
299-
# rel_label_to_predict,
300-
# )
301-
#
302-
#
303-
# def print_graph_info(G):
304-
# print(f"Graph '{G.name()}' node count: {G.node_count()}")
305-
# print(f"Graph '{G.name()}' node labels: {G.node_labels()}")
306-
# print(f"Graph '{G.name()}' relationship types: {G.relationship_types()}")
307-
# print(f"Graph '{G.name()}' relationship count: {G.relationship_count()}")
308-
#
309-
#
310-
# print_graph_info(G_test)
311-
#
312-
# target_emb = model.node_emb.weight[rel_id_to_predict].tolist()
313-
# transe_model = gds.model.transe.create(G_test, "emb", {rel_label_to_predict: target_emb})
314-
#
315-
# source_node_list = ["/m/07l450", "/m/0ds2l81", "/m/0jvt9"]
316-
# source_ids_df = gds.run_cypher(
317-
# "UNWIND $node_text_list AS t MATCH (n:Entity) WHERE n.text=t RETURN id(n) as nodeId",
318-
# params={"node_text_list": source_node_list},
319-
# )
320-
#
321-
# result = transe_model.predict_stream(
322-
# source_node_filter=source_ids_df.nodeId,
323-
# target_node_filter="Entity",
324-
# relationship_type=rel_label_to_predict,
325-
# top_k=3,
326-
# concurrency=4,
327-
# )
328-
# print(result)
329-
#
330-
# ids_in_result = pd.unique(pd.concat([result.sourceNodeId, result.targetNodeId]))
331-
#
332-
# ids_to_text = gds.run_cypher(
333-
# "UNWIND $ids AS id MATCH (n:Entity) WHERE id(n)=id RETURN id(n) AS nodeId, n.text AS tag, n.id AS id",
334-
# params={"ids": ids_in_result},
335-
# )
336-
#
337-
# nodeId_to_text_res = dict(zip(ids_to_text.nodeId, ids_to_text.tag))
338-
# nodeId_to_id_res = dict(zip(ids_to_text.nodeId, ids_to_text.id))
339-
#
340-
# result.insert(1, "sourceTag", result.sourceNodeId.map(lambda x: nodeId_to_text_res[x]))
341-
# result.insert(2, "sourceId", result.sourceNodeId.map(lambda x: nodeId_to_id_res[x]))
342-
# result.insert(4, "targetTag", result.targetNodeId.map(lambda x: nodeId_to_text_res[x]))
343-
# result.insert(5, "targetId", result.targetNodeId.map(lambda x: nodeId_to_id_res[x]))
344-
#
345-
# print(result)
346-
#
347-
# write_relationship_type = "PREDICTED_" + rel_label_to_predict
348-
# result_write = transe_model.predict_write(
349-
# source_node_filter=source_ids_df.nodeId,
350-
# target_node_filter="Entity",
351-
# relationship_type=rel_label_to_predict,
352-
# write_relationship_type=write_relationship_type,
353-
# write_property="transe_score",
354-
# top_k=3,
355-
# concurrency=4,
356-
# )
357-
#
358-
# gds.run_cypher(
359-
# "MATCH (n)-[r:"
360-
# + write_relationship_type
361-
# + "]->(m) RETURN n.id AS sourceId, n.text AS sourceTag, m.id AS targetId, m.text AS targetTag, r.transe_score AS score"
362-
# )
363-
#
364-
# gds.graph.drop(G_test)
168+
if __name__ == "__main__":
169+
gds = setup_connection()
170+
create_constraint(gds)
171+
put_data_in_db(gds)
172+
G_train = project_train_graph(gds)
173+
inspect_graph(G_train)
174+
175+
gds.set_compute_cluster_ip("localhost")
176+
177+
print(gds.debug.arrow())
178+
179+
gds.kge.model.train(
180+
G_train,
181+
scoring_function="DistMult",
182+
num_epochs=1,
183+
embedding_dimension=10,
184+
epochs_per_checkpoint=0,
185+
)
186+
187+
print('Finished training')

0 commit comments

Comments
 (0)