Skip to content

Commit 701c5ec

Browse files
committed
Fix coverage test
* move graphsage endpoints behind common interface * add hdbscan as missing
1 parent c23873c commit 701c5ec

File tree

4 files changed

+238
-39
lines changed

4 files changed

+238
-39
lines changed
Lines changed: 203 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,203 @@
1+
from typing import Any
2+
3+
from pandas import DataFrame
4+
5+
from graphdatascience.procedure_surface.api.catalog.graph_api import GraphV2
6+
from graphdatascience.procedure_surface.api.estimation_result import EstimationResult
7+
from graphdatascience.procedure_surface.api.model.graphsage_model import GraphSageModelV2
8+
from graphdatascience.procedure_surface.api.node_embedding.graphsage_predict_endpoints import (
9+
GraphSageMutateResult,
10+
GraphSagePredictEndpoints,
11+
GraphSageWriteResult,
12+
)
13+
from graphdatascience.procedure_surface.api.node_embedding.graphsage_train_endpoints import (
14+
GraphSageTrainEndpoints,
15+
GraphSageTrainResult,
16+
)
17+
18+
19+
class GraphSageEndpoints(GraphSageTrainEndpoints, GraphSagePredictEndpoints):
20+
"""
21+
API for the GraphSage algorithm, combining both training and prediction functionalities.
22+
"""
23+
24+
def __init__(
25+
self,
26+
train_endpoints: GraphSageTrainEndpoints,
27+
predict_endpoints: GraphSagePredictEndpoints,
28+
) -> None:
29+
self._train_endpoints = train_endpoints
30+
self._predict_endpoints = predict_endpoints
31+
32+
def train(
33+
self,
34+
G: GraphV2,
35+
model_name: str,
36+
feature_properties: list[str],
37+
*,
38+
activation_function: Any | None = None,
39+
negative_sample_weight: int | None = None,
40+
embedding_dimension: int | None = None,
41+
tolerance: float | None = None,
42+
learning_rate: float | None = None,
43+
max_iterations: int | None = None,
44+
sample_sizes: list[int] | None = None,
45+
aggregator: Any | None = None,
46+
penalty_l2: float | None = None,
47+
search_depth: int | None = None,
48+
epochs: int | None = None,
49+
projected_feature_dimension: int | None = None,
50+
batch_sampling_ratio: float | None = None,
51+
store_model_to_disk: bool | None = None,
52+
relationship_types: list[str] | None = None,
53+
node_labels: list[str] | None = None,
54+
username: str | None = None,
55+
log_progress: bool = True,
56+
sudo: bool | None = None,
57+
concurrency: Any | None = None,
58+
job_id: Any | None = None,
59+
batch_size: int | None = None,
60+
relationship_weight_property: str | None = None,
61+
random_seed: Any | None = None,
62+
) -> tuple[GraphSageModelV2, GraphSageTrainResult]:
63+
return self._train_endpoints.train(
64+
G,
65+
model_name,
66+
feature_properties,
67+
activation_function=activation_function,
68+
negative_sample_weight=negative_sample_weight,
69+
embedding_dimension=embedding_dimension,
70+
tolerance=tolerance,
71+
learning_rate=learning_rate,
72+
max_iterations=max_iterations,
73+
sample_sizes=sample_sizes,
74+
aggregator=aggregator,
75+
penalty_l2=penalty_l2,
76+
search_depth=search_depth,
77+
epochs=epochs,
78+
projected_feature_dimension=projected_feature_dimension,
79+
batch_sampling_ratio=batch_sampling_ratio,
80+
store_model_to_disk=store_model_to_disk,
81+
relationship_types=relationship_types,
82+
node_labels=node_labels,
83+
username=username,
84+
log_progress=log_progress,
85+
sudo=sudo,
86+
concurrency=concurrency,
87+
job_id=job_id,
88+
batch_size=batch_size,
89+
relationship_weight_property=relationship_weight_property,
90+
random_seed=random_seed,
91+
)
92+
93+
def stream(
94+
self,
95+
G: GraphV2,
96+
model_name: str,
97+
*,
98+
relationship_types: list[str] | None = None,
99+
node_labels: list[str] | None = None,
100+
username: str | None = None,
101+
log_progress: bool = True,
102+
sudo: bool | None = None,
103+
concurrency: Any | None = None,
104+
job_id: Any | None = None,
105+
batch_size: int | None = None,
106+
) -> DataFrame:
107+
return self._predict_endpoints.stream(
108+
G,
109+
model_name,
110+
relationship_types=relationship_types,
111+
node_labels=node_labels,
112+
username=username,
113+
log_progress=log_progress,
114+
sudo=sudo,
115+
concurrency=concurrency,
116+
job_id=job_id,
117+
batch_size=batch_size,
118+
)
119+
120+
def write(
121+
self,
122+
G: GraphV2,
123+
model_name: str,
124+
write_property: str,
125+
*,
126+
relationship_types: list[str] | None = None,
127+
node_labels: list[str] | None = None,
128+
username: str | None = None,
129+
log_progress: bool = True,
130+
sudo: bool | None = None,
131+
concurrency: Any | None = None,
132+
write_concurrency: int | None = None,
133+
job_id: Any | None = None,
134+
batch_size: int | None = None,
135+
) -> GraphSageWriteResult:
136+
return self._predict_endpoints.write(
137+
G,
138+
model_name,
139+
write_property,
140+
relationship_types=relationship_types,
141+
node_labels=node_labels,
142+
username=username,
143+
log_progress=log_progress,
144+
sudo=sudo,
145+
concurrency=concurrency,
146+
write_concurrency=write_concurrency,
147+
job_id=job_id,
148+
batch_size=batch_size,
149+
)
150+
151+
def mutate(
152+
self,
153+
G: GraphV2,
154+
model_name: str,
155+
mutate_property: str,
156+
relationship_types: list[str] | None = None,
157+
node_labels: list[str] | None = None,
158+
username: str | None = None,
159+
log_progress: bool = True,
160+
sudo: bool | None = None,
161+
concurrency: Any | None = None,
162+
job_id: Any | None = None,
163+
batch_size: int | None = None,
164+
) -> GraphSageMutateResult:
165+
return self._predict_endpoints.mutate(
166+
G,
167+
model_name,
168+
mutate_property,
169+
relationship_types=relationship_types,
170+
node_labels=node_labels,
171+
username=username,
172+
log_progress=log_progress,
173+
sudo=sudo,
174+
concurrency=concurrency,
175+
job_id=job_id,
176+
batch_size=batch_size,
177+
)
178+
179+
def estimate(
180+
self,
181+
G: GraphV2 | dict[str, Any],
182+
model_name: str,
183+
relationship_types: list[str] | None = None,
184+
node_labels: list[str] | None = None,
185+
batch_size: int | None = None,
186+
concurrency: int | None = None,
187+
log_progress: bool = True,
188+
username: str | None = None,
189+
sudo: bool | None = None,
190+
job_id: str | None = None,
191+
) -> EstimationResult:
192+
return self._predict_endpoints.estimate(
193+
G,
194+
model_name,
195+
relationship_types=relationship_types,
196+
node_labels=node_labels,
197+
batch_size=batch_size,
198+
concurrency=concurrency,
199+
log_progress=log_progress,
200+
username=username,
201+
sudo=sudo,
202+
job_id=job_id,
203+
)

graphdatascience/procedure_surface/api/node_embedding/graphsage_predict_endpoints.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def write(
7979
job_id: Any | None = None,
8080
batch_size: int | None = None,
8181
) -> GraphSageWriteResult:
82-
""" "
82+
"""
8383
Uses a pre-trained GraphSage model to predict embeddings for a graph and writes the results back to the database.
8484
8585
Parameters
@@ -130,7 +130,7 @@ def mutate(
130130
job_id: Any | None = None,
131131
batch_size: int | None = None,
132132
) -> GraphSageMutateResult:
133-
""" "
133+
"""
134134
Uses a pre-trained GraphSage model to predict embeddings for a graph and writes the results back to the graph as a node property.
135135
136136
Parameters

graphdatascience/session/session_v2_endpoints.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
)
1414
from graphdatascience.procedure_surface.api.community.sllpa_endpoints import SllpaEndpoints
1515
from graphdatascience.procedure_surface.api.community.triangle_count_endpoints import TriangleCountEndpoints
16+
from graphdatascience.procedure_surface.api.node_embedding.graphsage_endpoints import GraphSageEndpoints
1617
from graphdatascience.procedure_surface.arrow.catalog_arrow_endpoints import CatalogArrowEndpoints
1718
from graphdatascience.procedure_surface.arrow.centrality.articlerank_arrow_endpoints import ArticleRankArrowEndpoints
1819
from graphdatascience.procedure_surface.arrow.centrality.articulationpoints_arrow_endpoints import (
@@ -122,15 +123,14 @@ def fast_rp(self) -> FastRPArrowEndpoints:
122123
return FastRPArrowEndpoints(self._arrow_client, self._write_back_client, show_progress=self._show_progress)
123124

124125
@property
125-
def graphsage_predict(self) -> GraphSagePredictArrowEndpoints:
126-
return GraphSagePredictArrowEndpoints(
127-
self._arrow_client, self._write_back_client, show_progress=self._show_progress
128-
)
129-
130-
@property
131-
def graphsage_train(self) -> GraphSageTrainArrowEndpoints:
132-
return GraphSageTrainArrowEndpoints(
133-
self._arrow_client, self._write_back_client, show_progress=self._show_progress
126+
def graph_sage(self) -> GraphSageEndpoints:
127+
return GraphSageEndpoints(
128+
train_endpoints=GraphSageTrainArrowEndpoints(
129+
self._arrow_client, self._write_back_client, show_progress=self._show_progress
130+
),
131+
predict_endpoints=GraphSagePredictArrowEndpoints(
132+
self._arrow_client, self._write_back_client, show_progress=self._show_progress
133+
),
134134
)
135135

136136
@property
@@ -165,6 +165,7 @@ def label_propagation(self) -> LabelPropagationEndpoints:
165165
self._arrow_client, self._write_back_client, show_progress=self._show_progress
166166
)
167167

168+
@property
168169
def leiden(self) -> LeidenEndpoints:
169170
return LeidenArrowEndpoints(self._arrow_client, self._write_back_client, show_progress=self._show_progress)
170171

graphdatascience/tests/integrationV2/procedure_surface/session/test_session_endpoint_coverage.py

Lines changed: 23 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,17 @@
1+
import re
12
from collections import defaultdict
23

34
import pytest
4-
from pydantic.alias_generators import to_snake
55

66
from graphdatascience import QueryRunner, ServerVersion
77
from graphdatascience.arrow_client.authenticated_flight_client import AuthenticatedArrowClient
88
from graphdatascience.session.aura_graph_data_science import AuraGraphDataScience
99
from graphdatascience.session.session_v2_endpoints import SessionV2Endpoints
1010

1111
MISSING_ALGO_ENDPOINTS = {
12-
"embeddings.graphSage.train.estimate", # TODO fix this by moving behind shared interface
13-
"embeddings.graphSage.estimate",
12+
"embeddings.graphSage.train.estimate",
13+
"community.hdbscan",
14+
"community.hdbscan.estimate",
1415
"similarity.knn.filtered",
1516
"similarity.knn.filtered.estimate",
1617
"similarity.nodeSimilarity.filtered",
@@ -44,37 +45,21 @@
4445
# centrality algos
4546
"betweenness": "betweenness_centrality",
4647
"celf": "influence_maximization_celf",
47-
"celf.estimate": "influence_maximization_celf.estimate",
4848
"closeness": "closeness_centrality",
49-
"closeness.estimate": "closeness_centrality.estimate",
5049
"degree": "degree_centrality",
51-
"degree.estimate": "degree_centrality.estimate",
5250
"eigenvector": "eigenvector_centrality",
53-
"eigenvector.estimate": "eigenvector_centrality.estimate",
5451
"harmonic": "harmonic_centrality",
55-
"harmonic.estimate": "harmonic_centrality.estimate",
5652
"localClusteringCoefficient": "local_clustering_coefficient",
57-
"localClusteringCoefficient.estimate": "local_clustering_coefficient.estimate",
5853
# community algos
54+
"cliquecounting": "clique_counting",
5955
"k1coloring": "k1_coloring",
60-
"k1coloring.estimate": "k1_coloring.estimate",
6156
"kcore": "k_core_decomposition",
62-
"kcore.estimate": "k_core_decomposition.estimate",
6357
"maxkcut": "max_k_cut",
64-
"maxkcut.estimate": "max_k_cut.estimate",
6558
"modularityOptimization": "modularity_optimization",
66-
"modularityOptimization.estimate": "modularity_optimization.estimate",
67-
"sllpa": "sllpa",
68-
"sllpa.estimate": "sllpa.estimate",
69-
"triangleCount": "triangle_count",
70-
"triangleCount.estimate": "triangle_count.estimate",
7159
# embedding algos
7260
"fastrp": "fast_rp",
73-
"fastrp.estimate": "fast_rp.estimate",
74-
"graphSage": "graphsage_predict",
75-
"graphSage.train": "graphsage_train",
61+
"graphSage": "graphsage",
7662
"hashgnn": "hash_gnn",
77-
"hashgnn.estimate": "hash_gnn.estimate",
7863
}
7964

8065

@@ -88,13 +73,24 @@ def gds(arrow_client: AuthenticatedArrowClient, db_query_runner: QueryRunner) ->
8873
)
8974

9075

76+
def to_snake(camel: str) -> str:
77+
# adjusted version of pydantic.alias_generators.to_snake (without digit handling)
78+
79+
# Handle the sequence of uppercase letters followed by a lowercase letter
80+
snake = re.sub(r"([A-Z]+)([A-Z][a-z])", lambda m: f"{m.group(1)}_{m.group(2)}", camel)
81+
# Insert an underscore between a lowercase letter and an uppercase letter
82+
snake = re.sub(r"([a-z])([A-Z])", lambda m: f"{m.group(1)}_{m.group(2)}", snake)
83+
# Replace hyphens with underscores to handle kebab-case
84+
snake = snake.replace("-", "_")
85+
return snake.lower()
86+
87+
9188
def check_gds_v2_availability(endpoints: SessionV2Endpoints, algo: str) -> bool:
9289
"""Check if an algorithm is available through gds.v2 interface"""
9390

94-
algo = ENDPOINT_MAPPINGS.get(algo, algo)
95-
9691
algo_parts = algo.split(".")
9792
algo_parts = [to_snake(part) for part in algo_parts]
93+
algo_parts = [ENDPOINT_MAPPINGS.get(part, part) for part in algo_parts]
9894

9995
callable_object = endpoints
10096
for algo_part in algo_parts:
@@ -110,7 +106,6 @@ def check_gds_v2_availability(endpoints: SessionV2Endpoints, algo: str) -> bool:
110106

111107
@pytest.mark.db_integration
112108
def test_algo_coverage(gds: AuraGraphDataScience) -> None:
113-
"""Test that all available Arrow actions are accessible through gds.v2"""
114109
arrow_client = gds.v2._arrow_client
115110

116111
# Get all available Arrow actions
@@ -151,9 +146,9 @@ def test_algo_coverage(gds: AuraGraphDataScience) -> None:
151146
print(f"Available through gds.v2: {len(available_endpoints)}")
152147

153148
# check if any previously missing algos are now available
154-
assert not available_endpoints.intersection(MISSING_ALGO_ENDPOINTS), (
155-
"Endpoints now available, please remove from MISSING_ALGO_ENDPOINTS"
156-
)
149+
newly_available_endpoints = available_endpoints.intersection(MISSING_ALGO_ENDPOINTS)
150+
assert not newly_available_endpoints, "Endpoints now available, please remove from MISSING_ALGO_ENDPOINTS"
157151

158152
# check missing endpoints against known missing algos
159-
assert missing_endpoints.difference(MISSING_ALGO_ENDPOINTS), "Unexpectedly missing endpoints"
153+
missing_endpoints = missing_endpoints.difference(MISSING_ALGO_ENDPOINTS)
154+
assert not missing_endpoints, f"Unexpectedly missing endpoints {len(missing_endpoints)}"

0 commit comments

Comments
 (0)