@@ -117,8 +117,6 @@ def put_data_in_db(gds):
117117 for rel_type in dataset [rel_split ]:
118118 edges = dataset [rel_split ][rel_type ]
119119
120- # MERGE (n)-[:{rel_type} {{text:l.rel_text}}]->(m)
121- # MERGE (n)-[:{rel_split}]->(m)
122120 gds .run_cypher (
123121 f"""
124122 UNWIND $ll as l
@@ -156,33 +154,6 @@ def project_train_graph(gds):
156154 return G_train
157155
158156
159- def project_predict_graph (gds ):
160- all_rels = gds .run_cypher (
161- """
162- CALL db.relationshipTypes() YIELD relationshipType
163- """
164- )
165- all_rels = all_rels ["relationshipType" ].to_list ()
166- rel_spec = {}
167- for rel in all_rels :
168- if rel .startswith ("REL_" ):
169- rel_spec [rel ] = {"properties" : ["split" ]}
170-
171- gds .graph .drop ("fullGraph" , failIfMissing = False )
172- gds .graph .drop ("predictGraph" , failIfMissing = False )
173-
174- # {"REL": {"properties": ["relY"]}, "RELR": {"properties": ["relY"]}}
175- # print(rel_spec)
176-
177- G_full , result = gds .graph .project ("fullGraph" , ["Entity" ], all_rels )
178-
179- G_full , result = gds .graph .project ("fullGraph" , ["Entity" ], rel_spec )
180- # G_predict = gds.graph.filter('predictGraph', 'fullGraph', '*', 'r.split == 2')
181-
182- inspect_graph (G_full )
183- return G_full
184-
185-
186157def inspect_graph (G ):
187158 func_names = [
188159 "name" ,
@@ -200,47 +171,43 @@ def inspect_graph(G):
200171 create_constraint (gds )
201172 put_data_in_db (gds )
202173 G_train = project_train_graph (gds )
203- # G_predict = project_predict_graph(gds)
204- # inspect_graph(G_train)
205174
206175 gds .set_compute_cluster_ip ("localhost" )
207176
208177 print (gds .debug .arrow ())
209178
210179 model_name = "dummyModelName_" + str (time .time ())
211180
212- gds .kge .model .train (
181+ node_id_text = gds .find_node_id (["Entity" ], {"text" : "/m/016wzw" })
182+ node_id_2 = gds .find_node_id (["Entity" ], {"id" : 2 })
183+ node_id_3 = gds .find_node_id (["Entity" ], {"id" : 3 })
184+ node_id_0 = gds .find_node_id (["Entity" ], {"id" : 0 })
185+
186+ res = gds .kge .model .train (
213187 G_train ,
214188 model_name = model_name ,
215- scoring_function = "DistMult " ,
189+ scoring_function = "distmult " ,
216190 num_epochs = 1 ,
217191 embedding_dimension = 10 ,
218192 epochs_per_checkpoint = 0 ,
219193 )
194+ print (res ['metrics' ])
220195
221- gds .kge .model .predict (
222- G_train ,
196+ res = gds .kge .model .predict (
223197 model_name = model_name ,
224198 top_k = 10 ,
225- node_ids = [1 , 2 , 3 ],
199+ node_ids = [node_id_3 , node_id_2 , node_id_text ],
226200 rel_types = ["REL_1" , "REL_2" ],
227201 )
202+ print (res .to_string ())
228203
229- gds .kge .model .predict_tail (
230- G_train ,
231- model_name = model_name ,
232- top_k = 10 ,
233- node_ids = [gds .find_node_id (["Entity" ], {"text" : "/m/016wzw" }), gds .find_node_id (["Entity" ], {"id" : 2 })],
234- rel_types = ["REL_1" , "REL_2" ],
235- )
236-
237- gds .kge .model .score_triples (
238- G_train ,
204+ scores = gds .kge .model .score_triplets (
239205 model_name = model_name ,
240- triples = [
241- (gds . find_node_id ([ "Entity" ], { "text" : "/m/016wzw" }), " REL_1" , gds . find_node_id ([ "Entity" ], { "id" : 2 }) ),
242- (gds . find_node_id ([ "Entity" ], { "id" : 0 }), " REL_123" , gds . find_node_id ([ "Entity" ], { "id" : 3 }) ),
206+ triplets = [
207+ (node_id_2 , " REL_1" , node_id_text ),
208+ (node_id_0 , " REL_123" , node_id_3 ),
243209 ],
244210 )
211+ print (scores )
245212
246213 print ("Finished training" )
0 commit comments