|
6 | 6 |
|
7 | 7 |
|
8 | 8 | class GNNNodeClassificationRunner(UncallableNamespace, IllegalAttrChecker): |
| 9 | + def make_graph_sage_config(self, graph_sage_config): |
| 10 | + GRAPH_SAGE_DEFAULT_CONFIG = {"layer_config": {}, "num_neighbors": [25, 10], "dropout": 0.5, |
| 11 | + "hidden_channels": 256} |
| 12 | + final_sage_config = GRAPH_SAGE_DEFAULT_CONFIG |
| 13 | + if graph_sage_config: |
| 14 | + bad_keys = [] |
| 15 | + for key in graph_sage_config: |
| 16 | + if key not in GRAPH_SAGE_DEFAULT_CONFIG: |
| 17 | + bad_keys.append(key) |
| 18 | + if len(bad_keys) > 0: |
| 19 | + raise Exception(f"Argument graph_sage_config contains invalid keys {', '.join(bad_keys)}.") |
| 20 | + |
| 21 | + final_sage_config.update(graph_sage_config) |
| 22 | + return final_sage_config |
| 23 | + |
9 | 24 | def train( |
10 | | - self, |
11 | | - graph_name: str, |
12 | | - model_name: str, |
13 | | - feature_properties: List[str], |
14 | | - target_property: str, |
15 | | - relationship_types: List[str], |
16 | | - target_node_label: str = None, |
17 | | - node_labels: List[str] = None, |
| 25 | + self, |
| 26 | + graph_name: str, |
| 27 | + model_name: str, |
| 28 | + feature_properties: List[str], |
| 29 | + target_property: str, |
| 30 | + relationship_types: List[str], |
| 31 | + target_node_label: str = None, |
| 32 | + node_labels: List[str] = None, |
| 33 | + graph_sage_config = None |
18 | 34 | ) -> "Series[Any]": # noqa: F821 |
| 35 | + |
19 | 36 | mlConfigMap = { |
20 | 37 | "featureProperties": feature_properties, |
21 | 38 | "targetProperty": target_property, |
22 | 39 | "job_type": "train", |
23 | 40 | "nodeProperties": feature_properties + [target_property], |
24 | | - "relationshipTypes": relationship_types |
| 41 | + "relationshipTypes": relationship_types, |
| 42 | + "graph_sage_config": self.make_graph_sage_config(graph_sage_config) |
25 | 43 | } |
26 | 44 |
|
27 | 45 | if target_node_label: |
|
0 commit comments