Skip to content

Commit 95bb4dc

Browse files
committed
Add integration tests for WccArrowEndpoints
1 parent 6b131c6 commit 95bb4dc

File tree

3 files changed

+86
-8
lines changed

3 files changed

+86
-8
lines changed

graphdatascience/arrow_client/v2/job_client.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -45,19 +45,24 @@ def get_summary(client: AuthenticatedArrowClient, job_id: str) -> dict[str, Any]
4545
return deserialize_single(res)
4646

4747
@staticmethod
48-
def stream_results(client: AuthenticatedArrowClient, job_id: str) -> DataFrame:
49-
encoded_config = JobIdConfig(jobId=job_id).dump_json().encode("utf-8")
48+
def stream_results(client: AuthenticatedArrowClient,graph_name: str, job_id: str) -> DataFrame:
49+
payload = {
50+
"graphName": graph_name,
51+
"jobId": job_id,
52+
}
5053

51-
res = client.do_action_with_retry("v2/results.stream", encoded_config)
54+
res = client.do_action_with_retry("v2/results.stream", json.dumps(payload).encode("utf-8"))
5255
export_job_id = JobIdConfig(**deserialize_single(res)).job_id
5356

5457
payload = {
58+
"version": "v2",
5559
"name": export_job_id,
56-
"version": 1,
60+
"body": {}
5761
}
5862

5963
ticket = Ticket(json.dumps(payload).encode("utf-8"))
60-
with client.get_stream(ticket) as get:
61-
arrow_table = get.read_all()
6264

65+
get = client.get_stream(ticket)
66+
arrow_table = get.read_all()
6367
return arrow_table.to_pandas(types_mapper=ArrowDtype) # type: ignore
68+

graphdatascience/procedure_surface/arrow/arrow_wcc_endpoints.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
1-
from typing import List, Optional, Any
1+
from typing import Any, List, Optional
22

33
from pandas import DataFrame
44

5-
from ..api.estimation_result import EstimationResult
65
from ...arrow_client.authenticated_flight_client import AuthenticatedArrowClient
76
from ...arrow_client.v2.job_client import JobClient
87
from ...arrow_client.v2.mutation_client import MutationClient
98
from ...arrow_client.v2.write_back_client import WriteBackClient
109
from ...graph.graph_object import Graph
10+
from ..api.estimation_result import EstimationResult
1111
from ..api.wcc_endpoints import WccEndpoints, WccMutateResult, WccStatsResult, WccWriteResult
1212
from ..utils.config_converter import ConfigConverter
1313

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
import json
2+
3+
import pytest
4+
5+
from graphdatascience import Graph
6+
from graphdatascience.arrow_client.authenticated_flight_client import AuthenticatedArrowClient
7+
from graphdatascience.arrow_client.v2.data_mapper_utils import deserialize_single
8+
from graphdatascience.procedure_surface.arrow.arrow_wcc_endpoints import WccArrowEndpoints
9+
10+
11+
class MockGraph(Graph):
12+
def __init__(self, name: str):
13+
self._name = name
14+
15+
def name(self) -> str:
16+
return self._name
17+
18+
19+
@pytest.fixture
20+
def sample_graph(arrow_client: AuthenticatedArrowClient):
21+
gdl = """
22+
(a: Node)
23+
(b: Node)
24+
(c: Node)
25+
(a)-[:REL]->(c)
26+
"""
27+
28+
res = arrow_client.do_action( "v2/graph.fromGDL", json.dumps({"graphName": "g", "gdlGraph": gdl}).encode("utf-8"))
29+
print(deserialize_single(res))
30+
yield MockGraph("g")
31+
arrow_client.do_action( "v2/graph.drop", json.dumps({"graphName": "g"}).encode("utf-8"))
32+
33+
@pytest.fixture
34+
def wcc_endpoints(arrow_client: AuthenticatedArrowClient):
35+
yield WccArrowEndpoints(arrow_client)
36+
37+
38+
def test_wcc_stats(wcc_endpoints: WccArrowEndpoints, sample_graph: Graph):
39+
"""Test WCC stats operation."""
40+
result = wcc_endpoints.stats(
41+
G=sample_graph
42+
)
43+
44+
assert result.component_count == 2
45+
assert result.compute_millis > 0
46+
assert result.pre_processing_millis > 0
47+
assert result.post_processing_millis > 0
48+
assert "p10" in result.component_distribution
49+
50+
def test_wcc_stream(wcc_endpoints: WccArrowEndpoints, sample_graph: Graph):
51+
"""Test WCC stream operation."""
52+
result_df = wcc_endpoints.stream(
53+
G=sample_graph,
54+
)
55+
56+
assert "nodeId" in result_df.columns
57+
assert "componentId" in result_df.columns
58+
assert len(result_df.columns) == 2
59+
60+
def test_wcc_mutate(wcc_endpoints: WccArrowEndpoints, sample_graph: Graph):
61+
"""Test WCC mutate operation."""
62+
result = wcc_endpoints.mutate(
63+
G=sample_graph,
64+
mutate_property="componentId",
65+
)
66+
67+
assert result.component_count == 2
68+
assert "p10" in result.component_distribution
69+
assert result.pre_processing_millis >= 0
70+
assert result.compute_millis >= 0
71+
assert result.post_processing_millis >= 0
72+
assert result.mutate_millis >= 0
73+
assert result.node_properties_written == 3

0 commit comments

Comments
 (0)