Skip to content

Commit be27c44

Browse files
committed
Test coverage based on available actions
ref GDSA-144
1 parent bc1bb88 commit be27c44

File tree

2 files changed

+172
-0
lines changed

2 files changed

+172
-0
lines changed

graphdatascience/arrow_client/authenticated_flight_client.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from pyarrow import flight
1010
from pyarrow._flight import (
1111
Action,
12+
ActionType,
1213
FlightInternalError,
1314
FlightStreamReader,
1415
FlightTimedOutError,
@@ -192,6 +193,9 @@ def run_with_retry() -> Iterator[Result]:
192193

193194
return run_with_retry()
194195

196+
def list_actions(self) -> set[ActionType]:
197+
return self._flight_client.list_actions() # type: ignore
198+
195199
def _instantiate_flight_client(self) -> flight.FlightClient:
196200
location = (
197201
flight.Location.for_grpc_tls(self._host, self._port)
Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
from collections import defaultdict
2+
3+
import pytest
4+
from pydantic.alias_generators import to_snake
5+
6+
from graphdatascience import QueryRunner, ServerVersion
7+
from graphdatascience.arrow_client.authenticated_flight_client import AuthenticatedArrowClient
8+
from graphdatascience.session.aura_graph_data_science import AuraGraphDataScience
9+
from graphdatascience.session.session_v2_endpoints import SessionV2Endpoints
10+
11+
MISSING_ALGO_ENDPOINTS = {
12+
"community.kmeans",
13+
"community.cliquecounting",
14+
"community.maxkcut",
15+
"community.cliquecounting.estimate",
16+
"community.labelPropagation.estimate",
17+
"community.maxkcut.estimate",
18+
"community.k1coloring",
19+
"community.triangleCount.estimate",
20+
"community.kmeans.estimate",
21+
"community.leiden",
22+
"community.sllpa.estimate",
23+
"community.modularityOptimization",
24+
"community.sllpa",
25+
"community.localClusteringCoefficient",
26+
"community.modularityOptimization.estimate",
27+
"community.labelPropagation",
28+
"community.localClusteringCoefficient.estimate",
29+
"community.leiden.estimate",
30+
"community.triangleCount",
31+
"embeddings.graphSage.train.estimate", # TODO fix this by moving behind shared interface
32+
"embeddings.graphSage.estimate",
33+
"similarity.knn.filtered",
34+
"similarity.knn.filtered.estimate",
35+
"similarity.nodeSimilarity.filtered",
36+
"similarity.nodeSimilarity.filtered.estimate",
37+
"similarity.nodeSimilarity",
38+
"similarity.knn",
39+
"similarity.nodeSimilarity.estimate",
40+
"similarity.knn.estimate",
41+
"pathfinding.sourceTarget.dijkstra.estimate",
42+
"pathfinding.sourceTarget.aStar",
43+
"pathfinding.prizeSteinerTree.estimate",
44+
"pathfinding.sourceTarget.yens",
45+
"pathfinding.singleSource.deltaStepping.estimate",
46+
"pathfinding.singleSource.deltaStepping",
47+
"pathfinding.steinerTree",
48+
"pathfinding.singleSource.dijkstra",
49+
"pathfinding.singleSource.bellmanFord",
50+
"pathfinding.steinerTree.estimate",
51+
"pathfinding.singleSource.bellmanFord.estimate",
52+
"pathfinding.singleSource.dijkstra.estimate",
53+
"pathfinding.prizeSteinerTree",
54+
"pathfinding.spanningTree.estimate",
55+
"pathfinding.sourceTarget.dijkstra",
56+
"pathfinding.kSpanningTree",
57+
"pathfinding.spanningTree",
58+
"pathfinding.sourceTarget.aStar.estimate",
59+
"pathfinding.sourceTarget.yens.estimate",
60+
}
61+
62+
ENDPOINT_MAPPINGS = {
63+
# centrality algos
64+
"betweenness": "betweenness_centrality",
65+
"celf": "influence_maximization_celf",
66+
"celf.estimate": "influence_maximization_celf.estimate",
67+
"closeness": "closeness_centrality",
68+
"closeness.estimate": "closeness_centrality.estimate",
69+
"degree": "degree_centrality",
70+
"degree.estimate": "degree_centrality.estimate",
71+
"eigenvector": "eigenvector_centrality",
72+
"eigenvector.estimate": "eigenvector_centrality.estimate",
73+
"harmonic": "harmonic_centrality",
74+
"harmonic.estimate": "harmonic_centrality.estimate",
75+
# community algos
76+
"k1coloring": "k1_coloring",
77+
"k1coloring.estimate": "k1_coloring.estimate",
78+
"kcore": "k_core_decomposition",
79+
"kcore.estimate": "k_core_decomposition.estimate",
80+
# embedding algos
81+
"fastrp": "fast_rp",
82+
"fastrp.estimate": "fast_rp.estimate",
83+
"graphSage": "graphsage_predict",
84+
"graphSage.train": "graphsage_train",
85+
"hashgnn": "hash_gnn",
86+
"hashgnn.estimate": "hash_gnn.estimate",
87+
}
88+
89+
90+
@pytest.fixture
91+
def gds(arrow_client: AuthenticatedArrowClient, db_query_runner: QueryRunner) -> AuraGraphDataScience:
92+
return AuraGraphDataScience(
93+
query_runner=db_query_runner,
94+
delete_fn=lambda: True,
95+
gds_version=ServerVersion.from_string("2.7.0"),
96+
v2_endpoints=SessionV2Endpoints(arrow_client, db_query_runner, show_progress=False),
97+
)
98+
99+
100+
def check_gds_v2_availability(endpoints: SessionV2Endpoints, algo: str) -> bool:
101+
"""Check if an algorithm is available through gds.v2 interface"""
102+
103+
algo = ENDPOINT_MAPPINGS.get(algo, algo)
104+
105+
algo_parts = algo.split(".")
106+
algo_parts = [to_snake(part) for part in algo_parts]
107+
108+
callable_object = endpoints
109+
for algo_part in algo_parts:
110+
# Get the algorithm endpoint
111+
if not hasattr(callable_object, algo_part):
112+
return False
113+
114+
callable_object = getattr(callable_object, algo_part)
115+
116+
# if we can resolve an object for all parts of the algo endpoint we assume it is available
117+
return True
118+
119+
120+
@pytest.mark.db_integration
121+
def test_algo_coverage(gds: AuraGraphDataScience) -> None:
122+
"""Test that all available Arrow actions are accessible through gds.v2"""
123+
arrow_client = gds.v2._arrow_client
124+
125+
# Get all available Arrow actions
126+
available_v2_actions = [
127+
action.type.removeprefix("v2/") for action in arrow_client.list_actions() if action.type.startswith("v2/")
128+
]
129+
130+
algo_prefixes = ["pathfinding", "centrality", "community", "similarity", "embedding"]
131+
132+
# Filter to only v2 algorithm actions (exclude graph, model, catalog operations)
133+
algorithm_actions: set[str] = {
134+
action for action in available_v2_actions if any(action.startswith(prefix) for prefix in algo_prefixes)
135+
}
136+
137+
missing_endpoints: set[str] = set()
138+
available_endpoints: set[str] = set()
139+
140+
algos_per_category = defaultdict(list)
141+
for action in algorithm_actions:
142+
category, algo_parts = action.split(".", maxsplit=1)
143+
algos_per_category[category].append(algo_parts)
144+
145+
for category, algos in algos_per_category.items():
146+
for algo in algos:
147+
is_available = check_gds_v2_availability(
148+
gds.v2,
149+
algo,
150+
)
151+
action = f"{category}.{algo}"
152+
if is_available:
153+
available_endpoints.add(action)
154+
else:
155+
missing_endpoints.add(action)
156+
157+
# Print summary
158+
print("\nArrow Action Coverage Summary:")
159+
print(f"Total algorithm actions found: {len(algorithm_actions)}")
160+
print(f"Available through gds.v2: {len(available_endpoints)}")
161+
162+
# check if any previously missing algos are now available
163+
assert not available_endpoints.intersection(MISSING_ALGO_ENDPOINTS), (
164+
"Endpoints now available, please remove from MISSING_ALGO_ENDPOINTS"
165+
)
166+
167+
# check missing endpoints against known missing algos
168+
assert missing_endpoints.difference(MISSING_ALGO_ENDPOINTS), "Unexpectedly missing endpoints"

0 commit comments

Comments
 (0)