@@ -12,6 +12,7 @@ def train(
1212 model_name : str ,
1313 feature_properties : List [str ],
1414 target_property : str ,
15+ relationship_types : List [str ],
1516 target_node_label : str = None ,
1617 node_labels : List [str ] = None ,
1718 ) -> "Series[Any]" : # noqa: F821
@@ -20,6 +21,7 @@ def train(
2021 "targetProperty" : target_property ,
2122 "job_type" : "train" ,
2223 "nodeProperties" : feature_properties + [target_property ],
24+ "relationshipTypes" : relationship_types
2325 }
2426
2527 if target_node_label :
@@ -31,10 +33,9 @@ def train(
3133
3234 # token and uri will be injected by arrow_query_runner
3335 self ._query_runner .run_query (
34- "CALL gds.upload.graph($graph_name, $ config)" ,
36+ "CALL gds.upload.graph($config)" ,
3537 params = {
36- "graph_name" : graph_name ,
37- "config" : {"mlTrainingConfig" : mlTrainingConfig , "modelName" : model_name },
38+ "config" : {"mlTrainingConfig" : mlTrainingConfig , "graphName" : graph_name , "modelName" : model_name },
3839 },
3940 )
4041
@@ -43,13 +44,15 @@ def predict(
4344 graph_name : str ,
4445 model_name : str ,
4546 feature_properties : List [str ],
47+ relationship_types : List [str ],
4648 target_node_label : str = None ,
4749 node_labels : List [str ] = None ,
4850 ) -> "Series[Any]" : # noqa: F821
4951 mlConfigMap = {
5052 "featureProperties" : feature_properties ,
5153 "job_type" : "predict" ,
5254 "nodeProperties" : feature_properties ,
55+ "relationshipTypes" : relationship_types
5356 }
5457 if target_node_label :
5558 mlConfigMap ["targetNodeLabel" ] = target_node_label
@@ -58,9 +61,8 @@ def predict(
5861
5962 mlTrainingConfig = json .dumps (mlConfigMap )
6063 self ._query_runner .run_query (
61- "CALL gds.upload.graph($graph_name, $ config)" ,
64+ "CALL gds.upload.graph($config)" ,
6265 params = {
63- "graph_name" : graph_name ,
64- "config" : {"mlTrainingConfig" : mlTrainingConfig , "modelName" : model_name },
66+ "config" : {"mlTrainingConfig" : mlTrainingConfig , "graphName" : graph_name , "modelName" : model_name },
6567 },
6668 ) # type: ignore
0 commit comments