Skip to content

Commit 2bfff47

Browse files
DarthMaxMats-SX
authored andcommitted
Add a write test for ArticleRankArrowEndpoints
1 parent 1464300 commit 2bfff47

File tree

1 file changed

+42
-10
lines changed

1 file changed

+42
-10
lines changed

graphdatascience/tests/integrationV2/procedure_surface/arrow/test_articlerank_arrow_endpoints.py

Lines changed: 42 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,25 +2,43 @@
22

33
import pytest
44

5+
from graphdatascience import Graph, QueryRunner
56
from graphdatascience.arrow_client.authenticated_flight_client import AuthenticatedArrowClient
7+
from graphdatascience.arrow_client.v2.remote_write_back_client import RemoteWriteBackClient
68
from graphdatascience.procedure_surface.api.catalog.graph_api import GraphV2
79
from graphdatascience.procedure_surface.arrow.articlerank_arrow_endpoints import ArticleRankArrowEndpoints
8-
from graphdatascience.tests.integrationV2.procedure_surface.arrow.graph_creation_helper import create_graph
10+
from graphdatascience.tests.integrationV2.procedure_surface.arrow.graph_creation_helper import create_graph, \
11+
create_graph_from_db
912

1013

14+
graph = """
15+
CREATE
16+
(a: Node),
17+
(b: Node),
18+
(c: Node),
19+
(a)-[:REL]->(c),
20+
(b)-[:REL]->(c)
21+
"""
22+
1123
@pytest.fixture
1224
def sample_graph(arrow_client: AuthenticatedArrowClient) -> Generator[GraphV2, None, None]:
13-
gdl = """
14-
(a: Node)
15-
(b: Node)
16-
(c: Node)
17-
(a)-[:REL]->(c)
18-
(b)-[:REL]->(c)
19-
"""
20-
21-
with create_graph(arrow_client, "g", gdl) as G:
25+
with create_graph(arrow_client, "g", graph) as G:
2226
yield G
2327

28+
@pytest.fixture
29+
def db_graph(arrow_client: AuthenticatedArrowClient, query_runner: QueryRunner) -> Generator[Graph, None, None]:
30+
with create_graph_from_db(
31+
arrow_client,
32+
query_runner,
33+
"g",
34+
graph,
35+
"""
36+
MATCH (n)-->(m)
37+
WITH gds.graph.project.remote(n, m) as g
38+
RETURN g
39+
"""
40+
) as g:
41+
yield g
2442

2543
@pytest.fixture
2644
def articlerank_endpoints(arrow_client: AuthenticatedArrowClient) -> Generator[ArticleRankArrowEndpoints, None, None]:
@@ -67,6 +85,20 @@ def test_articlerank_mutate(articlerank_endpoints: ArticleRankArrowEndpoints, sa
6785
assert result.mutate_millis >= 0
6886
assert result.node_properties_written == 3
6987

88+
def test_articlerank_write(arrow_client: AuthenticatedArrowClient, query_runner: QueryRunner, db_graph: Graph) -> None:
89+
endpoints = ArticleRankArrowEndpoints(arrow_client, RemoteWriteBackClient(arrow_client,query_runner))
90+
result = endpoints.write(G=db_graph, write_property="write")
91+
92+
assert result.did_converge
93+
assert "p50" in result.centrality_distribution
94+
assert result.pre_processing_millis >= 0
95+
assert result.compute_millis >= 0
96+
assert result.post_processing_millis >= 0
97+
assert result.write_millis >= 0
98+
assert result.node_properties_written == 3
99+
100+
assert query_runner.run_cypher("MATCH (n) WHERE n.write IS NOT NULL RETURN COUNT(*) AS count").squeeze() == 3
101+
70102

71103
def test_articlerank_estimate(articlerank_endpoints: ArticleRankArrowEndpoints, sample_graph: GraphV2) -> None:
72104
result = articlerank_endpoints.estimate(sample_graph)

0 commit comments

Comments
 (0)