|
1 | | -import collections |
2 | 1 | import os |
| 2 | +import warnings |
| 3 | +from collections import defaultdict |
3 | 4 |
|
| 5 | +from graphdatascience import GraphDataScience |
4 | 6 | from neo4j.exceptions import ClientError |
5 | 7 | from tqdm import tqdm |
6 | 8 |
|
7 | | -from graphdatascience import GraphDataScience |
| 9 | +warnings.filterwarnings("ignore", category=DeprecationWarning) |
8 | 10 |
|
9 | | -NEO4J_URI = os.environ.get("NEO4J_URI", "bolt://localhost:7687") |
10 | | -NEO4J_AUTH = None |
11 | | -NEO4J_DB = os.environ.get("NEO4J_DB", "neo4j") |
12 | | -if os.environ.get("NEO4J_USER") and os.environ.get("NEO4J_PASSWORD"): |
13 | | - NEO4J_AUTH = ( |
14 | | - os.environ.get("NEO4J_USER"), |
15 | | - os.environ.get("NEO4J_PASSWORD"), |
16 | | - ) |
17 | | -gds = GraphDataScience(NEO4J_URI, auth=NEO4J_AUTH, database=NEO4J_DB, arrow=True) |
18 | 11 |
|
| 12 | +def setup_connection(): |
| 13 | + NEO4J_URI = os.environ.get("NEO4J_URI", "bolt://localhost:7687") |
| 14 | + NEO4J_AUTH = None |
| 15 | + NEO4J_DB = os.environ.get("NEO4J_DB", "neo4j") |
| 16 | + if os.environ.get("NEO4J_USER") and os.environ.get("NEO4J_PASSWORD"): |
| 17 | + NEO4J_AUTH = ( |
| 18 | + os.environ.get("NEO4J_USER"), |
| 19 | + os.environ.get("NEO4J_PASSWORD"), |
| 20 | + ) |
| 21 | + gds = GraphDataScience(NEO4J_URI, auth=NEO4J_AUTH, database=NEO4J_DB, arrow=True) |
19 | 22 |
|
20 | | -try: |
21 | | - _ = gds.run_cypher("CREATE CONSTRAINT entity_id FOR (e:Entity) REQUIRE e.id IS UNIQUE") |
22 | | -except ClientError: |
23 | | - print("CONSTRAINT entity_id already exists") |
| 23 | + return gds |
| 24 | + |
| 25 | + |
| 26 | +def create_constraint(gds): |
| 27 | + try: |
| 28 | + _ = gds.run_cypher("CREATE CONSTRAINT entity_id FOR (e:Entity) REQUIRE e.id IS UNIQUE") |
| 29 | + except ClientError: |
| 30 | + print("CONSTRAINT entity_id already exists") |
24 | 31 |
|
25 | | -import os |
26 | | -import zipfile |
27 | | -from collections import defaultdict |
28 | 32 |
|
29 | | -from ogb.utils.url import download_url |
| 33 | +def download_data(raw_file_names): |
| 34 | + import os |
| 35 | + import zipfile |
30 | 36 |
|
31 | | -url = "https://download.microsoft.com/download/8/7/0/8700516A-AB3D-4850-B4BB-805C515AECE1/FB15K-237.2.zip" |
32 | | -raw_dir = "./data_from_zip" |
33 | | -download_url(f"{url}", raw_dir) |
| 37 | + from ogb.utils.url import download_url |
34 | 38 |
|
35 | | -raw_file_names = ["train.txt", "valid.txt", "test.txt"] |
36 | | -with zipfile.ZipFile(raw_dir + "/" + os.path.basename(url), "r") as zip_ref: |
37 | | - for filename in raw_file_names: |
38 | | - zip_ref.extract(f"Release/{filename}", path=raw_dir) |
39 | | -data_dir = raw_dir + "/" + "Release" |
| 39 | + url = "https://download.microsoft.com/download/8/7/0/8700516A-AB3D-4850-B4BB-805C515AECE1/FB15K-237.2.zip" |
| 40 | + raw_dir = "./data_from_zip" |
| 41 | + download_url(f"{url}", raw_dir) |
40 | 42 |
|
41 | | -rel_types = { |
42 | | - "train.txt": "TRAIN", |
43 | | - "valid.txt": "VALID", |
44 | | - "test.txt": "TEST", |
45 | | -} |
46 | | -rel_id_to_text_dict = {} |
47 | | -rel_type_dict = collections.defaultdict(list) |
48 | | -rel_dict = {} |
| 43 | + with zipfile.ZipFile(raw_dir + "/" + os.path.basename(url), "r") as zip_ref: |
| 44 | + for filename in raw_file_names: |
| 45 | + zip_ref.extract(f"Release/{filename}", path=raw_dir) |
| 46 | + data_dir = raw_dir + "/" + "Release" |
| 47 | + return data_dir |
49 | 48 |
|
50 | 49 |
|
51 | 50 | def read_data(): |
| 51 | + rel_types = { |
| 52 | + "train.txt": "TRAIN", |
| 53 | + "valid.txt": "VALID", |
| 54 | + "test.txt": "TEST", |
| 55 | + } |
| 56 | + raw_file_names = ["train.txt", "valid.txt", "test.txt"] |
| 57 | + |
| 58 | + data_dir = download_data(raw_file_names) |
| 59 | + |
| 60 | + rel_id_to_text_dict = {} |
| 61 | + rel_dict = {} |
52 | 62 | node_id_set = {} |
53 | 63 | dataset = defaultdict(lambda: defaultdict(list)) |
54 | 64 | for file_name in raw_file_names: |
@@ -90,15 +100,16 @@ def read_data(): |
90 | 100 | return dataset |
91 | 101 |
|
92 | 102 |
|
93 | | -dataset = read_data() |
94 | | - |
95 | | - |
96 | | -def put_data_in_db(dataset): |
| 103 | +def put_data_in_db(gds): |
97 | 104 | res = gds.run_cypher("MATCH (m) RETURN count(m) as num_nodes") |
98 | | - if res['num_nodes'].values[0] > 0: |
99 | | - print("Data already in db, number of nodes: ", res['num_nodes'].values[0]) |
| 105 | + if res["num_nodes"].values[0] > 0: |
| 106 | + print("Data already in db, number of nodes: ", res["num_nodes"].values[0]) |
100 | 107 | return |
101 | | - pbar = tqdm(desc='Putting data in db', total=sum([len(dataset[rel_split][rel_type]) for rel_split in dataset for rel_type in dataset[rel_split]])) |
| 108 | + dataset = read_data() |
| 109 | + pbar = tqdm( |
| 110 | + desc="Putting data in db", |
| 111 | + total=sum([len(dataset[rel_split][rel_type]) for rel_split in dataset for rel_type in dataset[rel_split]]), |
| 112 | + ) |
102 | 113 | for rel_split in dataset: |
103 | 114 | for rel_type in dataset[rel_split]: |
104 | 115 | edges = dataset[rel_split][rel_type] |
@@ -127,238 +138,50 @@ def put_data_in_db(dataset): |
127 | 138 | print(f"Number of relationships of type {rel_split} in db: ", res.numberOfRelationships) |
128 | 139 |
|
129 | 140 |
|
130 | | -put_data_in_db(dataset) |
131 | | - |
132 | | -ALL_RELS = dataset["TRAIN"].keys() |
133 | | -gds.graph.drop("trainGraph", failIfMissing=False) |
134 | | -G_train, result = gds.graph.cypher.project( |
| 141 | +def project_train_graph(gds): |
| 142 | + all_rels = gds.run_cypher( |
| 143 | + """ |
| 144 | + CALL db.relationshipTypes() YIELD relationshipType |
135 | 145 | """ |
136 | | - MATCH (n:Entity)-[:TRAIN]->(m:Entity)<-[:""" |
137 | | - + "|".join(ALL_RELS) |
138 | | - + """]-(n) |
139 | | - RETURN gds.graph.project($graph_name, n, m, { |
140 | | - sourceNodeLabels: $label, |
141 | | - targetNodeLabels: $label |
142 | | - }) |
143 | | - """, # Cypher query |
144 | | - database="neo4j", # Target database |
145 | | - graph_name="trainGraph", # Query parameter |
146 | | - label="Entity", # Query parameter |
147 | | -) |
| 146 | + ) |
| 147 | + all_rels = all_rels["relationshipType"].to_list() |
| 148 | + all_rels = [rel for rel in all_rels if rel.startswith("REL_")] |
| 149 | + gds.graph.drop("trainGraph", failIfMissing=False) |
| 150 | + |
| 151 | + G_train, result = gds.graph.project("trainGraph", ["Entity"], all_rels) |
| 152 | + |
| 153 | + return G_train |
148 | 154 |
|
149 | 155 |
|
150 | 156 | def inspect_graph(G): |
151 | 157 | func_names = [ |
152 | 158 | "name", |
153 | | - # "database", |
154 | 159 | "node_count", |
155 | 160 | "relationship_count", |
156 | 161 | "node_labels", |
157 | 162 | "relationship_types", |
158 | | - # "degree_distribution", "density", "size_in_bytes", "memory_usage", "exists", "configuration", "creation_time", "modification_time", |
159 | 163 | ] |
160 | 164 | for func_name in func_names: |
161 | 165 | print(f"==={func_name}===: {getattr(G, func_name)()}") |
162 | 166 |
|
163 | 167 |
|
164 | | -inspect_graph(G_train) |
165 | | - |
166 | | -gds.set_compute_cluster_ip("localhost") |
167 | | - |
168 | | -kkge = gds.kge |
169 | | -kmodel = gds.kge.model |
170 | | - |
171 | | -print(gds.debug.arrow()) |
172 | | - |
173 | | -gds.kge.model.train( |
174 | | - G_train, |
175 | | - scoring_function="DistMult", |
176 | | - num_epochs=1, |
177 | | - embedding_dimension=10, |
178 | | -) |
179 | | - |
180 | | -print('Finished training') |
181 | | -# |
182 | | -# node_projection = {"Entity": {"properties": "id"}} |
183 | | -# relationship_projection = [ |
184 | | -# {"TRAIN": {"orientation": "NATURAL", "properties": "rel_id"}}, |
185 | | -# {"TEST": {"orientation": "NATURAL", "properties": "rel_id"}}, |
186 | | -# {"VALID": {"orientation": "NATURAL", "properties": "rel_id"}}, |
187 | | -# ] |
188 | | -# |
189 | | -# ttv_G, result = gds.graph.project( |
190 | | -# "fb15k-graph-ttv", |
191 | | -# node_projection, |
192 | | -# relationship_projection, |
193 | | -# ) |
194 | | -# |
195 | | -# node_properties = gds.graph.nodeProperties.stream( |
196 | | -# ttv_G, |
197 | | -# ["id"], |
198 | | -# separate_property_columns=True, |
199 | | -# ) |
200 | | -# |
201 | | -# nodeId_to_id = dict(zip(node_properties.nodeId, node_properties.id)) |
202 | | -# id_to_nodeId = dict(zip(node_properties.id, node_properties.nodeId)) |
203 | | -# |
204 | | -# def create_data_from_graph(relationship_type): |
205 | | -# rels_tmp = gds.graph.relationshipProperty.stream(ttv_G, "rel_id", relationship_type) |
206 | | -# topology = [ |
207 | | -# rels_tmp.sourceNodeId.map(lambda x: nodeId_to_id[x]), |
208 | | -# rels_tmp.targetNodeId.map(lambda x: nodeId_to_id[x]), |
209 | | -# ] |
210 | | -# edge_index = torch.tensor(topology, dtype=torch.long) |
211 | | -# edge_type = torch.tensor(rels_tmp.propertyValue.astype(int), dtype=torch.long) |
212 | | -# data = Data(edge_index=edge_index, edge_type=edge_type) |
213 | | -# data.num_nodes = len(nodeId_to_id) |
214 | | -# display(data) |
215 | | -# return data |
216 | | -# |
217 | | -# |
218 | | -# train_tensor_data = create_data_from_graph("TRAIN") |
219 | | -# test_tensor_data = create_data_from_graph("TEST") |
220 | | -# val_tensor_data = create_data_from_graph("VALID") |
221 | | -# |
222 | | -# gds.graph.drop(ttv_G) |
223 | | -# |
224 | | -# def train_model_with_pyg(): |
225 | | -# device = "cuda" if torch.cuda.is_available() else "cpu" |
226 | | -# |
227 | | -# model = TransE( |
228 | | -# num_nodes=train_tensor_data.num_nodes, |
229 | | -# num_relations=train_tensor_data.num_edge_types, |
230 | | -# hidden_channels=50, |
231 | | -# ).to(device) |
232 | | -# |
233 | | -# loader = model.loader( |
234 | | -# head_index=train_tensor_data.edge_index[0], |
235 | | -# rel_type=train_tensor_data.edge_type, |
236 | | -# tail_index=train_tensor_data.edge_index[1], |
237 | | -# batch_size=1000, |
238 | | -# shuffle=True, |
239 | | -# ) |
240 | | -# |
241 | | -# optimizer = optim.Adam(model.parameters(), lr=0.01) |
242 | | -# |
243 | | -# def train(): |
244 | | -# model.train() |
245 | | -# total_loss = total_examples = 0 |
246 | | -# for head_index, rel_type, tail_index in loader: |
247 | | -# optimizer.zero_grad() |
248 | | -# loss = model.loss(head_index, rel_type, tail_index) |
249 | | -# loss.backward() |
250 | | -# optimizer.step() |
251 | | -# total_loss += float(loss) * head_index.numel() |
252 | | -# total_examples += head_index.numel() |
253 | | -# return total_loss / total_examples |
254 | | -# |
255 | | -# @torch.no_grad() |
256 | | -# def test(data): |
257 | | -# model.eval() |
258 | | -# return model.test( |
259 | | -# head_index=data.edge_index[0], |
260 | | -# rel_type=data.edge_type, |
261 | | -# tail_index=data.edge_index[1], |
262 | | -# batch_size=1000, |
263 | | -# k=10, |
264 | | -# ) |
265 | | -# |
266 | | -# # Consider increasing the number of epochs |
267 | | -# epoch_count = 5 |
268 | | -# for epoch in range(1, epoch_count): |
269 | | -# loss = train() |
270 | | -# print(f"Epoch: {epoch:03d}, Loss: {loss:.4f}") |
271 | | -# if epoch % 75 == 0: |
272 | | -# rank, hits = test(val_tensor_data) |
273 | | -# print(f"Epoch: {epoch:03d}, Val Mean Rank: {rank:.2f}, " f"Val Hits@10: {hits:.4f}") |
274 | | -# |
275 | | -# torch.save(model, f"./model_{epoch_count}.pt") |
276 | | -# |
277 | | -# mean_rank, mrr, hits_at_k = test(test_tensor_data) |
278 | | -# print(f"Test Mean Rank: {mean_rank:.2f}, Test Hits@10: {hits_at_k:.4f}, MRR: {mrr:.4f}") |
279 | | -# |
280 | | -# return model |
281 | | -# |
282 | | -# model = train_model_with_pyg() |
283 | | -# # The model can be loaded if it was trained before |
284 | | -# # model = torch.load("./model_501.pt") |
285 | | -# |
286 | | -# for i in tqdm(range(len(nodeId_to_id))): |
287 | | -# gds.run_cypher( |
288 | | -# "MATCH (n:Entity {id: $i}) SET n.emb=$EMBEDDING", |
289 | | -# params={"i": i, "EMBEDDING": model.node_emb.weight[i].tolist()}, |
290 | | -# ) |
291 | | -# |
292 | | -# relationship_to_predict = "/film/film/genre" |
293 | | -# rel_id_to_predict = rel_dict[relationship_to_predict] |
294 | | -# rel_label_to_predict = f"REL_{rel_id_to_predict}" |
295 | | -# |
296 | | -# G_test, result = gds.graph.project( |
297 | | -# "graph_to_predict_", |
298 | | -# {"Entity": {"properties": ["id", "emb"]}}, |
299 | | -# rel_label_to_predict, |
300 | | -# ) |
301 | | -# |
302 | | -# |
303 | | -# def print_graph_info(G): |
304 | | -# print(f"Graph '{G.name()}' node count: {G.node_count()}") |
305 | | -# print(f"Graph '{G.name()}' node labels: {G.node_labels()}") |
306 | | -# print(f"Graph '{G.name()}' relationship types: {G.relationship_types()}") |
307 | | -# print(f"Graph '{G.name()}' relationship count: {G.relationship_count()}") |
308 | | -# |
309 | | -# |
310 | | -# print_graph_info(G_test) |
311 | | -# |
312 | | -# target_emb = model.node_emb.weight[rel_id_to_predict].tolist() |
313 | | -# transe_model = gds.model.transe.create(G_test, "emb", {rel_label_to_predict: target_emb}) |
314 | | -# |
315 | | -# source_node_list = ["/m/07l450", "/m/0ds2l81", "/m/0jvt9"] |
316 | | -# source_ids_df = gds.run_cypher( |
317 | | -# "UNWIND $node_text_list AS t MATCH (n:Entity) WHERE n.text=t RETURN id(n) as nodeId", |
318 | | -# params={"node_text_list": source_node_list}, |
319 | | -# ) |
320 | | -# |
321 | | -# result = transe_model.predict_stream( |
322 | | -# source_node_filter=source_ids_df.nodeId, |
323 | | -# target_node_filter="Entity", |
324 | | -# relationship_type=rel_label_to_predict, |
325 | | -# top_k=3, |
326 | | -# concurrency=4, |
327 | | -# ) |
328 | | -# print(result) |
329 | | -# |
330 | | -# ids_in_result = pd.unique(pd.concat([result.sourceNodeId, result.targetNodeId])) |
331 | | -# |
332 | | -# ids_to_text = gds.run_cypher( |
333 | | -# "UNWIND $ids AS id MATCH (n:Entity) WHERE id(n)=id RETURN id(n) AS nodeId, n.text AS tag, n.id AS id", |
334 | | -# params={"ids": ids_in_result}, |
335 | | -# ) |
336 | | -# |
337 | | -# nodeId_to_text_res = dict(zip(ids_to_text.nodeId, ids_to_text.tag)) |
338 | | -# nodeId_to_id_res = dict(zip(ids_to_text.nodeId, ids_to_text.id)) |
339 | | -# |
340 | | -# result.insert(1, "sourceTag", result.sourceNodeId.map(lambda x: nodeId_to_text_res[x])) |
341 | | -# result.insert(2, "sourceId", result.sourceNodeId.map(lambda x: nodeId_to_id_res[x])) |
342 | | -# result.insert(4, "targetTag", result.targetNodeId.map(lambda x: nodeId_to_text_res[x])) |
343 | | -# result.insert(5, "targetId", result.targetNodeId.map(lambda x: nodeId_to_id_res[x])) |
344 | | -# |
345 | | -# print(result) |
346 | | -# |
347 | | -# write_relationship_type = "PREDICTED_" + rel_label_to_predict |
348 | | -# result_write = transe_model.predict_write( |
349 | | -# source_node_filter=source_ids_df.nodeId, |
350 | | -# target_node_filter="Entity", |
351 | | -# relationship_type=rel_label_to_predict, |
352 | | -# write_relationship_type=write_relationship_type, |
353 | | -# write_property="transe_score", |
354 | | -# top_k=3, |
355 | | -# concurrency=4, |
356 | | -# ) |
357 | | -# |
358 | | -# gds.run_cypher( |
359 | | -# "MATCH (n)-[r:" |
360 | | -# + write_relationship_type |
361 | | -# + "]->(m) RETURN n.id AS sourceId, n.text AS sourceTag, m.id AS targetId, m.text AS targetTag, r.transe_score AS score" |
362 | | -# ) |
363 | | -# |
364 | | -# gds.graph.drop(G_test) |
| 168 | +if __name__ == "__main__": |
| 169 | + gds = setup_connection() |
| 170 | + create_constraint(gds) |
| 171 | + put_data_in_db(gds) |
| 172 | + G_train = project_train_graph(gds) |
| 173 | + inspect_graph(G_train) |
| 174 | + |
| 175 | + gds.set_compute_cluster_ip("localhost") |
| 176 | + |
| 177 | + print(gds.debug.arrow()) |
| 178 | + |
| 179 | + gds.kge.model.train( |
| 180 | + G_train, |
| 181 | + scoring_function="DistMult", |
| 182 | + num_epochs=1, |
| 183 | + embedding_dimension=10, |
| 184 | + epochs_per_checkpoint=0, |
| 185 | + ) |
| 186 | + |
| 187 | + print('Finished training') |
0 commit comments