Skip to content

Commit 228fc7b

Browse files
committed
Move graphsage endpoints under shared surface
1 parent 701c5ec commit 228fc7b

File tree

10 files changed

+293
-82
lines changed

10 files changed

+293
-82
lines changed

graphdatascience/procedure_surface/api/node_embedding/graphsage_endpoints.py

Lines changed: 69 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -4,19 +4,17 @@
44

55
from graphdatascience.procedure_surface.api.catalog.graph_api import GraphV2
66
from graphdatascience.procedure_surface.api.estimation_result import EstimationResult
7-
from graphdatascience.procedure_surface.api.model.graphsage_model import GraphSageModelV2
87
from graphdatascience.procedure_surface.api.node_embedding.graphsage_predict_endpoints import (
98
GraphSageMutateResult,
109
GraphSagePredictEndpoints,
1110
GraphSageWriteResult,
1211
)
1312
from graphdatascience.procedure_surface.api.node_embedding.graphsage_train_endpoints import (
1413
GraphSageTrainEndpoints,
15-
GraphSageTrainResult,
1614
)
1715

1816

19-
class GraphSageEndpoints(GraphSageTrainEndpoints, GraphSagePredictEndpoints):
17+
class GraphSageEndpoints(GraphSagePredictEndpoints):
2018
"""
2119
API for the GraphSage algorithm, combining both training and prediction functionalities.
2220
"""
@@ -29,66 +27,74 @@ def __init__(
2927
self._train_endpoints = train_endpoints
3028
self._predict_endpoints = predict_endpoints
3129

32-
def train(
33-
self,
34-
G: GraphV2,
35-
model_name: str,
36-
feature_properties: list[str],
37-
*,
38-
activation_function: Any | None = None,
39-
negative_sample_weight: int | None = None,
40-
embedding_dimension: int | None = None,
41-
tolerance: float | None = None,
42-
learning_rate: float | None = None,
43-
max_iterations: int | None = None,
44-
sample_sizes: list[int] | None = None,
45-
aggregator: Any | None = None,
46-
penalty_l2: float | None = None,
47-
search_depth: int | None = None,
48-
epochs: int | None = None,
49-
projected_feature_dimension: int | None = None,
50-
batch_sampling_ratio: float | None = None,
51-
store_model_to_disk: bool | None = None,
52-
relationship_types: list[str] | None = None,
53-
node_labels: list[str] | None = None,
54-
username: str | None = None,
55-
log_progress: bool = True,
56-
sudo: bool | None = None,
57-
concurrency: Any | None = None,
58-
job_id: Any | None = None,
59-
batch_size: int | None = None,
60-
relationship_weight_property: str | None = None,
61-
random_seed: Any | None = None,
62-
) -> tuple[GraphSageModelV2, GraphSageTrainResult]:
63-
return self._train_endpoints.train(
64-
G,
65-
model_name,
66-
feature_properties,
67-
activation_function=activation_function,
68-
negative_sample_weight=negative_sample_weight,
69-
embedding_dimension=embedding_dimension,
70-
tolerance=tolerance,
71-
learning_rate=learning_rate,
72-
max_iterations=max_iterations,
73-
sample_sizes=sample_sizes,
74-
aggregator=aggregator,
75-
penalty_l2=penalty_l2,
76-
search_depth=search_depth,
77-
epochs=epochs,
78-
projected_feature_dimension=projected_feature_dimension,
79-
batch_sampling_ratio=batch_sampling_ratio,
80-
store_model_to_disk=store_model_to_disk,
81-
relationship_types=relationship_types,
82-
node_labels=node_labels,
83-
username=username,
84-
log_progress=log_progress,
85-
sudo=sudo,
86-
concurrency=concurrency,
87-
job_id=job_id,
88-
batch_size=batch_size,
89-
relationship_weight_property=relationship_weight_property,
90-
random_seed=random_seed,
91-
)
30+
@property
31+
def train(self) -> GraphSageTrainEndpoints:
32+
"""
33+
Trains a GraphSage model on the given graph.
34+
35+
Parameters
36+
----------
37+
G : GraphV2
38+
The graph to run the algorithm on
39+
model_name : str
40+
Name under which the model will be stored
41+
feature_properties : list[str]
42+
The names of the node properties to use as input features
43+
activation_function : Any | None, default=None
44+
The activation function to apply after each layer
45+
negative_sample_weight : int | None, default=None
46+
Weight of negative samples in the loss function
47+
embedding_dimension : int | None, default=None
48+
The dimension of the generated embeddings
49+
tolerance : float | None, default=None
50+
Tolerance for early stopping based on loss improvement
51+
learning_rate : float | None, default=None
52+
Learning rate for the training optimization
53+
max_iterations : int | None, default=None
54+
Maximum number of training iterations
55+
sample_sizes : list[int] | None, default=None
56+
Number of neighbors to sample at each layer
57+
aggregator : Any | None, default=None
58+
The aggregator function for neighborhood aggregation
59+
penalty_l2 : float | None, default=None
60+
L2 regularization penalty
61+
search_depth : int | None, default=None
62+
Maximum search depth for neighbor sampling
63+
epochs : int | None, default=None
64+
Number of training epochs
65+
projected_feature_dimension : int | None, default=None
66+
Dimension to project input features to before training
67+
batch_sampling_ratio : float | None, default=None
68+
Ratio of nodes to sample for each training batch
69+
store_model_to_disk : bool | None, default=None
70+
Whether to persist the model to disk
71+
relationship_types : list[str] | None, default=None
72+
The relationship types used to select relationships for this algorithm run
73+
node_labels : list[str] | None, default=None
74+
The node labels used to select nodes for this algorithm run
75+
username : str | None = None
76+
The username to attribute the procedure run to
77+
log_progress : bool | None, default=None
78+
Whether to log progress
79+
sudo : bool | None, default=None
80+
Override memory estimation limits
81+
concurrency : Any | None, default=None
82+
The number of concurrent threads
83+
job_id : Any | None, default=None
84+
An identifier for the job
85+
batch_size : int | None, default=None
86+
Batch size for training
87+
relationship_weight_property : str | None, default=None
88+
The property name that contains weight
89+
random_seed : Any | None, default=None
90+
Random seed for reproducible results
91+
92+
Returns
93+
-------
94+
GraphSageModelV2
95+
Trained model
96+
"""
97+
return self._train_endpoints
9298

9399
def stream(
94100
self,

graphdatascience/procedure_surface/api/node_embedding/graphsage_train_endpoints.py

Lines changed: 44 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,46 @@
55

66
from graphdatascience.procedure_surface.api.base_result import BaseResult
77
from graphdatascience.procedure_surface.api.catalog.graph_api import GraphV2
8+
from graphdatascience.procedure_surface.api.estimation_result import EstimationResult
89
from graphdatascience.procedure_surface.api.model.graphsage_model import GraphSageModelV2
910

1011

1112
class GraphSageTrainEndpoints(ABC):
12-
"""
13-
Abstract base class defining the API for the GraphSage algorithm.
14-
"""
13+
@abstractmethod
14+
def __call__(
15+
self,
16+
G: GraphV2,
17+
model_name: str,
18+
feature_properties: list[str],
19+
*,
20+
activation_function: Any | None = None,
21+
negative_sample_weight: int | None = None,
22+
embedding_dimension: int | None = None,
23+
tolerance: float | None = None,
24+
learning_rate: float | None = None,
25+
max_iterations: int | None = None,
26+
sample_sizes: list[int] | None = None,
27+
aggregator: Any | None = None,
28+
penalty_l2: float | None = None,
29+
search_depth: int | None = None,
30+
epochs: int | None = None,
31+
projected_feature_dimension: int | None = None,
32+
batch_sampling_ratio: float | None = None,
33+
store_model_to_disk: bool | None = None,
34+
relationship_types: list[str] | None = None,
35+
node_labels: list[str] | None = None,
36+
username: str | None = None,
37+
log_progress: bool = True,
38+
sudo: bool | None = None,
39+
concurrency: Any | None = None,
40+
job_id: Any | None = None,
41+
batch_size: int | None = None,
42+
relationship_weight_property: str | None = None,
43+
random_seed: Any | None = None,
44+
) -> tuple[GraphSageModelV2, GraphSageTrainResult]: ...
1545

1646
@abstractmethod
17-
def train(
47+
def estimate(
1848
self,
1949
G: GraphV2,
2050
model_name: str,
@@ -44,9 +74,13 @@ def train(
4474
batch_size: int | None = None,
4575
relationship_weight_property: str | None = None,
4676
random_seed: Any | None = None,
47-
) -> tuple[GraphSageModelV2, GraphSageTrainResult]:
77+
) -> EstimationResult:
4878
"""
49-
Trains a GraphSage model on the given graph.
79+
Estimates memory requirements and other statistics for training a GraphSage model.
80+
81+
This method provides memory estimation for the GraphSage training algorithm without
82+
actually executing the training. It helps determine the computational requirements
83+
before running the actual training procedure.
5084
5185
Parameters
5286
----------
@@ -84,9 +118,9 @@ def train(
84118
Ratio of nodes to sample for each training batch
85119
store_model_to_disk : bool | None, default=None
86120
Whether to persist the model to disk
87-
relationship_types : list[str] | None, default=None
121+
relationship_types : list[str] | None = None
88122
The relationship types used to select relationships for this algorithm run
89-
node_labels : list[str] | None, default=None
123+
node_labels : list[str] | None = None
90124
The node labels used to select nodes for this algorithm run
91125
username : str | None = None
92126
The username to attribute the procedure run to
@@ -107,8 +141,8 @@ def train(
107141
108142
Returns
109143
-------
110-
GraphSageModelV2
111-
Trained model
144+
EstimationResult
145+
The estimation result containing memory requirements and other statistics
112146
"""
113147

114148

graphdatascience/procedure_surface/arrow/node_embedding/graphsage_train_arrow_endpoints.py

Lines changed: 66 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from graphdatascience.arrow_client.authenticated_flight_client import AuthenticatedArrowClient
44
from graphdatascience.arrow_client.v2.remote_write_back_client import RemoteWriteBackClient
55
from graphdatascience.procedure_surface.api.catalog.graph_api import GraphV2
6+
from graphdatascience.procedure_surface.api.estimation_result import EstimationResult
67
from graphdatascience.procedure_surface.api.model.graphsage_model import GraphSageModelV2
78
from graphdatascience.procedure_surface.api.node_embedding.graphsage_train_endpoints import (
89
GraphSageTrainEndpoints,
@@ -29,7 +30,7 @@ def __init__(
2930
)
3031
self._model_api = ModelApiArrow(arrow_client)
3132

32-
def train(
33+
def __call__(
3334
self,
3435
G: GraphV2,
3536
model_name: str,
@@ -100,3 +101,67 @@ def train(
100101
train_result = GraphSageTrainResult(**result)
101102

102103
return model, train_result
104+
105+
def estimate(
106+
self,
107+
G: GraphV2,
108+
model_name: str,
109+
feature_properties: list[str],
110+
*,
111+
activation_function: Any | None = None,
112+
negative_sample_weight: int | None = None,
113+
embedding_dimension: int | None = None,
114+
tolerance: float | None = None,
115+
learning_rate: float | None = None,
116+
max_iterations: int | None = None,
117+
sample_sizes: list[int] | None = None,
118+
aggregator: Any | None = None,
119+
penalty_l2: float | None = None,
120+
search_depth: int | None = None,
121+
epochs: int | None = None,
122+
projected_feature_dimension: int | None = None,
123+
batch_sampling_ratio: float | None = None,
124+
store_model_to_disk: bool | None = None,
125+
relationship_types: list[str] | None = None,
126+
node_labels: list[str] | None = None,
127+
username: str | None = None,
128+
log_progress: bool = True,
129+
sudo: bool | None = None,
130+
concurrency: Any | None = None,
131+
job_id: Any | None = None,
132+
batch_size: int | None = None,
133+
relationship_weight_property: str | None = None,
134+
random_seed: Any | None = None,
135+
) -> EstimationResult:
136+
return self._node_property_endpoints.estimate(
137+
estimate_endpoint="v2/embeddings.graphSage.train.estimate",
138+
G=G,
139+
algo_config=self._node_property_endpoints.create_estimate_config(
140+
model_name=model_name,
141+
feature_properties=feature_properties,
142+
activation_function=activation_function,
143+
negative_sample_weight=negative_sample_weight,
144+
embedding_dimension=embedding_dimension,
145+
tolerance=tolerance,
146+
learning_rate=learning_rate,
147+
max_iterations=max_iterations,
148+
sample_sizes=sample_sizes,
149+
aggregator=aggregator,
150+
penalty_l2=penalty_l2,
151+
search_depth=search_depth,
152+
epochs=epochs,
153+
projected_feature_dimension=projected_feature_dimension,
154+
batch_sampling_ratio=batch_sampling_ratio,
155+
store_model_to_disk=store_model_to_disk,
156+
relationship_types=relationship_types,
157+
node_labels=node_labels,
158+
username=username,
159+
log_progress=log_progress,
160+
sudo=sudo,
161+
concurrency=concurrency,
162+
job_id=job_id,
163+
batch_size=batch_size,
164+
relationship_weight_property=relationship_weight_property,
165+
random_seed=random_seed,
166+
),
167+
)

0 commit comments

Comments
 (0)