Skip to content

Commit f7735c0

Browse files
committed
Move non-static data config from params into query
1 parent cf63d86 commit f7735c0

File tree

2 files changed

+28
-25
lines changed

2 files changed

+28
-25
lines changed

graphdatascience/graph/graph_cypher_runner.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ def project(
110110
query_params = {"graph_name": graph_name}
111111

112112
data_config = {}
113+
data_config_is_static = True
113114

114115
nodes = self._node_projections_spec(nodes)
115116
rels = self._rel_projections_spec(relationships)
@@ -142,6 +143,7 @@ def project(
142143

143144
data_config["sourceNodeLabels"] = "labels(source)"
144145
data_config["targetNodeLabels"] = "labels(target)"
146+
data_config_is_static = False
145147
else:
146148
raise ValueError(f"Invalid value for combine_labels_with: {combine_labels_with}")
147149

@@ -152,6 +154,7 @@ def project(
152154
else:
153155
rel_var = "rel"
154156
data_config["relationshipTypes"] = "type(rel)"
157+
data_config_is_static = False
155158
match_pattern = match_pattern._replace(
156159
type_filter=f"[{rel_var}:{'|'.join(spec.source_type for spec in rels)}]"
157160
)
@@ -169,8 +172,11 @@ def project(
169172
args = ["$graph_name", "source", "target"]
170173

171174
if data_config:
172-
query_params["data_config"] = data_config
173-
args += ["$data_config"]
175+
if data_config_is_static:
176+
query_params["data_config"] = data_config
177+
args += ["$data_config"]
178+
else:
179+
args += [self._render_map(data_config)]
174180

175181
if config:
176182
query_params["config"] = config
@@ -237,6 +243,9 @@ def _rel_projection_spec(self, spec: Any, name: Optional[str] = None) -> Relatio
237243
def _rel_properties_spec(self, properties: dict[str, Any]) -> list[RelationshipProperty]:
238244
raise TypeError(f"Invalid relationship projection specification: {properties}")
239245

246+
def _render_map(self, mapping: dict[str, Any]) -> str:
247+
return "{" + ", ".join(f"{key}: {value}" for key, value in mapping.items()) + "}"
248+
240249
#
241250
# def estimate(self, *, nodes: Any, relationships: Any, **config: Any) -> "Series[Any]":
242251
# pass

graphdatascience/tests/unit/test_graph_cypher.py

Lines changed: 17 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -136,15 +136,14 @@ def test_multiple_node_labels_or(runner: CollectingQueryRunner, gds: GraphDataSc
136136
G, _ = gds.graph.cypher.project("g", nodes=["A", "B"], combine_labels_with="OR")
137137

138138
assert G.name() == "g"
139-
assert runner.last_params() == dict(
140-
graph_name="g", data_config={"sourceNodeLabels": "labels(source)", "targetNodeLabels": "labels(target)"}
141-
)
139+
assert runner.last_params() == dict(graph_name="g")
142140

143-
assert (
144-
runner.last_query()
145-
== """MATCH (source)-->(target)
141+
assert runner.last_query() == (
142+
"""MATCH (source)-->(target)
146143
WHERE (source:A OR source:B) AND (target:A OR target:B)
147-
RETURN gds.graph.project($graph_name, source, target, $data_config)"""
144+
RETURN gds.graph.project($graph_name, source, target, {"""
145+
"sourceNodeLabels: labels(source), "
146+
"targetNodeLabels: labels(target)})"
148147
)
149148

150149

@@ -153,17 +152,16 @@ def test_disconnected_nodes_multiple_node_labels_or(runner: CollectingQueryRunne
153152
G, _ = gds.graph.cypher.project("g", nodes=["A", "B"], combine_labels_with="OR", allow_disconnected_nodes=True)
154153

155154
assert G.name() == "g"
156-
assert runner.last_params() == dict(
157-
graph_name="g", data_config={"sourceNodeLabels": "labels(source)", "targetNodeLabels": "labels(target)"}
158-
)
155+
assert runner.last_params() == dict(graph_name="g")
159156

160-
assert (
161-
runner.last_query()
162-
== """MATCH (source)
157+
assert runner.last_query() == (
158+
"""MATCH (source)
163159
WHERE source:A OR source:B
164160
OPTIONAL MATCH (source)-->(target)
165161
WHERE target:A OR target:B
166-
RETURN gds.graph.project($graph_name, source, target, $data_config)"""
162+
RETURN gds.graph.project($graph_name, source, target, {"""
163+
"sourceNodeLabels: labels(source), "
164+
"targetNodeLabels: labels(target)})"
167165
)
168166

169167

@@ -207,18 +205,14 @@ def test_multiple_multi_graph(runner: CollectingQueryRunner, gds: GraphDataScien
207205
G, _ = gds.graph.cypher.project("g", nodes=["A", "B"], relationships=["REL1", "REL2"])
208206

209207
assert G.name() == "g"
210-
assert runner.last_params() == dict(
211-
graph_name="g",
212-
data_config={
213-
"sourceNodeLabels": "labels(source)",
214-
"targetNodeLabels": "labels(target)",
215-
"relationshipTypes": "type(rel)",
216-
},
217-
)
208+
assert runner.last_params() == dict(graph_name="g")
218209

219210
assert (
220211
runner.last_query()
221212
== """MATCH (source)-[rel:REL1|REL2]->(target)
222213
WHERE (source:A OR source:B) AND (target:A OR target:B)
223-
RETURN gds.graph.project($graph_name, source, target, $data_config)"""
214+
RETURN gds.graph.project($graph_name, source, target, {"""
215+
"sourceNodeLabels: labels(source), "
216+
"targetNodeLabels: labels(target), "
217+
"relationshipTypes: type(rel)})"
224218
)

0 commit comments

Comments
 (0)