@@ -11,11 +11,26 @@ def train(self, graph_name: str, model_name: str, feature_properties: List[str],
1111 configMap = {
1212 "featureProperties" : feature_properties ,
1313 "targetProperty" : target_property ,
14+ "job_type" : "train" ,
1415 }
1516 node_properties = feature_properties + [target_property ]
1617 if target_node_label :
1718 configMap ["targetNodeLabel" ] = target_node_label
1819 mlTrainingConfig = json .dumps (configMap )
19- # TODO query avaiable node labels
20+ # TODO query available node labels
2021 node_labels = ["Paper" ] if not node_labels else node_labels
2122 self ._query_runner .run_query (f"CALL gds.upload.graph('{ graph_name } ', {{mlTrainingConfig: '{ mlTrainingConfig } ', modelName: '{ model_name } ', nodeLabels: { node_labels } , nodeProperties: { node_properties } }})" )
23+
24+
25+ def predict (self , graph_name : str , model_name : str , feature_properties : List [str ], target_node_label : str = None , node_labels : List [str ] = None ) -> "Series[Any]" :
26+ configMap = {
27+ "featureProperties" : feature_properties ,
28+ "job_type" : "predict" ,
29+ }
30+ if target_node_label :
31+ configMap ["targetNodeLabel" ] = target_node_label
32+ mlTrainingConfig = json .dumps (configMap )
33+ # TODO query available node labels
34+ node_labels = ["Paper" ] if not node_labels else node_labels
35+ self ._query_runner .run_query (
36+ f"CALL gds.upload.graph('{ graph_name } ', {{mlTrainingConfig: '{ mlTrainingConfig } ', modelName: '{ model_name } ', nodeLabels: { node_labels } , nodeProperties: { feature_properties } }})" )
0 commit comments