Skip to content

Commit cf63d86

Browse files
committed
Add pythonic projection to Cypher mapper DSL
1 parent f94a622 commit cf63d86

File tree

3 files changed

+472
-0
lines changed

3 files changed

+472
-0
lines changed
Lines changed: 242 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,242 @@
1+
from collections import namedtuple
2+
from typing import Any, NamedTuple, Optional, Tuple
3+
4+
from pandas import Series
5+
6+
from ..error.illegal_attr_checker import IllegalAttrChecker
7+
from ..query_runner.query_runner import QueryRunner
8+
from ..server_version.server_version import ServerVersion
9+
from .graph_object import Graph
10+
11+
12+
class NodeProperty(NamedTuple):
13+
name: str
14+
property_key: str
15+
default_value: Optional[Any] = None
16+
17+
18+
class NodeProjection(NamedTuple):
19+
name: str
20+
source_label: str
21+
properties: Optional[list[NodeProperty]] = None
22+
23+
24+
class RelationshipProperty(NamedTuple):
25+
name: str
26+
property_key: str
27+
default_value: Optional[Any] = None
28+
29+
30+
class RelationshipProjection(NamedTuple):
31+
name: str
32+
source_type: str
33+
properties: Optional[list[RelationshipProperty]] = None
34+
35+
36+
class MatchPart(NamedTuple):
37+
match: str = ""
38+
source_where: str = ""
39+
optional_match: str = ""
40+
optional_where: str = ""
41+
42+
def __str__(self) -> str:
43+
return "\n".join(
44+
part
45+
for part in [
46+
self.match,
47+
self.source_where,
48+
self.optional_match,
49+
self.optional_where,
50+
]
51+
if part
52+
)
53+
54+
55+
class MatchPattern(NamedTuple):
56+
label_filter: str = ""
57+
left_arrow: str = ""
58+
type_filter: str = ""
59+
right_arrow: str = ""
60+
61+
def __str__(self) -> str:
62+
return f"{self.left_arrow}{self.type_filter}{self.right_arrow}(target{self.label_filter})"
63+
64+
65+
class GraphCypherRunner(IllegalAttrChecker):
66+
def __init__(self, query_runner: QueryRunner, namespace: str, server_version: ServerVersion) -> None:
67+
if server_version < ServerVersion(2, 4, 0):
68+
raise ValueError("The new Cypher projection is only supported since GDS 2.4.0.")
69+
super().__init__(query_runner, namespace, server_version)
70+
71+
def project(
72+
self,
73+
graph_name: str,
74+
*,
75+
nodes: Any = None,
76+
relationships: Any = None,
77+
where: Optional[str] = None,
78+
allow_disconnected_nodes: bool = False,
79+
inverse: bool = False,
80+
combine_labels_with: str = "OR",
81+
**config: Any,
82+
) -> Tuple[Graph, "Series[Any]"]:
83+
"""
84+
Project a graph using Cypher projection.
85+
86+
Parameters
87+
----------
88+
graph_name : str
89+
The name of the graph to project.
90+
nodes : Any
91+
The nodes to project. If not specified, all nodes are projected.
92+
relationships : Any
93+
The relationships to project. If not specified, all relationships
94+
are projected.
95+
where : Optional[str]
96+
A Cypher WHERE clause to filter the nodes and relationships to
97+
project.
98+
allow_disconnected_nodes : bool
99+
Whether to allow disconnected nodes in the projected graph.
100+
inverse : bool
101+
Whether to project inverse relationships. The projected graph will
102+
be configured as NATURAL.
103+
combine_labels_with : str
104+
Whether to combine node labels with AND or OR. The default is AND.
105+
Allowed values are 'AND' and 'OR'.
106+
**config : Any
107+
Additional configuration for the projection.
108+
"""
109+
110+
query_params = {"graph_name": graph_name}
111+
112+
data_config = {}
113+
114+
nodes = self._node_projections_spec(nodes)
115+
rels = self._rel_projections_spec(relationships)
116+
117+
match_part = MatchPart()
118+
match_pattern = MatchPattern(
119+
left_arrow="<-" if inverse else "-",
120+
right_arrow="-" if inverse else "->",
121+
)
122+
123+
if nodes:
124+
if len(nodes) == 1 or combine_labels_with == "AND":
125+
match_pattern = match_pattern._replace(label_filter=f":{':'.join(spec.source_label for spec in nodes)}")
126+
127+
projected_labels = [spec.name for spec in nodes]
128+
data_config["sourceNodeLabels"] = projected_labels
129+
data_config["targetNodeLabels"] = projected_labels
130+
131+
elif combine_labels_with == "OR":
132+
source_labels_filter = " OR ".join(f"source:{spec.source_label}" for spec in nodes)
133+
target_labels_filter = " OR ".join(f"target:{spec.source_label}" for spec in nodes)
134+
if allow_disconnected_nodes:
135+
match_part = match_part._replace(
136+
source_where=f"WHERE {source_labels_filter}", optional_where=f"WHERE {target_labels_filter}"
137+
)
138+
else:
139+
match_part = match_part._replace(
140+
source_where=f"WHERE ({source_labels_filter}) AND ({target_labels_filter})"
141+
)
142+
143+
data_config["sourceNodeLabels"] = "labels(source)"
144+
data_config["targetNodeLabels"] = "labels(target)"
145+
else:
146+
raise ValueError(f"Invalid value for combine_labels_with: {combine_labels_with}")
147+
148+
if rels:
149+
if len(rels) == 1:
150+
rel_var = ""
151+
data_config["relationshipType"] = rels[0].source_type
152+
else:
153+
rel_var = "rel"
154+
data_config["relationshipTypes"] = "type(rel)"
155+
match_pattern = match_pattern._replace(
156+
type_filter=f"[{rel_var}:{'|'.join(spec.source_type for spec in rels)}]"
157+
)
158+
159+
source = f"(source{match_pattern.label_filter})"
160+
if allow_disconnected_nodes:
161+
match_part = match_part._replace(
162+
match=f"MATCH {source}", optional_match=f"OPTIONAL MATCH (source){match_pattern}"
163+
)
164+
else:
165+
match_part = match_part._replace(match=f"MATCH {source}{match_pattern}")
166+
167+
match_part = str(match_part)
168+
169+
args = ["$graph_name", "source", "target"]
170+
171+
if data_config:
172+
query_params["data_config"] = data_config
173+
args += ["$data_config"]
174+
175+
if config:
176+
query_params["config"] = config
177+
args += ["$config"]
178+
179+
return_part = f"RETURN {self._namespace}({', '.join(args)})"
180+
181+
query = "\n".join(part for part in [match_part, return_part] if part)
182+
183+
print(query)
184+
185+
result = self._query_runner.run_query_with_logging(
186+
query,
187+
query_params,
188+
).squeeze()
189+
190+
return Graph(graph_name, self._query_runner, self._server_version), result
191+
192+
def _node_projections_spec(self, spec: Any) -> list[NodeProjection]:
193+
if spec is None or spec is False:
194+
return []
195+
196+
if isinstance(spec, str):
197+
spec = [spec]
198+
199+
if isinstance(spec, list):
200+
return [self._node_projection_spec(node) for node in spec]
201+
202+
if isinstance(spec, dict):
203+
return [self._node_projection_spec(node, name) for name, node in spec.items()]
204+
205+
raise TypeError(f"Invalid node projection specification: {spec}")
206+
207+
def _node_projection_spec(self, spec: Any, name: Optional[str] = None) -> NodeProjection:
208+
if isinstance(spec, str):
209+
return NodeProjection(name=name or spec, source_label=spec)
210+
211+
raise TypeError(f"Invalid node projection specification: {spec}")
212+
213+
def _node_properties_spec(self, properties: dict[str, Any]) -> list[NodeProperty]:
214+
raise TypeError(f"Invalid node projection specification: {properties}")
215+
216+
def _rel_projections_spec(self, spec: Any) -> list[RelationshipProjection]:
217+
if spec is None or spec is False:
218+
return []
219+
220+
if isinstance(spec, str):
221+
spec = [spec]
222+
223+
if isinstance(spec, list):
224+
return [self._rel_projection_spec(node) for node in spec]
225+
226+
if isinstance(spec, dict):
227+
return [self._rel_projection_spec(node, name) for name, node in spec.items()]
228+
229+
raise TypeError(f"Invalid relationship projection specification: {spec}")
230+
231+
def _rel_projection_spec(self, spec: Any, name: Optional[str] = None) -> RelationshipProjection:
232+
if isinstance(spec, str):
233+
return RelationshipProjection(name=name or spec, source_type=spec)
234+
235+
raise TypeError(f"Invalid relationship projection specification: {spec}")
236+
237+
def _rel_properties_spec(self, properties: dict[str, Any]) -> list[RelationshipProperty]:
238+
raise TypeError(f"Invalid relationship projection specification: {properties}")
239+
240+
#
241+
# def estimate(self, *, nodes: Any, relationships: Any, **config: Any) -> "Series[Any]":
242+
# pass

graphdatascience/graph/graph_proc_runner.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from .graph_sample_runner import GraphSampleRunner
2828
from .graph_type_check import graph_type_check, graph_type_check_optional
2929
from .ogb_loader import OGBLLoader, OGBNLoader
30+
from graphdatascience.graph.graph_cypher_runner import GraphCypherRunner
3031

3132
Strings = Union[str, List[str]]
3233

@@ -165,6 +166,11 @@ def project(self) -> GraphProjectRunner:
165166
self._namespace += ".project"
166167
return GraphProjectRunner(self._query_runner, self._namespace, self._server_version)
167168

169+
@property
170+
def cypher(self) -> GraphCypherRunner:
171+
self._namespace += ".project"
172+
return GraphCypherRunner(self._query_runner, self._namespace, self._server_version)
173+
168174
@property
169175
def export(self) -> GraphExportRunner:
170176
self._namespace += ".export"

0 commit comments

Comments
 (0)