Skip to content

Commit c1ec516

Browse files
DarthMaxMats-SX
authored andcommitted
Add missing write tests for arrow endpoints
1 parent 2bfff47 commit c1ec516

25 files changed

+977
-276
lines changed

graphdatascience/procedure_surface/arrow/celf_arrow_endpoints.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,12 @@ def write(
147147
)
148148

149149
result = self._node_property_endpoints.run_job_and_write(
150-
"v2/centrality.celf", G, config, write_concurrency=write_concurrency, concurrency=concurrency, property_overwrites=write_property
150+
"v2/centrality.celf",
151+
G,
152+
config,
153+
write_concurrency=write_concurrency,
154+
concurrency=concurrency,
155+
property_overwrites=write_property,
151156
)
152157

153158
return CelfWriteResult(**result)

graphdatascience/procedure_surface/arrow/graphsage_predict_arrow_endpoints.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,15 @@
1111
)
1212

1313
from ...arrow_client.authenticated_flight_client import AuthenticatedArrowClient
14+
from ...arrow_client.v2.remote_write_back_client import RemoteWriteBackClient
1415
from .model_api_arrow import ModelApiArrow
1516
from .node_property_endpoints import NodePropertyEndpoints
1617

1718

1819
class GraphSagePredictArrowEndpoints(GraphSagePredictEndpoints):
19-
def __init__(self, arrow_client: AuthenticatedArrowClient):
20+
def __init__(self, arrow_client: AuthenticatedArrowClient, write_back_client: Optional[RemoteWriteBackClient]):
2021
self._arrow_client = arrow_client
21-
self._node_property_endpoints = NodePropertyEndpoints(arrow_client)
22+
self._node_property_endpoints = NodePropertyEndpoints(arrow_client, write_back_client)
2223
self._model_api = ModelApiArrow(arrow_client)
2324

2425
def stream(
@@ -79,12 +80,7 @@ def write(
7980
)
8081

8182
raw_result = self._node_property_endpoints.run_job_and_write(
82-
"v2/embeddings.graphSage",
83-
G,
84-
config,
85-
write_concurrency,
86-
concurrency,
87-
write_property
83+
"v2/embeddings.graphSage", G, config, write_concurrency, concurrency, write_property
8884
)
8985

9086
return GraphSageWriteResult(**raw_result)

graphdatascience/procedure_surface/arrow/graphsage_train_arrow_endpoints.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
from graphdatascience.procedure_surface.arrow.graphsage_predict_arrow_endpoints import GraphSagePredictArrowEndpoints
66

77
from ...arrow_client.authenticated_flight_client import AuthenticatedArrowClient
8+
from ...arrow_client.v2.remote_write_back_client import RemoteWriteBackClient
9+
from ...graph.graph_object import Graph
810
from ..api.graphsage_train_endpoints import (
911
GraphSageTrainEndpoints,
1012
GraphSageTrainResult,
@@ -14,9 +16,10 @@
1416

1517

1618
class GraphSageTrainArrowEndpoints(GraphSageTrainEndpoints):
17-
def __init__(self, arrow_client: AuthenticatedArrowClient):
19+
def __init__(self, arrow_client: AuthenticatedArrowClient, write_back_client: Optional[RemoteWriteBackClient]):
1820
self._arrow_client = arrow_client
19-
self._node_property_endpoints = NodePropertyEndpoints(arrow_client)
21+
self._write_back_client = write_back_client
22+
self._node_property_endpoints = NodePropertyEndpoints(arrow_client, write_back_client=write_back_client)
2023
self._model_api = ModelApiArrow(arrow_client)
2124

2225
def train(
@@ -83,7 +86,9 @@ def train(
8386
result = self._node_property_endpoints.run_job_and_get_summary("v2/embeddings.graphSage.train", G, config)
8487

8588
model = GraphSageModelV2(
86-
model_name, self._model_api, predict_endpoints=GraphSagePredictArrowEndpoints(self._arrow_client)
89+
model_name,
90+
self._model_api,
91+
predict_endpoints=GraphSagePredictArrowEndpoints(self._arrow_client, self._write_back_client),
8792
)
8893
train_result = GraphSageTrainResult(**result)
8994

graphdatascience/session/session_v2_endpoints.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,11 +68,11 @@ def fast_rp(self) -> FastRPArrowEndpoints:
6868

6969
@property
7070
def graphsage_predict(self) -> GraphSagePredictArrowEndpoints:
71-
return GraphSagePredictArrowEndpoints(self._arrow_client)
71+
return GraphSagePredictArrowEndpoints(self._arrow_client, self._write_back_client)
7272

7373
@property
7474
def graphsage_train(self) -> GraphSageTrainArrowEndpoints:
75-
return GraphSageTrainArrowEndpoints(self._arrow_client)
75+
return GraphSageTrainArrowEndpoints(self._arrow_client, self._write_back_client)
7676

7777
@property
7878
def harmonic_centrality(self) -> ClosenessHarmonicArrowEndpoints:

graphdatascience/tests/integrationV2/procedure_surface/arrow/graph_creation_helper.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,12 +43,14 @@ def create_graph_from_db(
4343
graph_name: str,
4444
graph_data: str,
4545
query: str,
46+
undirected_relationship_types: Optional[list[str]] = None,
4647
) -> Generator[GraphV2, Any, None]:
4748
try:
4849
query_runner.run_cypher(graph_data)
4950
result = CatalogArrowEndpoints(arrow_client, query_runner).project(
5051
graph_name=graph_name,
5152
query=query,
53+
undirected_relationship_types=undirected_relationship_types,
5254
)
5355

5456
yield result.graph

graphdatascience/tests/integrationV2/procedure_surface/arrow/test_articlerank_arrow_endpoints.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,10 @@
77
from graphdatascience.arrow_client.v2.remote_write_back_client import RemoteWriteBackClient
88
from graphdatascience.procedure_surface.api.catalog.graph_api import GraphV2
99
from graphdatascience.procedure_surface.arrow.articlerank_arrow_endpoints import ArticleRankArrowEndpoints
10-
from graphdatascience.tests.integrationV2.procedure_surface.arrow.graph_creation_helper import create_graph, \
11-
create_graph_from_db
12-
10+
from graphdatascience.tests.integrationV2.procedure_surface.arrow.graph_creation_helper import (
11+
create_graph,
12+
create_graph_from_db,
13+
)
1314

1415
graph = """
1516
CREATE
@@ -20,26 +21,29 @@
2021
(b)-[:REL]->(c)
2122
"""
2223

24+
2325
@pytest.fixture
2426
def sample_graph(arrow_client: AuthenticatedArrowClient) -> Generator[GraphV2, None, None]:
2527
with create_graph(arrow_client, "g", graph) as G:
2628
yield G
2729

30+
2831
@pytest.fixture
2932
def db_graph(arrow_client: AuthenticatedArrowClient, query_runner: QueryRunner) -> Generator[Graph, None, None]:
3033
with create_graph_from_db(
31-
arrow_client,
32-
query_runner,
33-
"g",
34-
graph,
35-
"""
34+
arrow_client,
35+
query_runner,
36+
"g",
37+
graph,
38+
"""
3639
MATCH (n)-->(m)
3740
WITH gds.graph.project.remote(n, m) as g
3841
RETURN g
39-
"""
42+
""",
4043
) as g:
4144
yield g
4245

46+
4347
@pytest.fixture
4448
def articlerank_endpoints(arrow_client: AuthenticatedArrowClient) -> Generator[ArticleRankArrowEndpoints, None, None]:
4549
yield ArticleRankArrowEndpoints(arrow_client)
@@ -85,8 +89,9 @@ def test_articlerank_mutate(articlerank_endpoints: ArticleRankArrowEndpoints, sa
8589
assert result.mutate_millis >= 0
8690
assert result.node_properties_written == 3
8791

92+
8893
def test_articlerank_write(arrow_client: AuthenticatedArrowClient, query_runner: QueryRunner, db_graph: Graph) -> None:
89-
endpoints = ArticleRankArrowEndpoints(arrow_client, RemoteWriteBackClient(arrow_client,query_runner))
94+
endpoints = ArticleRankArrowEndpoints(arrow_client, RemoteWriteBackClient(arrow_client, query_runner))
9095
result = endpoints.write(G=db_graph, write_property="write")
9196

9297
assert result.did_converge

graphdatascience/tests/integrationV2/procedure_surface/arrow/test_articulationpoints_arrow_endpoints.py

Lines changed: 52 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,32 +2,56 @@
22

33
import pytest
44

5+
from graphdatascience import Graph, QueryRunner
56
from graphdatascience.arrow_client.authenticated_flight_client import AuthenticatedArrowClient
7+
from graphdatascience.arrow_client.v2.remote_write_back_client import RemoteWriteBackClient
68
from graphdatascience.procedure_surface.api.articulationpoints_endpoints import (
79
ArticulationPointsMutateResult,
810
ArticulationPointsStatsResult,
11+
ArticulationPointsWriteResult,
912
)
1013
from graphdatascience.procedure_surface.api.catalog.graph_api import GraphV2
1114
from graphdatascience.procedure_surface.arrow.articulationpoints_arrow_endpoints import (
1215
ArticulationPointsArrowEndpoints,
1316
)
14-
from graphdatascience.tests.integrationV2.procedure_surface.arrow.graph_creation_helper import create_graph
15-
17+
from graphdatascience.tests.integrationV2.procedure_surface.arrow.graph_creation_helper import (
18+
create_graph,
19+
create_graph_from_db,
20+
)
1621

17-
@pytest.fixture
18-
def sample_graph(arrow_client: AuthenticatedArrowClient) -> Generator[GraphV2, None, None]:
19-
gdl = """
20-
(a: Node)
21-
(b: Node)
22-
(c: Node)
23-
(a)-[:REL]->(c)
22+
graph = """
23+
CREATE
24+
(a: Node),
25+
(b: Node),
26+
(c: Node),
27+
(a)-[:REL]->(c),
2428
(b)-[:REL]->(c)
2529
"""
2630

27-
with create_graph(arrow_client, "g", gdl, undirected=("REL", "UNDIRECTED_REL")) as G:
31+
32+
@pytest.fixture
33+
def sample_graph(arrow_client: AuthenticatedArrowClient) -> Generator[GraphV2, None, None]:
34+
with create_graph(arrow_client, "g", graph, undirected=("REL", "UNDIRECTED_REL")) as G:
2835
yield G
2936

3037

38+
@pytest.fixture
39+
def db_graph(arrow_client: AuthenticatedArrowClient, query_runner: QueryRunner) -> Generator[Graph, None, None]:
40+
with create_graph_from_db(
41+
arrow_client,
42+
query_runner,
43+
"g",
44+
graph,
45+
"""
46+
MATCH (n)-->(m)
47+
WITH gds.graph.project.remote(n, m, {relationshipType: "REL"}) as g
48+
RETURN g
49+
""",
50+
["REL"],
51+
) as g:
52+
yield g
53+
54+
3155
@pytest.fixture
3256
def articulationpoints_endpoints(arrow_client: AuthenticatedArrowClient) -> ArticulationPointsArrowEndpoints:
3357
return ArticulationPointsArrowEndpoints(arrow_client)
@@ -36,7 +60,6 @@ def articulationpoints_endpoints(arrow_client: AuthenticatedArrowClient) -> Arti
3660
def test_articulationpoints_mutate(
3761
articulationpoints_endpoints: ArticulationPointsArrowEndpoints, sample_graph: GraphV2
3862
) -> None:
39-
"""Test ArticulationPoints mutate operation."""
4063
result = articulationpoints_endpoints.mutate(
4164
G=sample_graph,
4265
mutate_property="articulationPoint",
@@ -53,7 +76,6 @@ def test_articulationpoints_mutate(
5376
def test_articulationpoints_stats(
5477
articulationpoints_endpoints: ArticulationPointsArrowEndpoints, sample_graph: GraphV2
5578
) -> None:
56-
"""Test ArticulationPoints stats operation."""
5779
result = articulationpoints_endpoints.stats(sample_graph)
5880

5981
assert isinstance(result, ArticulationPointsStatsResult)
@@ -64,17 +86,33 @@ def test_articulationpoints_stats(
6486
def test_articulationpoints_stream_not_implemented(
6587
articulationpoints_endpoints: ArticulationPointsArrowEndpoints, sample_graph: GraphV2
6688
) -> None:
67-
"""Test that ArticulationPoints stream raises NotImplementedError."""
6889
with pytest.raises(
6990
NotImplementedError, match="Stream mode is not supported for ArticulationPoints arrow endpoints"
7091
):
7192
articulationpoints_endpoints.stream(sample_graph)
7293

7394

95+
def test_articulationpoints_write(
96+
arrow_client: AuthenticatedArrowClient, query_runner: QueryRunner, db_graph: Graph
97+
) -> None:
98+
endpoints = ArticulationPointsArrowEndpoints(arrow_client, RemoteWriteBackClient(arrow_client, query_runner))
99+
result = endpoints.write(G=db_graph, write_property="articulationPoint")
100+
101+
assert isinstance(result, ArticulationPointsWriteResult)
102+
assert result.articulation_point_count >= 0
103+
assert result.compute_millis >= 0
104+
assert result.write_millis >= 0
105+
assert result.node_properties_written == 3
106+
107+
assert (
108+
query_runner.run_cypher("MATCH (n) WHERE n.articulationPoint IS NOT NULL RETURN COUNT(*) AS count").squeeze()
109+
== 3
110+
)
111+
112+
74113
def test_articulationpoints_estimate(
75114
articulationpoints_endpoints: ArticulationPointsArrowEndpoints, sample_graph: GraphV2
76115
) -> None:
77-
"""Test ArticulationPoints memory estimation."""
78116
result = articulationpoints_endpoints.estimate(sample_graph)
79117

80118
assert result.node_count == 3

0 commit comments

Comments
 (0)