1- import json
21from typing import Any , List
32
43from ..error .illegal_attr_checker import IllegalAttrChecker
54from ..error .uncallable_namespace import UncallableNamespace
5+ import json
66
77
88class GNNNodeClassificationRunner (UncallableNamespace , IllegalAttrChecker ):
9- def train (
10- self ,
11- graph_name : str ,
12- model_name : str ,
13- feature_properties : List [str ],
14- target_property : str ,
15- target_node_label : str = None ,
16- node_labels : List [str ] = None ,
17- ) -> "Series[Any]" :
9+ def train (self , graph_name : str , model_name : str , feature_properties : List [str ], target_property : str ,
10+ target_node_label : str = None , node_labels : List [str ] = None ) -> "Series[Any]" :
1811 configMap = {
1912 "featureProperties" : feature_properties ,
2013 "targetProperty" : target_property ,
@@ -27,18 +20,21 @@ def train(
2720 mlTrainingConfig = json .dumps (configMap )
2821 # TODO query available node labels
2922 node_labels = ["Paper" ] if not node_labels else node_labels
23+
24+ # use arrow direclty here
3025 self ._query_runner .run_query (
31- f"CALL gds.upload.graph('{ graph_name } ', {{mlTrainingConfig: '{ mlTrainingConfig } ', modelName: '{ model_name } ', nodeLabels: { node_labels } , nodeProperties: { node_properties } }})"
32- )
26+ "CALL gds.upload.graph($config)" ,
27+ params = {"config" :
28+ {"graph_name" : graph_name ,
29+ "mlTrainingConfig" : mlTrainingConfig ,
30+ "modelName" : model_name ,
31+ "nodeLabels" : node_labels ,
32+ "nodeProperties" : node_properties
33+ }
34+ })
35+
3336
34- def predict (
35- self ,
36- graph_name : str ,
37- model_name : str ,
38- feature_properties : List [str ],
39- target_node_label : str = None ,
40- node_labels : List [str ] = None ,
41- ) -> "Series[Any]" :
37+ def predict (self , graph_name : str , model_name : str , feature_properties : List [str ], target_node_label : str = None , node_labels : List [str ] = None ) -> "Series[Any]" :
4238 configMap = {
4339 "featureProperties" : feature_properties ,
4440 "job_type" : "predict" ,
@@ -49,5 +45,4 @@ def predict(
4945 # TODO query available node labels
5046 node_labels = ["Paper" ] if not node_labels else node_labels
5147 self ._query_runner .run_query (
52- f"CALL gds.upload.graph('{ graph_name } ', {{mlTrainingConfig: '{ mlTrainingConfig } ', modelName: '{ model_name } ', nodeLabels: { node_labels } , nodeProperties: { feature_properties } }})"
53- )
48+ f"CALL gds.upload.graph('{ graph_name } ', {{mlTrainingConfig: '{ mlTrainingConfig } ', modelName: '{ model_name } ', nodeLabels: { node_labels } , nodeProperties: { feature_properties } }})" )
0 commit comments