|
| 1 | +import collections |
| 2 | +import os |
| 3 | + |
| 4 | +from neo4j.exceptions import ClientError |
| 5 | +from tqdm import tqdm |
| 6 | + |
| 7 | +from graphdatascience import GraphDataScience |
| 8 | + |
| 9 | +NEO4J_URI = os.environ.get("NEO4J_URI", "bolt://localhost:7687") |
| 10 | +NEO4J_AUTH = None |
| 11 | +NEO4J_DB = os.environ.get("NEO4J_DB", "neo4j") |
| 12 | +if os.environ.get("NEO4J_USER") and os.environ.get("NEO4J_PASSWORD"): |
| 13 | + NEO4J_AUTH = ( |
| 14 | + os.environ.get("NEO4J_USER"), |
| 15 | + os.environ.get("NEO4J_PASSWORD"), |
| 16 | + ) |
| 17 | +gds = GraphDataScience(NEO4J_URI, auth=NEO4J_AUTH, database=NEO4J_DB, arrow=True) |
| 18 | + |
| 19 | + |
| 20 | +try: |
| 21 | + _ = gds.run_cypher("CREATE CONSTRAINT entity_id FOR (e:Entity) REQUIRE e.id IS UNIQUE") |
| 22 | +except ClientError: |
| 23 | + print("CONSTRAINT entity_id already exists") |
| 24 | + |
| 25 | +import os |
| 26 | +import zipfile |
| 27 | +from collections import defaultdict |
| 28 | + |
| 29 | +from ogb.utils.url import download_url |
| 30 | + |
| 31 | +url = "https://download.microsoft.com/download/8/7/0/8700516A-AB3D-4850-B4BB-805C515AECE1/FB15K-237.2.zip" |
| 32 | +raw_dir = "./data_from_zip" |
| 33 | +download_url(f"{url}", raw_dir) |
| 34 | + |
| 35 | +raw_file_names = ["train.txt", "valid.txt", "test.txt"] |
| 36 | +with zipfile.ZipFile(raw_dir + "/" + os.path.basename(url), "r") as zip_ref: |
| 37 | + for filename in raw_file_names: |
| 38 | + zip_ref.extract(f"Release/{filename}", path=raw_dir) |
| 39 | +data_dir = raw_dir + "/" + "Release" |
| 40 | + |
| 41 | +rel_types = { |
| 42 | + "train.txt": "TRAIN", |
| 43 | + "valid.txt": "VALID", |
| 44 | + "test.txt": "TEST", |
| 45 | +} |
| 46 | +rel_id_to_text_dict = {} |
| 47 | +rel_type_dict = collections.defaultdict(list) |
| 48 | +rel_dict = {} |
| 49 | + |
| 50 | + |
| 51 | +def read_data(): |
| 52 | + node_id_set = {} |
| 53 | + dataset = defaultdict(lambda: defaultdict(list)) |
| 54 | + for file_name in raw_file_names: |
| 55 | + file_name_path = data_dir + "/" + file_name |
| 56 | + |
| 57 | + with open(file_name_path, "r") as f: |
| 58 | + data = [x.split("\t") for x in f.read().split("\n")[:-1]] |
| 59 | + |
| 60 | + for i, (src_text, rel_text, dst_text) in enumerate(data): |
| 61 | + if src_text not in node_id_set: |
| 62 | + node_id_set[src_text] = len(node_id_set) |
| 63 | + if dst_text not in node_id_set: |
| 64 | + node_id_set[dst_text] = len(node_id_set) |
| 65 | + if rel_text not in rel_dict: |
| 66 | + rel_dict[rel_text] = len(rel_dict) |
| 67 | + rel_id_to_text_dict[rel_dict[rel_text]] = rel_text |
| 68 | + |
| 69 | + source = node_id_set[src_text] |
| 70 | + target = node_id_set[dst_text] |
| 71 | + rel_type = "REL_" + str(rel_dict[rel_text]) |
| 72 | + rel_split = rel_types[file_name] |
| 73 | + |
| 74 | + dataset[rel_split][rel_type].append( |
| 75 | + { |
| 76 | + "source": source, |
| 77 | + "source_text": src_text, |
| 78 | + "target": target, |
| 79 | + "target_text": dst_text, |
| 80 | + # "rel_text": rel_text, |
| 81 | + } |
| 82 | + ) |
| 83 | + |
| 84 | + print("Number of nodes: ", len(node_id_set)) |
| 85 | + for rel_split in dataset: |
| 86 | + print( |
| 87 | + f"Number of relationships of type {rel_split}: ", |
| 88 | + sum([len(dataset[rel_split][rel_type]) for rel_type in dataset[rel_split]]), |
| 89 | + ) |
| 90 | + return dataset |
| 91 | + |
| 92 | + |
| 93 | +dataset = read_data() |
| 94 | + |
| 95 | + |
| 96 | +def put_data_in_db(dataset): |
| 97 | + for rel_split in tqdm(dataset, desc="Relationship"): |
| 98 | + for rel_type in tqdm(dataset[rel_split], mininterval=1, leave=False): |
| 99 | + edges = dataset[rel_split][rel_type] |
| 100 | + |
| 101 | + # MERGE (n)-[:{rel_type} {{text:l.rel_text}}]->(m) |
| 102 | + gds.run_cypher( |
| 103 | + f""" |
| 104 | + UNWIND $ll as l |
| 105 | + MERGE (n:Entity {{id:l.source, text:l.source_text}}) |
| 106 | + MERGE (m:Entity {{id:l.target, text:l.target_text}}) |
| 107 | + MERGE (n)-[:{rel_split}]->(m) |
| 108 | + MERGE (n)-[:{rel_type}]->(m) |
| 109 | + """, |
| 110 | + params={"ll": edges}, |
| 111 | + ) |
| 112 | + |
| 113 | + for rel_split in dataset: |
| 114 | + res = gds.run_cypher( |
| 115 | + f""" |
| 116 | + MATCH ()-[r:{rel_split}]->() |
| 117 | + RETURN COUNT(r) AS numberOfRelationships |
| 118 | + """ |
| 119 | + ) |
| 120 | + print(f"Number of relationships of type {rel_split} in db: ", res.numberOfRelationships) |
| 121 | + |
| 122 | + |
| 123 | +# put_data_in_db(dataset) |
| 124 | + |
| 125 | +ALL_RELS = dataset["TRAIN"].keys() |
| 126 | +gds.graph.drop("trainGraph", failIfMissing=False) |
| 127 | +G_train, result = gds.graph.cypher.project( |
| 128 | + """ |
| 129 | + MATCH (n:Entity)-[:TRAIN]->(m:Entity)<-[:""" |
| 130 | + + "|".join(ALL_RELS) |
| 131 | + + """]-(n) |
| 132 | + RETURN gds.graph.project($graph_name, n, m, { |
| 133 | + sourceNodeLabels: $label, |
| 134 | + targetNodeLabels: $label |
| 135 | + }) |
| 136 | + """, # Cypher query |
| 137 | + database="neo4j", # Target database |
| 138 | + graph_name="trainGraph", # Query parameter |
| 139 | + label="Entity", # Query parameter |
| 140 | +) |
| 141 | + |
| 142 | + |
| 143 | +def inspect_graph(G): |
| 144 | + func_names = [ |
| 145 | + "name", |
| 146 | + # "database", |
| 147 | + "node_count", |
| 148 | + "relationship_count", |
| 149 | + "node_labels", |
| 150 | + "relationship_types", |
| 151 | + # "degree_distribution", "density", "size_in_bytes", "memory_usage", "exists", "configuration", "creation_time", "modification_time", |
| 152 | + ] |
| 153 | + for func_name in func_names: |
| 154 | + print(f"==={func_name}===: {getattr(G, func_name)()}") |
| 155 | + |
| 156 | + |
| 157 | +inspect_graph(G_train) |
| 158 | + |
| 159 | +gds.set_compute_cluster_ip("localhost") |
| 160 | + |
| 161 | +kkge = gds.kge |
| 162 | + |
| 163 | +gds.kge.model.train( |
| 164 | + G_train, |
| 165 | + scoring_function="distmult", |
| 166 | + num_epochs=10, |
| 167 | + embedding_dimension=100, |
| 168 | +) |
| 169 | +# |
| 170 | +# node_projection = {"Entity": {"properties": "id"}} |
| 171 | +# relationship_projection = [ |
| 172 | +# {"TRAIN": {"orientation": "NATURAL", "properties": "rel_id"}}, |
| 173 | +# {"TEST": {"orientation": "NATURAL", "properties": "rel_id"}}, |
| 174 | +# {"VALID": {"orientation": "NATURAL", "properties": "rel_id"}}, |
| 175 | +# ] |
| 176 | +# |
| 177 | +# ttv_G, result = gds.graph.project( |
| 178 | +# "fb15k-graph-ttv", |
| 179 | +# node_projection, |
| 180 | +# relationship_projection, |
| 181 | +# ) |
| 182 | +# |
| 183 | +# node_properties = gds.graph.nodeProperties.stream( |
| 184 | +# ttv_G, |
| 185 | +# ["id"], |
| 186 | +# separate_property_columns=True, |
| 187 | +# ) |
| 188 | +# |
| 189 | +# nodeId_to_id = dict(zip(node_properties.nodeId, node_properties.id)) |
| 190 | +# id_to_nodeId = dict(zip(node_properties.id, node_properties.nodeId)) |
| 191 | +# |
| 192 | +# def create_data_from_graph(relationship_type): |
| 193 | +# rels_tmp = gds.graph.relationshipProperty.stream(ttv_G, "rel_id", relationship_type) |
| 194 | +# topology = [ |
| 195 | +# rels_tmp.sourceNodeId.map(lambda x: nodeId_to_id[x]), |
| 196 | +# rels_tmp.targetNodeId.map(lambda x: nodeId_to_id[x]), |
| 197 | +# ] |
| 198 | +# edge_index = torch.tensor(topology, dtype=torch.long) |
| 199 | +# edge_type = torch.tensor(rels_tmp.propertyValue.astype(int), dtype=torch.long) |
| 200 | +# data = Data(edge_index=edge_index, edge_type=edge_type) |
| 201 | +# data.num_nodes = len(nodeId_to_id) |
| 202 | +# display(data) |
| 203 | +# return data |
| 204 | +# |
| 205 | +# |
| 206 | +# train_tensor_data = create_data_from_graph("TRAIN") |
| 207 | +# test_tensor_data = create_data_from_graph("TEST") |
| 208 | +# val_tensor_data = create_data_from_graph("VALID") |
| 209 | +# |
| 210 | +# gds.graph.drop(ttv_G) |
| 211 | +# |
| 212 | +# def train_model_with_pyg(): |
| 213 | +# device = "cuda" if torch.cuda.is_available() else "cpu" |
| 214 | +# |
| 215 | +# model = TransE( |
| 216 | +# num_nodes=train_tensor_data.num_nodes, |
| 217 | +# num_relations=train_tensor_data.num_edge_types, |
| 218 | +# hidden_channels=50, |
| 219 | +# ).to(device) |
| 220 | +# |
| 221 | +# loader = model.loader( |
| 222 | +# head_index=train_tensor_data.edge_index[0], |
| 223 | +# rel_type=train_tensor_data.edge_type, |
| 224 | +# tail_index=train_tensor_data.edge_index[1], |
| 225 | +# batch_size=1000, |
| 226 | +# shuffle=True, |
| 227 | +# ) |
| 228 | +# |
| 229 | +# optimizer = optim.Adam(model.parameters(), lr=0.01) |
| 230 | +# |
| 231 | +# def train(): |
| 232 | +# model.train() |
| 233 | +# total_loss = total_examples = 0 |
| 234 | +# for head_index, rel_type, tail_index in loader: |
| 235 | +# optimizer.zero_grad() |
| 236 | +# loss = model.loss(head_index, rel_type, tail_index) |
| 237 | +# loss.backward() |
| 238 | +# optimizer.step() |
| 239 | +# total_loss += float(loss) * head_index.numel() |
| 240 | +# total_examples += head_index.numel() |
| 241 | +# return total_loss / total_examples |
| 242 | +# |
| 243 | +# @torch.no_grad() |
| 244 | +# def test(data): |
| 245 | +# model.eval() |
| 246 | +# return model.test( |
| 247 | +# head_index=data.edge_index[0], |
| 248 | +# rel_type=data.edge_type, |
| 249 | +# tail_index=data.edge_index[1], |
| 250 | +# batch_size=1000, |
| 251 | +# k=10, |
| 252 | +# ) |
| 253 | +# |
| 254 | +# # Consider increasing the number of epochs |
| 255 | +# epoch_count = 5 |
| 256 | +# for epoch in range(1, epoch_count): |
| 257 | +# loss = train() |
| 258 | +# print(f"Epoch: {epoch:03d}, Loss: {loss:.4f}") |
| 259 | +# if epoch % 75 == 0: |
| 260 | +# rank, hits = test(val_tensor_data) |
| 261 | +# print(f"Epoch: {epoch:03d}, Val Mean Rank: {rank:.2f}, " f"Val Hits@10: {hits:.4f}") |
| 262 | +# |
| 263 | +# torch.save(model, f"./model_{epoch_count}.pt") |
| 264 | +# |
| 265 | +# mean_rank, mrr, hits_at_k = test(test_tensor_data) |
| 266 | +# print(f"Test Mean Rank: {mean_rank:.2f}, Test Hits@10: {hits_at_k:.4f}, MRR: {mrr:.4f}") |
| 267 | +# |
| 268 | +# return model |
| 269 | +# |
| 270 | +# model = train_model_with_pyg() |
| 271 | +# # The model can be loaded if it was trained before |
| 272 | +# # model = torch.load("./model_501.pt") |
| 273 | +# |
| 274 | +# for i in tqdm(range(len(nodeId_to_id))): |
| 275 | +# gds.run_cypher( |
| 276 | +# "MATCH (n:Entity {id: $i}) SET n.emb=$EMBEDDING", |
| 277 | +# params={"i": i, "EMBEDDING": model.node_emb.weight[i].tolist()}, |
| 278 | +# ) |
| 279 | +# |
| 280 | +# relationship_to_predict = "/film/film/genre" |
| 281 | +# rel_id_to_predict = rel_dict[relationship_to_predict] |
| 282 | +# rel_label_to_predict = f"REL_{rel_id_to_predict}" |
| 283 | +# |
| 284 | +# G_test, result = gds.graph.project( |
| 285 | +# "graph_to_predict_", |
| 286 | +# {"Entity": {"properties": ["id", "emb"]}}, |
| 287 | +# rel_label_to_predict, |
| 288 | +# ) |
| 289 | +# |
| 290 | +# |
| 291 | +# def print_graph_info(G): |
| 292 | +# print(f"Graph '{G.name()}' node count: {G.node_count()}") |
| 293 | +# print(f"Graph '{G.name()}' node labels: {G.node_labels()}") |
| 294 | +# print(f"Graph '{G.name()}' relationship types: {G.relationship_types()}") |
| 295 | +# print(f"Graph '{G.name()}' relationship count: {G.relationship_count()}") |
| 296 | +# |
| 297 | +# |
| 298 | +# print_graph_info(G_test) |
| 299 | +# |
| 300 | +# target_emb = model.node_emb.weight[rel_id_to_predict].tolist() |
| 301 | +# transe_model = gds.model.transe.create(G_test, "emb", {rel_label_to_predict: target_emb}) |
| 302 | +# |
| 303 | +# source_node_list = ["/m/07l450", "/m/0ds2l81", "/m/0jvt9"] |
| 304 | +# source_ids_df = gds.run_cypher( |
| 305 | +# "UNWIND $node_text_list AS t MATCH (n:Entity) WHERE n.text=t RETURN id(n) as nodeId", |
| 306 | +# params={"node_text_list": source_node_list}, |
| 307 | +# ) |
| 308 | +# |
| 309 | +# result = transe_model.predict_stream( |
| 310 | +# source_node_filter=source_ids_df.nodeId, |
| 311 | +# target_node_filter="Entity", |
| 312 | +# relationship_type=rel_label_to_predict, |
| 313 | +# top_k=3, |
| 314 | +# concurrency=4, |
| 315 | +# ) |
| 316 | +# print(result) |
| 317 | +# |
| 318 | +# ids_in_result = pd.unique(pd.concat([result.sourceNodeId, result.targetNodeId])) |
| 319 | +# |
| 320 | +# ids_to_text = gds.run_cypher( |
| 321 | +# "UNWIND $ids AS id MATCH (n:Entity) WHERE id(n)=id RETURN id(n) AS nodeId, n.text AS tag, n.id AS id", |
| 322 | +# params={"ids": ids_in_result}, |
| 323 | +# ) |
| 324 | +# |
| 325 | +# nodeId_to_text_res = dict(zip(ids_to_text.nodeId, ids_to_text.tag)) |
| 326 | +# nodeId_to_id_res = dict(zip(ids_to_text.nodeId, ids_to_text.id)) |
| 327 | +# |
| 328 | +# result.insert(1, "sourceTag", result.sourceNodeId.map(lambda x: nodeId_to_text_res[x])) |
| 329 | +# result.insert(2, "sourceId", result.sourceNodeId.map(lambda x: nodeId_to_id_res[x])) |
| 330 | +# result.insert(4, "targetTag", result.targetNodeId.map(lambda x: nodeId_to_text_res[x])) |
| 331 | +# result.insert(5, "targetId", result.targetNodeId.map(lambda x: nodeId_to_id_res[x])) |
| 332 | +# |
| 333 | +# print(result) |
| 334 | +# |
| 335 | +# write_relationship_type = "PREDICTED_" + rel_label_to_predict |
| 336 | +# result_write = transe_model.predict_write( |
| 337 | +# source_node_filter=source_ids_df.nodeId, |
| 338 | +# target_node_filter="Entity", |
| 339 | +# relationship_type=rel_label_to_predict, |
| 340 | +# write_relationship_type=write_relationship_type, |
| 341 | +# write_property="transe_score", |
| 342 | +# top_k=3, |
| 343 | +# concurrency=4, |
| 344 | +# ) |
| 345 | +# |
| 346 | +# gds.run_cypher( |
| 347 | +# "MATCH (n)-[r:" |
| 348 | +# + write_relationship_type |
| 349 | +# + "]->(m) RETURN n.id AS sourceId, n.text AS sourceTag, m.id AS targetId, m.text AS targetTag, r.transe_score AS score" |
| 350 | +# ) |
| 351 | +# |
| 352 | +# gds.graph.drop(G_test) |
0 commit comments