66
77
88class GNNNodeClassificationRunner (UncallableNamespace , IllegalAttrChecker ):
9+ def make_graph_sage_config (self , graph_sage_config ):
10+ GRAPH_SAGE_DEFAULT_CONFIG = {"layer_config" : {}, "num_neighbors" : [25 , 10 ], "dropout" : 0.5 ,
11+ "hidden_channels" : 256 }
12+ final_sage_config = GRAPH_SAGE_DEFAULT_CONFIG
13+ if graph_sage_config :
14+ bad_keys = []
15+ for key in graph_sage_config :
16+ if key not in GRAPH_SAGE_DEFAULT_CONFIG :
17+ bad_keys .append (key )
18+ if len (bad_keys ) > 0 :
19+ raise Exception (f"Argument graph_sage_config contains invalid keys { ', ' .join (bad_keys )} ." )
20+
21+ final_sage_config .update (graph_sage_config )
22+ return final_sage_config
23+
924 def train (
1025 self ,
1126 graph_name : str ,
@@ -15,13 +30,15 @@ def train(
1530 relationship_types : List [str ],
1631 target_node_label : str = None ,
1732 node_labels : List [str ] = None ,
33+ graph_sage_config = None
1834 ) -> "Series[Any]" : # noqa: F821
1935 mlConfigMap = {
2036 "featureProperties" : feature_properties ,
2137 "targetProperty" : target_property ,
2238 "job_type" : "train" ,
2339 "nodeProperties" : feature_properties + [target_property ],
24- "relationshipTypes" : relationship_types
40+ "relationshipTypes" : relationship_types ,
41+ "graph_sage_config" : self .make_graph_sage_config (graph_sage_config )
2542 }
2643
2744 if target_node_label :
0 commit comments