Skip to content

Commit 5e2cd81

Browse files
committed
Implement WccEndpoints with cypher
1 parent b93c117 commit 5e2cd81

File tree

6 files changed

+687
-0
lines changed

6 files changed

+687
-0
lines changed
Lines changed: 195 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,195 @@
1+
from typing import Any, List, Optional, Union
2+
3+
from pandas import DataFrame
4+
5+
from ...call_parameters import CallParameters
6+
from ...graph.graph_object import Graph
7+
from ...query_runner.query_runner import QueryRunner
8+
from ..api.estimation_result import EstimationResult
9+
from ..api.wcc_endpoints import WccEndpoints, WccMutateResult, WccStatsResult, WccWriteResult
10+
from ..utils.config_converter import ConfigConverter
11+
12+
13+
class WccCypherEndpoints(WccEndpoints):
14+
"""
15+
Implementation of the WCC algorithm endpoints.
16+
This class handles the actual execution by forwarding calls to the query runner.
17+
"""
18+
19+
def __init__(self, query_runner: QueryRunner):
20+
self._query_runner = query_runner
21+
22+
def mutate(
23+
self,
24+
G: Graph,
25+
mutate_property: str,
26+
threshold: Optional[float] = None,
27+
relationship_types: Optional[List[str]] = None,
28+
node_labels: Optional[List[str]] = None,
29+
sudo: Optional[bool] = None,
30+
log_progress: Optional[bool] = None,
31+
username: Optional[str] = None,
32+
concurrency: Optional[int] = None,
33+
job_id: Optional[str] = None,
34+
seed_property: Optional[str] = None,
35+
consecutive_ids: Optional[bool] = None,
36+
relationship_weight_property: Optional[str] = None,
37+
) -> WccMutateResult:
38+
config = ConfigConverter.convert_to_gds_config(
39+
mutate_property=mutate_property,
40+
concurrency=concurrency,
41+
consecutive_ids=consecutive_ids,
42+
job_id=job_id,
43+
log_progress=log_progress,
44+
node_labels=node_labels,
45+
relationship_types=relationship_types,
46+
relationship_weight_property=relationship_weight_property,
47+
seed_property=seed_property,
48+
sudo=sudo,
49+
threshold=threshold,
50+
username=username,
51+
)
52+
53+
# Run procedure and return results
54+
params = CallParameters(graph_name=G.name(), config=config)
55+
params.ensure_job_id_in_config()
56+
57+
cypher_result = self._query_runner.call_procedure(endpoint="gds.wcc.mutate", params=params).squeeze()
58+
59+
return WccMutateResult(**cypher_result.to_dict())
60+
61+
def stats(
62+
self,
63+
G: Graph,
64+
threshold: Optional[float] = None,
65+
relationship_types: Optional[List[str]] = None,
66+
node_labels: Optional[List[str]] = None,
67+
sudo: Optional[bool] = None,
68+
log_progress: Optional[bool] = None,
69+
username: Optional[str] = None,
70+
concurrency: Optional[int] = None,
71+
job_id: Optional[str] = None,
72+
seed_property: Optional[str] = None,
73+
consecutive_ids: Optional[bool] = None,
74+
relationship_weight_property: Optional[str] = None,
75+
) -> WccStatsResult:
76+
config = ConfigConverter.convert_to_gds_config(
77+
concurrency=concurrency,
78+
consecutive_ids=consecutive_ids,
79+
job_id=job_id,
80+
log_progress=log_progress,
81+
node_labels=node_labels,
82+
relationship_types=relationship_types,
83+
relationship_weight_property=relationship_weight_property,
84+
seed_property=seed_property,
85+
sudo=sudo,
86+
threshold=threshold,
87+
username=username,
88+
)
89+
90+
# Run procedure and return results
91+
params = CallParameters(graph_name=G.name(), config=config)
92+
params.ensure_job_id_in_config()
93+
94+
cypher_result = self._query_runner.call_procedure(endpoint="gds.wcc.stats", params=params).squeeze() # type: ignore
95+
96+
return WccStatsResult(**cypher_result.to_dict())
97+
98+
def stream(
99+
self,
100+
G: Graph,
101+
min_component_size: Optional[int] = None,
102+
threshold: Optional[float] = None,
103+
relationship_types: Optional[List[str]] = None,
104+
node_labels: Optional[List[str]] = None,
105+
sudo: Optional[bool] = None,
106+
log_progress: Optional[bool] = None,
107+
username: Optional[str] = None,
108+
concurrency: Optional[int] = None,
109+
job_id: Optional[str] = None,
110+
seed_property: Optional[str] = None,
111+
consecutive_ids: Optional[bool] = None,
112+
relationship_weight_property: Optional[str] = None,
113+
) -> DataFrame:
114+
config = ConfigConverter.convert_to_gds_config(
115+
concurrency=concurrency,
116+
consecutive_ids=consecutive_ids,
117+
job_id=job_id,
118+
log_progress=log_progress,
119+
min_component_size=min_component_size,
120+
node_labels=node_labels,
121+
relationship_types=relationship_types,
122+
relationship_weight_property=relationship_weight_property,
123+
seed_property=seed_property,
124+
sudo=sudo,
125+
threshold=threshold,
126+
username=username,
127+
)
128+
129+
# Run procedure and return results
130+
params = CallParameters(graph_name=G.name(), config=config)
131+
params.ensure_job_id_in_config()
132+
133+
return self._query_runner.call_procedure(endpoint="gds.wcc.stream", params=params)
134+
135+
def write(
136+
self,
137+
G: Graph,
138+
write_property: str,
139+
min_component_size: Optional[int] = None,
140+
threshold: Optional[float] = None,
141+
relationship_types: Optional[List[str]] = None,
142+
node_labels: Optional[List[str]] = None,
143+
sudo: Optional[bool] = None,
144+
log_progress: Optional[bool] = None,
145+
username: Optional[str] = None,
146+
concurrency: Optional[int] = None,
147+
job_id: Optional[str] = None,
148+
seed_property: Optional[str] = None,
149+
consecutive_ids: Optional[bool] = None,
150+
relationship_weight_property: Optional[str] = None,
151+
write_concurrency: Optional[int] = None,
152+
) -> WccWriteResult:
153+
config = ConfigConverter.convert_to_gds_config(
154+
write_property=write_property,
155+
concurrency=concurrency,
156+
consecutive_ids=consecutive_ids,
157+
job_id=job_id,
158+
log_progress=log_progress,
159+
min_component_size=min_component_size,
160+
node_labels=node_labels,
161+
relationship_types=relationship_types,
162+
relationship_weight_property=relationship_weight_property,
163+
seed_property=seed_property,
164+
sudo=sudo,
165+
threshold=threshold,
166+
username=username,
167+
)
168+
169+
if write_concurrency is not None:
170+
config["writeConcurrency"] = write_concurrency
171+
172+
params = CallParameters(graph_name=G.name(), config=config)
173+
params.ensure_job_id_in_config()
174+
175+
result = self._query_runner.call_procedure(endpoint="gds.wcc.write", params=params).squeeze() # type: ignore
176+
177+
return WccWriteResult(**result.to_dict())
178+
179+
def estimate(
180+
self, graph_name: Optional[str] = None, projection_config: Optional[dict[str, Any]] = None
181+
) -> EstimationResult:
182+
config: Union[str, dict[str, Any]] = {}
183+
184+
if graph_name is not None:
185+
config = graph_name
186+
elif projection_config is not None:
187+
config = projection_config
188+
else:
189+
raise ValueError("Either graph_name or projection_config must be provided.")
190+
191+
params = CallParameters(config=config)
192+
193+
result = self._query_runner.call_procedure(endpoint="gds.wcc.stats.estimate", params=params).squeeze()
194+
195+
return EstimationResult(**result.to_dict())
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
from typing import Any, Dict, Optional
2+
3+
4+
class ConfigConverter:
5+
@staticmethod
6+
def convert_to_gds_config(**kwargs: Optional[Any]) -> dict[str, Any]:
7+
config: dict[str, Any] = {}
8+
9+
# Process kwargs
10+
processed_kwargs = ConfigConverter._process_dict_values(kwargs)
11+
config.update(processed_kwargs)
12+
13+
return config
14+
15+
@staticmethod
16+
def _convert_to_camel_case(name: str) -> str:
17+
"""Convert a snake_case string to camelCase."""
18+
parts = name.split("_")
19+
return "".join([word.capitalize() if i > 0 else word.lower() for i, word in enumerate(parts)])
20+
21+
@staticmethod
22+
def _process_dict_values(input_dict: Dict[str, Any]) -> Dict[str, Any]:
23+
"""Process dictionary values, converting keys to camelCase and handling nested dictionaries."""
24+
result = {}
25+
for key, value in input_dict.items():
26+
if value is not None:
27+
camel_key = ConfigConverter._convert_to_camel_case(key)
28+
# Recursively process nested dictionaries
29+
if isinstance(value, dict):
30+
result[camel_key] = ConfigConverter._process_dict_values(value)
31+
else:
32+
result[camel_key] = value
33+
return result

graphdatascience/tests/unit/procedure_surface/__init__.py

Whitespace-only changes.

graphdatascience/tests/unit/procedure_surface/api/__init__.py

Whitespace-only changes.
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
from graphdatascience.procedure_surface.api.estimation_result import EstimationResult
2+
3+
4+
def test_estimation_result_initialization() -> None:
5+
estimation_result = EstimationResult(
6+
nodeCount=5,
7+
relationshipCount=10,
8+
requiredMemory="512MB",
9+
treeView="TreeData",
10+
mapView={"key": "value"},
11+
bytesMin=500,
12+
bytesMax=1000,
13+
heapPercentageMin=0.1,
14+
heapPercentageMax=0.5,
15+
)
16+
assert estimation_result.node_count == 5
17+
assert estimation_result.relationship_count == 10
18+
assert estimation_result.required_memory == "512MB"
19+
assert estimation_result.tree_view == "TreeData"
20+
assert estimation_result.map_view == {"key": "value"}
21+
assert estimation_result.bytes_min == 500
22+
assert estimation_result.bytes_max == 1000
23+
assert estimation_result.heap_percentage_min == 0.1
24+
assert estimation_result.heap_percentage_max == 0.5
25+
26+
27+
def test_estimation_result_getitem() -> None:
28+
estimation_result = EstimationResult(
29+
nodeCount=5,
30+
relationshipCount=10,
31+
requiredMemory="512MB",
32+
treeView="TreeData",
33+
mapView={"key": "value"},
34+
bytesMin=500,
35+
bytesMax=1000,
36+
heapPercentageMin=0.1,
37+
heapPercentageMax=0.5,
38+
)
39+
assert estimation_result["node_count"] == 5
40+
assert estimation_result["relationship_count"] == 10
41+
assert estimation_result["required_memory"] == "512MB"
42+
assert estimation_result["tree_view"] == "TreeData"
43+
assert estimation_result["map_view"] == {"key": "value"}
44+
assert estimation_result["bytes_min"] == 500
45+
assert estimation_result["bytes_max"] == 1000
46+
assert estimation_result["heap_percentage_min"] == 0.1
47+
assert estimation_result["heap_percentage_max"] == 0.5
48+
49+
50+
def test_estimation_result_from_cypher() -> None:
51+
cypher_result = {
52+
"nodeCount": 5,
53+
"relationshipCount": 10,
54+
"requiredMemory": "512MB",
55+
"treeView": "TreeData",
56+
"mapView": {"key": "value"},
57+
"bytesMin": 500,
58+
"bytesMax": 1000,
59+
"heapPercentageMin": 0.1,
60+
"heapPercentageMax": 0.5,
61+
}
62+
estimation_result = EstimationResult.from_cypher(cypher_result)
63+
assert isinstance(estimation_result, EstimationResult)
64+
assert estimation_result.node_count == 5
65+
assert estimation_result.relationship_count == 10
66+
assert estimation_result.required_memory == "512MB"
67+
assert estimation_result.tree_view == "TreeData"
68+
assert estimation_result.map_view == {"key": "value"}
69+
assert estimation_result.bytes_min == 500
70+
assert estimation_result.bytes_max == 1000
71+
assert estimation_result.heap_percentage_min == 0.1
72+
assert estimation_result.heap_percentage_max == 0.5

0 commit comments

Comments
 (0)