Skip to content

Commit 50c52e1

Browse files
committed
Add integration tests for WccCypherEndpoints
1 parent 95bb4dc commit 50c52e1

File tree

3 files changed

+108
-2
lines changed

3 files changed

+108
-2
lines changed

graphdatascience/tests/integrationV2/procedure_surface/arrow/test_wcc_arrow_endpoints.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,7 @@ def sample_graph(arrow_client: AuthenticatedArrowClient):
2525
(a)-[:REL]->(c)
2626
"""
2727

28-
res = arrow_client.do_action( "v2/graph.fromGDL", json.dumps({"graphName": "g", "gdlGraph": gdl}).encode("utf-8"))
29-
print(deserialize_single(res))
28+
arrow_client.do_action( "v2/graph.fromGDL", json.dumps({"graphName": "g", "gdlGraph": gdl}).encode("utf-8"))
3029
yield MockGraph("g")
3130
arrow_client.do_action( "v2/graph.drop", json.dumps({"graphName": "g"}).encode("utf-8"))
3231

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
import os
2+
from typing import Generator
3+
4+
import pytest
5+
from testcontainers.core.container import DockerContainer
6+
from testcontainers.core.waiting_utils import wait_for_logs
7+
from testcontainers.neo4j import Neo4jContainer
8+
9+
from graphdatascience.query_runner.neo4j_query_runner import Neo4jQueryRunner
10+
11+
12+
@pytest.fixture(scope="session")
13+
def neo4j_database_container() -> Generator[Neo4jContainer, None, None]:
14+
neo4j_image = os.getenv("NEO4J_DATABASE_IMAGE", "neo4j:enterprise")
15+
16+
neo4j_container = (
17+
Neo4jContainer(
18+
image=neo4j_image,
19+
)
20+
.with_env("NEO4J_ACCEPT_LICENSE_AGREEMENT", "yes")
21+
.with_env("NEO4J_PLUGINS", '["graph-data-science"]')
22+
)
23+
24+
with neo4j_container as neo4j_db:
25+
wait_for_logs(neo4j_db, "Started.")
26+
yield neo4j_db
27+
28+
29+
@pytest.fixture
30+
def query_runner(neo4j_database_container: DockerContainer):
31+
yield Neo4jQueryRunner.create_for_db(
32+
f"bolt://localhost:{neo4j_database_container.get_exposed_port(7687)}",
33+
("neo4j", "password"),
34+
)
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
import pytest
2+
3+
from graphdatascience import Graph, QueryRunner
4+
from graphdatascience.procedure_surface.arrow.arrow_wcc_endpoints import WccArrowEndpoints
5+
from graphdatascience.procedure_surface.cypher.wcc_cypher_endpoints import WccCypherEndpoints
6+
from graphdatascience.tests.integrationV2.procedure_surface.cypher.conftest import query_runner
7+
8+
9+
@pytest.fixture
10+
def sample_graph(query_runner: QueryRunner):
11+
create_statement = """
12+
CREATE
13+
(a: Node),
14+
(b: Node),
15+
(c: Node),
16+
(a)-[:REL]->(c)
17+
"""
18+
19+
query_runner.run_cypher(create_statement)
20+
21+
query_runner.run_cypher("""
22+
MATCH (n)
23+
OPTIONAL MATCH (n)-[r]->(m)
24+
WITH gds.graph.project('g', n, m, {}) AS G
25+
RETURN G
26+
""")
27+
28+
yield Graph("g", query_runner)
29+
30+
query_runner.run_cypher("CALL gds.graph.drop('g')")
31+
query_runner.run_cypher("MATCH (n) DETACH DELETE n")
32+
33+
@pytest.fixture
34+
def wcc_endpoints(query_runner: QueryRunner):
35+
yield WccCypherEndpoints(query_runner)
36+
37+
38+
def test_wcc_stats(wcc_endpoints: WccArrowEndpoints, sample_graph: Graph):
39+
"""Test WCC stats operation."""
40+
result = wcc_endpoints.stats(
41+
G=sample_graph
42+
)
43+
44+
assert result.component_count == 2
45+
assert result.compute_millis > 0
46+
assert result.pre_processing_millis > 0
47+
assert result.post_processing_millis > 0
48+
assert "p10" in result.component_distribution
49+
50+
def test_wcc_stream(wcc_endpoints: WccArrowEndpoints, sample_graph: Graph):
51+
"""Test WCC stream operation."""
52+
result_df = wcc_endpoints.stream(
53+
G=sample_graph,
54+
)
55+
56+
assert "nodeId" in result_df.columns
57+
assert "componentId" in result_df.columns
58+
assert len(result_df.columns) == 2
59+
60+
def test_wcc_mutate(wcc_endpoints: WccArrowEndpoints, sample_graph: Graph):
61+
"""Test WCC mutate operation."""
62+
result = wcc_endpoints.mutate(
63+
G=sample_graph,
64+
mutate_property="componentId",
65+
)
66+
67+
assert result.component_count == 2
68+
assert "p10" in result.component_distribution
69+
assert result.pre_processing_millis >= 0
70+
assert result.compute_millis >= 0
71+
assert result.post_processing_millis >= 0
72+
assert result.mutate_millis >= 0
73+
assert result.node_properties_written == 3

0 commit comments

Comments
 (0)