|
2 | 2 |
|
3 | 3 | import pytest |
4 | 4 |
|
| 5 | +from graphdatascience import Graph, QueryRunner |
5 | 6 | from graphdatascience.arrow_client.authenticated_flight_client import AuthenticatedArrowClient |
| 7 | +from graphdatascience.arrow_client.v2.remote_write_back_client import RemoteWriteBackClient |
6 | 8 | from graphdatascience.procedure_surface.api.catalog.graph_api import GraphV2 |
7 | 9 | from graphdatascience.procedure_surface.arrow.articlerank_arrow_endpoints import ArticleRankArrowEndpoints |
8 | | -from graphdatascience.tests.integrationV2.procedure_surface.arrow.graph_creation_helper import create_graph |
| 10 | +from graphdatascience.tests.integrationV2.procedure_surface.arrow.graph_creation_helper import create_graph, \ |
| 11 | + create_graph_from_db |
9 | 12 |
|
10 | 13 |
|
| 14 | +graph = """ |
| 15 | + CREATE |
| 16 | + (a: Node), |
| 17 | + (b: Node), |
| 18 | + (c: Node), |
| 19 | + (a)-[:REL]->(c), |
| 20 | + (b)-[:REL]->(c) |
| 21 | + """ |
| 22 | + |
11 | 23 | @pytest.fixture |
12 | 24 | def sample_graph(arrow_client: AuthenticatedArrowClient) -> Generator[GraphV2, None, None]: |
13 | | - gdl = """ |
14 | | - (a: Node) |
15 | | - (b: Node) |
16 | | - (c: Node) |
17 | | - (a)-[:REL]->(c) |
18 | | - (b)-[:REL]->(c) |
19 | | - """ |
20 | | - |
21 | | - with create_graph(arrow_client, "g", gdl) as G: |
| 25 | + with create_graph(arrow_client, "g", graph) as G: |
22 | 26 | yield G |
23 | 27 |
|
| 28 | +@pytest.fixture |
| 29 | +def db_graph(arrow_client: AuthenticatedArrowClient, query_runner: QueryRunner) -> Generator[Graph, None, None]: |
| 30 | + with create_graph_from_db( |
| 31 | + arrow_client, |
| 32 | + query_runner, |
| 33 | + "g", |
| 34 | + graph, |
| 35 | + """ |
| 36 | + MATCH (n)-->(m) |
| 37 | + WITH gds.graph.project.remote(n, m) as g |
| 38 | + RETURN g |
| 39 | + """ |
| 40 | + ) as g: |
| 41 | + yield g |
24 | 42 |
|
25 | 43 | @pytest.fixture |
26 | 44 | def articlerank_endpoints(arrow_client: AuthenticatedArrowClient) -> Generator[ArticleRankArrowEndpoints, None, None]: |
@@ -67,6 +85,20 @@ def test_articlerank_mutate(articlerank_endpoints: ArticleRankArrowEndpoints, sa |
67 | 85 | assert result.mutate_millis >= 0 |
68 | 86 | assert result.node_properties_written == 3 |
69 | 87 |
|
| 88 | +def test_articlerank_write(arrow_client: AuthenticatedArrowClient, query_runner: QueryRunner, db_graph: Graph) -> None: |
| 89 | + endpoints = ArticleRankArrowEndpoints(arrow_client, RemoteWriteBackClient(arrow_client,query_runner)) |
| 90 | + result = endpoints.write(G=db_graph, write_property="write") |
| 91 | + |
| 92 | + assert result.did_converge |
| 93 | + assert "p50" in result.centrality_distribution |
| 94 | + assert result.pre_processing_millis >= 0 |
| 95 | + assert result.compute_millis >= 0 |
| 96 | + assert result.post_processing_millis >= 0 |
| 97 | + assert result.write_millis >= 0 |
| 98 | + assert result.node_properties_written == 3 |
| 99 | + |
| 100 | + assert query_runner.run_cypher("MATCH (n) WHERE n.write IS NOT NULL RETURN COUNT(*) AS count").squeeze() == 3 |
| 101 | + |
70 | 102 |
|
71 | 103 | def test_articlerank_estimate(articlerank_endpoints: ArticleRankArrowEndpoints, sample_graph: GraphV2) -> None: |
72 | 104 | result = articlerank_endpoints.estimate(sample_graph) |
|
0 commit comments