44
55from graphdatascience .procedure_surface .api .catalog .graph_api import GraphV2
66from graphdatascience .procedure_surface .api .estimation_result import EstimationResult
7- from graphdatascience .procedure_surface .api .model .graphsage_model import GraphSageModelV2
87from graphdatascience .procedure_surface .api .node_embedding .graphsage_predict_endpoints import (
98 GraphSageMutateResult ,
109 GraphSagePredictEndpoints ,
1110 GraphSageWriteResult ,
1211)
1312from graphdatascience .procedure_surface .api .node_embedding .graphsage_train_endpoints import (
1413 GraphSageTrainEndpoints ,
15- GraphSageTrainResult ,
1614)
1715
1816
19- class GraphSageEndpoints (GraphSageTrainEndpoints , GraphSagePredictEndpoints ):
17+ class GraphSageEndpoints (GraphSagePredictEndpoints ):
2018 """
2119 API for the GraphSage algorithm, combining both training and prediction functionalities.
2220 """
@@ -29,66 +27,74 @@ def __init__(
2927 self ._train_endpoints = train_endpoints
3028 self ._predict_endpoints = predict_endpoints
3129
32- def train (
33- self ,
34- G : GraphV2 ,
35- model_name : str ,
36- feature_properties : list [str ],
37- * ,
38- activation_function : Any | None = None ,
39- negative_sample_weight : int | None = None ,
40- embedding_dimension : int | None = None ,
41- tolerance : float | None = None ,
42- learning_rate : float | None = None ,
43- max_iterations : int | None = None ,
44- sample_sizes : list [int ] | None = None ,
45- aggregator : Any | None = None ,
46- penalty_l2 : float | None = None ,
47- search_depth : int | None = None ,
48- epochs : int | None = None ,
49- projected_feature_dimension : int | None = None ,
50- batch_sampling_ratio : float | None = None ,
51- store_model_to_disk : bool | None = None ,
52- relationship_types : list [str ] | None = None ,
53- node_labels : list [str ] | None = None ,
54- username : str | None = None ,
55- log_progress : bool = True ,
56- sudo : bool | None = None ,
57- concurrency : Any | None = None ,
58- job_id : Any | None = None ,
59- batch_size : int | None = None ,
60- relationship_weight_property : str | None = None ,
61- random_seed : Any | None = None ,
62- ) -> tuple [GraphSageModelV2 , GraphSageTrainResult ]:
63- return self ._train_endpoints .train (
64- G ,
65- model_name ,
66- feature_properties ,
67- activation_function = activation_function ,
68- negative_sample_weight = negative_sample_weight ,
69- embedding_dimension = embedding_dimension ,
70- tolerance = tolerance ,
71- learning_rate = learning_rate ,
72- max_iterations = max_iterations ,
73- sample_sizes = sample_sizes ,
74- aggregator = aggregator ,
75- penalty_l2 = penalty_l2 ,
76- search_depth = search_depth ,
77- epochs = epochs ,
78- projected_feature_dimension = projected_feature_dimension ,
79- batch_sampling_ratio = batch_sampling_ratio ,
80- store_model_to_disk = store_model_to_disk ,
81- relationship_types = relationship_types ,
82- node_labels = node_labels ,
83- username = username ,
84- log_progress = log_progress ,
85- sudo = sudo ,
86- concurrency = concurrency ,
87- job_id = job_id ,
88- batch_size = batch_size ,
89- relationship_weight_property = relationship_weight_property ,
90- random_seed = random_seed ,
91- )
30+ @property
31+ def train (self ) -> GraphSageTrainEndpoints :
32+ """
33+ Trains a GraphSage model on the given graph.
34+
35+ Parameters
36+ ----------
37+ G : GraphV2
38+ The graph to run the algorithm on
39+ model_name : str
40+ Name under which the model will be stored
41+ feature_properties : list[str]
42+ The names of the node properties to use as input features
43+ activation_function : Any | None, default=None
44+ The activation function to apply after each layer
45+ negative_sample_weight : int | None, default=None
46+ Weight of negative samples in the loss function
47+ embedding_dimension : int | None, default=None
48+ The dimension of the generated embeddings
49+ tolerance : float | None, default=None
50+ Tolerance for early stopping based on loss improvement
51+ learning_rate : float | None, default=None
52+ Learning rate for the training optimization
53+ max_iterations : int | None, default=None
54+ Maximum number of training iterations
55+ sample_sizes : list[int] | None, default=None
56+ Number of neighbors to sample at each layer
57+ aggregator : Any | None, default=None
58+ The aggregator function for neighborhood aggregation
59+ penalty_l2 : float | None, default=None
60+ L2 regularization penalty
61+ search_depth : int | None, default=None
62+ Maximum search depth for neighbor sampling
63+ epochs : int | None, default=None
64+ Number of training epochs
65+ projected_feature_dimension : int | None, default=None
66+ Dimension to project input features to before training
67+ batch_sampling_ratio : float | None, default=None
68+ Ratio of nodes to sample for each training batch
69+ store_model_to_disk : bool | None, default=None
70+ Whether to persist the model to disk
71+ relationship_types : list[str] | None, default=None
72+ The relationship types used to select relationships for this algorithm run
73+ node_labels : list[str] | None, default=None
74+ The node labels used to select nodes for this algorithm run
75+ username : str | None = None
76+ The username to attribute the procedure run to
77+ log_progress : bool | None, default=None
78+ Whether to log progress
79+ sudo : bool | None, default=None
80+ Override memory estimation limits
81+ concurrency : Any | None, default=None
82+ The number of concurrent threads
83+ job_id : Any | None, default=None
84+ An identifier for the job
85+ batch_size : int | None, default=None
86+ Batch size for training
87+ relationship_weight_property : str | None, default=None
88+ The property name that contains weight
89+ random_seed : Any | None, default=None
90+ Random seed for reproducible results
91+
92+ Returns
93+ -------
94+ GraphSageModelV2
95+ Trained model
96+ """
97+ return self ._train_endpoints
9298
9399 def stream (
94100 self ,
0 commit comments