Skip to content

Commit 143270f

Browse files
committed
Add basic node property support
1 parent f7735c0 commit 143270f

File tree

2 files changed

+94
-8
lines changed

2 files changed

+94
-8
lines changed

graphdatascience/graph/graph_cypher_runner.py

Lines changed: 63 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from collections import namedtuple
1+
from collections import defaultdict, namedtuple
22
from typing import Any, NamedTuple, Optional, Tuple
33

44
from pandas import Series
@@ -62,6 +62,12 @@ def __str__(self) -> str:
6262
return f"{self.left_arrow}{self.type_filter}{self.right_arrow}(target{self.label_filter})"
6363

6464

65+
class LabelPropertyMapping(NamedTuple):
66+
label: str
67+
property_key: str
68+
default_value: Optional[Any] = None
69+
70+
6571
class GraphCypherRunner(IllegalAttrChecker):
6672
def __init__(self, query_runner: QueryRunner, namespace: str, server_version: ServerVersion) -> None:
6773
if server_version < ServerVersion(2, 4, 0):
@@ -121,6 +127,8 @@ def project(
121127
right_arrow="-" if inverse else "->",
122128
)
123129

130+
label_mappings = defaultdict(list)
131+
124132
if nodes:
125133
if len(nodes) == 1 or combine_labels_with == "AND":
126134
match_pattern = match_pattern._replace(label_filter=f":{':'.join(spec.source_label for spec in nodes)}")
@@ -147,14 +155,22 @@ def project(
147155
else:
148156
raise ValueError(f"Invalid value for combine_labels_with: {combine_labels_with}")
149157

158+
for spec in nodes:
159+
if spec.properties:
160+
for prop in spec.properties:
161+
label_mappings[spec.source_label].append(
162+
LabelPropertyMapping(spec.source_label, prop.property_key, prop.default_value)
163+
)
164+
165+
rel_var = ""
150166
if rels:
151167
if len(rels) == 1:
152-
rel_var = ""
153168
data_config["relationshipType"] = rels[0].source_type
154169
else:
155170
rel_var = "rel"
156171
data_config["relationshipTypes"] = "type(rel)"
157172
data_config_is_static = False
173+
158174
match_pattern = match_pattern._replace(
159175
type_filter=f"[{rel_var}:{'|'.join(spec.source_type for spec in rels)}]"
160176
)
@@ -169,6 +185,24 @@ def project(
169185

170186
match_part = str(match_part)
171187

188+
case_part = []
189+
if label_mappings:
190+
with_rel = f", {rel_var}" if rel_var else ""
191+
case_part = [f"WITH source, target{with_rel}"]
192+
for kind in ["source", "target"]:
193+
case_part.append("CASE")
194+
195+
for label, mappings in label_mappings.items():
196+
mappings = ", ".join(f".{key.property_key}" for key in mappings)
197+
when_part = f"WHEN '{label}' in labels({kind}) THEN [{kind} {{{mappings}}}]"
198+
case_part.append(when_part)
199+
200+
case_part.append(f"END AS {kind}NodeProperties")
201+
202+
data_config["sourceNodeProperties"] = "sourceNodeProperties"
203+
data_config["targetNodeProperties"] = "targetNodeProperties"
204+
data_config_is_static = False
205+
172206
args = ["$graph_name", "source", "target"]
173207

174208
if data_config:
@@ -184,9 +218,7 @@ def project(
184218

185219
return_part = f"RETURN {self._namespace}({', '.join(args)})"
186220

187-
query = "\n".join(part for part in [match_part, return_part] if part)
188-
189-
print(query)
221+
query = "\n".join(part for part in [match_part, *case_part, return_part] if part)
190222

191223
result = self._query_runner.run_query_with_logging(
192224
query,
@@ -208,16 +240,39 @@ def _node_projections_spec(self, spec: Any) -> list[NodeProjection]:
208240
if isinstance(spec, dict):
209241
return [self._node_projection_spec(node, name) for name, node in spec.items()]
210242

211-
raise TypeError(f"Invalid node projection specification: {spec}")
243+
raise TypeError(f"Invalid node projections specification: {spec}")
212244

213245
def _node_projection_spec(self, spec: Any, name: Optional[str] = None) -> NodeProjection:
214246
if isinstance(spec, str):
215247
return NodeProjection(name=name or spec, source_label=spec)
216248

249+
if name is None:
250+
raise ValueError(f"Node projections with properties must use the dict syntax: {spec}")
251+
252+
if isinstance(spec, dict):
253+
properties = [self._node_properties_spec(prop, name) for name, prop in spec.items()]
254+
return NodeProjection(name=name, source_label=name, properties=properties)
255+
256+
if isinstance(spec, list):
257+
properties = [self._node_properties_spec(prop) for prop in spec]
258+
return NodeProjection(name=name, source_label=name, properties=properties)
259+
217260
raise TypeError(f"Invalid node projection specification: {spec}")
218261

219-
def _node_properties_spec(self, properties: dict[str, Any]) -> list[NodeProperty]:
220-
raise TypeError(f"Invalid node projection specification: {properties}")
262+
def _node_properties_spec(self, spec: Any, name: Optional[str] = None) -> NodeProperty:
263+
if isinstance(spec, str):
264+
return NodeProperty(name=name or spec, property_key=spec)
265+
266+
if name is None:
267+
raise ValueError(f"Node properties spec must be used with the dict syntax: {spec}")
268+
269+
if spec is True:
270+
return NodeProperty(name=name, property_key=name)
271+
272+
if isinstance(spec, dict):
273+
return NodeProperty(name=name, property_key=name, **spec)
274+
275+
raise TypeError(f"Invalid node property specification: {spec}")
221276

222277
def _rel_projections_spec(self, spec: Any) -> list[RelationshipProjection]:
223278
if spec is None or spec is False:

graphdatascience/tests/unit/test_graph_cypher.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,3 +216,34 @@ def test_multiple_multi_graph(runner: CollectingQueryRunner, gds: GraphDataScien
216216
"targetNodeLabels: labels(target), "
217217
"relationshipTypes: type(rel)})"
218218
)
219+
220+
221+
@pytest.mark.parametrize("server_version", [ServerVersion(2, 4, 0)])
222+
def test_node_properties(runner: CollectingQueryRunner, gds: GraphDataScience) -> None:
223+
G, _ = gds.graph.cypher.project(
224+
"g", nodes=dict(L1=["prop1"], L2=["prop2", "prop3"], L3=dict(prop4=True, prop5=dict()))
225+
)
226+
227+
assert G.name() == "g"
228+
assert runner.last_params() == dict(graph_name="g")
229+
230+
assert runner.last_query() == (
231+
"""MATCH (source)-->(target)
232+
WHERE (source:L1 OR source:L2 OR source:L3) AND (target:L1 OR target:L2 OR target:L3)
233+
WITH source, target
234+
CASE
235+
WHEN 'L1' in labels(source) THEN [source {.prop1}]
236+
WHEN 'L2' in labels(source) THEN [source {.prop2, .prop3}]
237+
WHEN 'L3' in labels(source) THEN [source {.prop4, .prop5}]
238+
END AS sourceNodeProperties
239+
CASE
240+
WHEN 'L1' in labels(target) THEN [target {.prop1}]
241+
WHEN 'L2' in labels(target) THEN [target {.prop2, .prop3}]
242+
WHEN 'L3' in labels(target) THEN [target {.prop4, .prop5}]
243+
END AS targetNodeProperties
244+
RETURN gds.graph.project($graph_name, source, target, {"""
245+
"sourceNodeLabels: labels(source), "
246+
"targetNodeLabels: labels(target), "
247+
"sourceNodeProperties: sourceNodeProperties, "
248+
"targetNodeProperties: targetNodeProperties})"
249+
)

0 commit comments

Comments
 (0)