Skip to content

Commit 46d58e2

Browse files
authored
Merge pull request #396 from knutwalker/cpv2
Use new Cypher projections API when running against 2.4.0+
2 parents d9fc752 + fc274fe commit 46d58e2

File tree

2 files changed

+154
-70
lines changed

2 files changed

+154
-70
lines changed

graphdatascience/query_runner/cypher_graph_constructor.py

Lines changed: 92 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1+
import itertools
12
import os
23
import warnings
34
from dataclasses import dataclass
4-
from functools import reduce
55
from typing import Any, Dict, List, Optional, Set, Tuple
66
from uuid import uuid4
77

@@ -12,11 +12,14 @@
1212
from graphdatascience.server_version.server_version import ServerVersion
1313

1414

15-
class CypherAggregationApi:
15+
class CypherProjectionApi:
1616
RELATIONSHIP_TYPE = "relationshipType"
1717
SOURCE_NODE_LABEL = "sourceNodeLabels"
18+
TARGET_NODE_LABEL = "targetNodeLabels"
1819
SOURCE_NODE_PROPERTIES = "sourceNodeProperties"
20+
TARGET_NODE_PROPERTIES = "targetNodeProperties"
1921
REL_PROPERTIES = "properties"
22+
REL_PROPERTIES_NEW = "relationshipProperties"
2023

2124

2225
@dataclass
@@ -71,10 +74,14 @@ def run(self, node_dfs: List[DataFrame], relationship_dfs: List[DataFrame]) -> N
7174
"edition graph construction (slower)"
7275
)
7376

74-
# Cypher aggregation supports concurrency since 2.3.0
77+
# New Cypher propjection supports concurrency since 2.3.0
7578
if self._server_version >= ServerVersion(2, 3, 0):
76-
self.CypherAggregationRunner(
77-
self._query_runner, self._graph_name, self._concurrency, self._undirected_relationship_types
79+
self.CypherProjectionRunner(
80+
self._query_runner,
81+
self._graph_name,
82+
self._concurrency,
83+
self._undirected_relationship_types,
84+
self._server_version,
7885
).run(node_dfs, relationship_dfs)
7986
else:
8087
assert not self._undirected_relationship_types, "This should have been raised earlier."
@@ -91,7 +98,9 @@ def graph_construct_error_multidf(element: str) -> str:
9198
node_df = node_dfs[0]
9299
rel_df = relationship_dfs[0]
93100

94-
self.CyperProjectionRunner(self._query_runner, self._graph_name, self._concurrency).run(node_df, rel_df)
101+
self.LegacyCypherProjectionRunner(self._query_runner, self._graph_name, self._concurrency).run(
102+
node_df, rel_df
103+
)
95104

96105
def _should_warn_about_arrow_missing(self) -> bool:
97106
try:
@@ -111,7 +120,7 @@ def _should_warn_about_arrow_missing(self) -> bool:
111120

112121
return should_warn
113122

114-
class CypherAggregationRunner:
123+
class CypherProjectionRunner:
115124
_BIT_COL_SUFFIX = "_is_present" + str(uuid4())
116125

117126
def __init__(
@@ -120,11 +129,13 @@ def __init__(
120129
graph_name: str,
121130
concurrency: int,
122131
undirected_relationship_types: Optional[List[str]],
132+
server_version: ServerVersion,
123133
):
124134
self._query_runner = query_runner
125135
self._concurrency = concurrency
126136
self._graph_name = graph_name
127137
self._undirected_relationship_types = undirected_relationship_types
138+
self._server_version = server_version
128139

129140
def run(self, node_dfs: List[DataFrame], relationship_dfs: List[DataFrame]) -> None:
130141
graph_schema = self.schema(node_dfs, relationship_dfs)
@@ -137,8 +148,16 @@ def run(self, node_dfs: List[DataFrame], relationship_dfs: List[DataFrame]) -> N
137148
f"but the columns {same_cols} exist in both dfs. Please rename the column in one df."
138149
)
139150

140-
aligned_node_dfs = self.adjust_node_dfs(node_dfs, graph_schema)
141-
aligned_rel_dfs = self.adjust_rel_dfs(relationship_dfs, graph_schema)
151+
is_cypher_projection_v2 = self._server_version >= ServerVersion(2, 4, 0)
152+
153+
rel_properties_key = (
154+
CypherProjectionApi.REL_PROPERTIES_NEW
155+
if is_cypher_projection_v2
156+
else CypherProjectionApi.REL_PROPERTIES
157+
)
158+
159+
aligned_node_dfs = self.adjust_node_dfs(node_dfs, graph_schema, rel_properties_key)
160+
aligned_rel_dfs = self.adjust_rel_dfs(relationship_dfs, graph_schema, rel_properties_key)
142161

143162
# concat instead of join as we want to first have all nodes and then the rels
144163
# this way we don't duplicate the node property data and its cheaper
@@ -151,12 +170,12 @@ def run(self, node_dfs: List[DataFrame], relationship_dfs: List[DataFrame]) -> N
151170

152171
property_clauses: List[str] = [
153172
self.check_value_clause(combined_cols, prop_col)
154-
for prop_col in [CypherAggregationApi.SOURCE_NODE_PROPERTIES, CypherAggregationApi.REL_PROPERTIES]
173+
for prop_col in [CypherProjectionApi.SOURCE_NODE_PROPERTIES, rel_properties_key]
155174
]
156175

157176
source_node_labels_clause = (
158-
self.check_value_clause(combined_cols, CypherAggregationApi.SOURCE_NODE_LABEL)
159-
if CypherAggregationApi.SOURCE_NODE_LABEL in combined_cols
177+
self.check_value_clause(combined_cols, CypherProjectionApi.SOURCE_NODE_LABEL)
178+
if CypherProjectionApi.SOURCE_NODE_LABEL in combined_cols
160179
else ""
161180
)
162181
rel_type_clause = (
@@ -166,18 +185,25 @@ def run(self, node_dfs: List[DataFrame], relationship_dfs: List[DataFrame]) -> N
166185
)
167186
target_id_clause = self.check_value_clause(combined_cols, "targetNodeId")
168187

169-
nodes_config = self.nodes_config(graph_schema.nodes_per_df)
170-
rels_config = self.rels_config(graph_schema.rels_per_df)
188+
nodes_config_part = self.nodes_config_part(graph_schema.nodes_per_df, is_cypher_projection_v2)
189+
rels_config_part = self.rels_config_part(graph_schema.rels_per_df, rel_properties_key)
190+
191+
if is_cypher_projection_v2:
192+
data_config = f"{{{', '.join(itertools.chain(nodes_config_part, rels_config_part))}}}"
193+
else:
194+
data_config = f"{{{', '.join(nodes_config_part)}}}, {{{', '.join(rels_config_part)}}}"
171195

172196
property_clauses_str = f"{os.linesep}" if len(property_clauses) > 0 else ""
173197
property_clauses_str += f"{os.linesep}".join(property_clauses)[:-2] # remove the final comma
174198

199+
tier = "" if is_cypher_projection_v2 else ".alpha"
200+
175201
query = (
176202
"UNWIND $data AS data"
177203
f" WITH data, {source_node_labels_clause}{rel_type_clause}{target_id_clause}{property_clauses_str}"
178-
" RETURN gds.alpha.graph.project("
204+
f" RETURN gds{tier}.graph.project("
179205
f"$graph_name, data[{combined_cols.index('sourceNodeId')}], targetNodeId, "
180-
f"{nodes_config}, {rels_config}, $configuration)"
206+
f"{data_config}, $configuration)"
181207
)
182208

183209
configuration = {
@@ -219,7 +245,9 @@ def schema(self, node_dfs: List[DataFrame], rel_dfs: List[DataFrame]) -> GraphCo
219245

220246
return GraphColumnSchema(node_schema, rel_schema)
221247

222-
def adjust_node_dfs(self, node_dfs: List[DataFrame], schema: GraphColumnSchema) -> List[DataFrame]:
248+
def adjust_node_dfs(
249+
self, node_dfs: List[DataFrame], schema: GraphColumnSchema, rel_properties_key: str
250+
) -> List[DataFrame]:
223251
adjusted_dfs = []
224252

225253
for i, df in enumerate(node_dfs):
@@ -229,31 +257,33 @@ def adjust_node_dfs(self, node_dfs: List[DataFrame], schema: GraphColumnSchema)
229257
f"targetNodeId{self._BIT_COL_SUFFIX}": False,
230258
}
231259

232-
if CypherAggregationApi.RELATIONSHIP_TYPE in schema.all_rels.all:
233-
node_dict[CypherAggregationApi.RELATIONSHIP_TYPE] = None
234-
node_dict[CypherAggregationApi.RELATIONSHIP_TYPE + self._BIT_COL_SUFFIX] = False
260+
if CypherProjectionApi.RELATIONSHIP_TYPE in schema.all_rels.all:
261+
node_dict[CypherProjectionApi.RELATIONSHIP_TYPE] = None
262+
node_dict[CypherProjectionApi.RELATIONSHIP_TYPE + self._BIT_COL_SUFFIX] = False
235263

236264
if "labels" in schema.nodes_per_df[i].all:
237-
node_dict[CypherAggregationApi.SOURCE_NODE_LABEL + self._BIT_COL_SUFFIX] = True
238-
node_dict[CypherAggregationApi.SOURCE_NODE_LABEL] = df["labels"]
265+
node_dict[CypherProjectionApi.SOURCE_NODE_LABEL + self._BIT_COL_SUFFIX] = True
266+
node_dict[CypherProjectionApi.SOURCE_NODE_LABEL] = df["labels"]
239267
elif "labels" in schema.all_nodes.all:
240-
node_dict[CypherAggregationApi.SOURCE_NODE_LABEL + self._BIT_COL_SUFFIX] = False
241-
node_dict[CypherAggregationApi.SOURCE_NODE_LABEL] = ""
268+
node_dict[CypherProjectionApi.SOURCE_NODE_LABEL + self._BIT_COL_SUFFIX] = False
269+
node_dict[CypherProjectionApi.SOURCE_NODE_LABEL] = ""
242270

243271
def collect_to_dict(row: Dict[str, Any]) -> Dict[str, Any]:
244272
return {column: row[column] for column in schema.nodes_per_df[i].properties}
245273

246274
node_dict_df = DataFrame(node_dict)
247-
node_dict_df[CypherAggregationApi.SOURCE_NODE_PROPERTIES] = df.apply(collect_to_dict, axis=1)
248-
node_dict_df[CypherAggregationApi.SOURCE_NODE_PROPERTIES + self._BIT_COL_SUFFIX] = True
249-
node_dict_df[CypherAggregationApi.REL_PROPERTIES] = None
250-
node_dict_df[CypherAggregationApi.REL_PROPERTIES + self._BIT_COL_SUFFIX] = False
275+
node_dict_df[CypherProjectionApi.SOURCE_NODE_PROPERTIES] = df.apply(collect_to_dict, axis=1)
276+
node_dict_df[CypherProjectionApi.SOURCE_NODE_PROPERTIES + self._BIT_COL_SUFFIX] = True
277+
node_dict_df[rel_properties_key] = None
278+
node_dict_df[rel_properties_key + self._BIT_COL_SUFFIX] = False
251279

252280
adjusted_dfs.append(node_dict_df)
253281

254282
return adjusted_dfs
255283

256-
def adjust_rel_dfs(self, rel_dfs: List[DataFrame], schema: GraphColumnSchema) -> List[DataFrame]:
284+
def adjust_rel_dfs(
285+
self, rel_dfs: List[DataFrame], schema: GraphColumnSchema, rel_properties_key: str
286+
) -> List[DataFrame]:
257287
adjusted_dfs = []
258288

259289
for i, df in enumerate(rel_dfs):
@@ -263,63 +293,69 @@ def adjust_rel_dfs(self, rel_dfs: List[DataFrame], schema: GraphColumnSchema) ->
263293
f"targetNodeId{self._BIT_COL_SUFFIX}": True,
264294
}
265295

266-
if CypherAggregationApi.RELATIONSHIP_TYPE in schema.rels_per_df[i].all:
267-
rel_dict[CypherAggregationApi.RELATIONSHIP_TYPE + self._BIT_COL_SUFFIX] = True
268-
rel_dict[CypherAggregationApi.RELATIONSHIP_TYPE] = df[CypherAggregationApi.RELATIONSHIP_TYPE]
269-
elif CypherAggregationApi.RELATIONSHIP_TYPE in schema.all_rels.all:
270-
rel_dict[CypherAggregationApi.RELATIONSHIP_TYPE + self._BIT_COL_SUFFIX] = False
271-
rel_dict[CypherAggregationApi.RELATIONSHIP_TYPE] = None
296+
if CypherProjectionApi.RELATIONSHIP_TYPE in schema.rels_per_df[i].all:
297+
rel_dict[CypherProjectionApi.RELATIONSHIP_TYPE + self._BIT_COL_SUFFIX] = True
298+
rel_dict[CypherProjectionApi.RELATIONSHIP_TYPE] = df[CypherProjectionApi.RELATIONSHIP_TYPE]
299+
elif CypherProjectionApi.RELATIONSHIP_TYPE in schema.all_rels.all:
300+
rel_dict[CypherProjectionApi.RELATIONSHIP_TYPE + self._BIT_COL_SUFFIX] = False
301+
rel_dict[CypherProjectionApi.RELATIONSHIP_TYPE] = None
272302

273303
if "labels" in schema.all_nodes.all:
274-
rel_dict[CypherAggregationApi.SOURCE_NODE_LABEL] = None
275-
rel_dict[CypherAggregationApi.SOURCE_NODE_LABEL + self._BIT_COL_SUFFIX] = False
304+
rel_dict[CypherProjectionApi.SOURCE_NODE_LABEL] = None
305+
rel_dict[CypherProjectionApi.SOURCE_NODE_LABEL + self._BIT_COL_SUFFIX] = False
276306

277307
def collect_to_dict(row: Dict[str, Any]) -> Dict[str, Any]:
278308
return {column: row[column] for column in schema.rels_per_df[i].properties}
279309

280310
rel_dict_df = DataFrame(rel_dict)
281-
rel_dict_df[CypherAggregationApi.REL_PROPERTIES] = df.apply(collect_to_dict, axis=1)
282-
rel_dict_df[CypherAggregationApi.REL_PROPERTIES + self._BIT_COL_SUFFIX] = True
283-
rel_dict_df[CypherAggregationApi.SOURCE_NODE_PROPERTIES] = None
284-
rel_dict_df[CypherAggregationApi.SOURCE_NODE_PROPERTIES + self._BIT_COL_SUFFIX] = False
311+
rel_dict_df[rel_properties_key] = df.apply(collect_to_dict, axis=1)
312+
rel_dict_df[rel_properties_key + self._BIT_COL_SUFFIX] = True
313+
rel_dict_df[CypherProjectionApi.SOURCE_NODE_PROPERTIES] = None
314+
rel_dict_df[CypherProjectionApi.SOURCE_NODE_PROPERTIES + self._BIT_COL_SUFFIX] = False
285315

286316
adjusted_dfs.append(rel_dict_df)
287317

288318
return adjusted_dfs
289319

290-
def nodes_config(self, node_cols: List[EntityColumnSchema]) -> str:
320+
def nodes_config_part(self, node_cols: List[EntityColumnSchema], is_cypher_projection_v2: bool) -> List[str]:
291321
# Cannot use a dictionary as we need to refer to the `data` variable in the cypher query.
292322
# Otherwise we would just pass a string such as `data[0]`
293323
nodes_config_fields: List[str] = []
294-
if reduce(lambda x, y: x | y.has_labels(), node_cols, False):
324+
if any(x.has_labels() for x in node_cols):
295325
nodes_config_fields.append(
296-
f"{CypherAggregationApi.SOURCE_NODE_LABEL}: {CypherAggregationApi.SOURCE_NODE_LABEL}"
326+
f"{CypherProjectionApi.SOURCE_NODE_LABEL}: {CypherProjectionApi.SOURCE_NODE_LABEL}"
297327
)
328+
if is_cypher_projection_v2:
329+
nodes_config_fields.append(
330+
f"{CypherProjectionApi.TARGET_NODE_LABEL}: NULL",
331+
)
298332

299333
# as we first list all nodes at the top of the df, we don't need to lookup properties for the target node
300-
if reduce(lambda x, y: x | y.has_properties(), node_cols, False):
334+
if any(x.has_properties() for x in node_cols):
301335
nodes_config_fields.append(
302-
f"{CypherAggregationApi.SOURCE_NODE_PROPERTIES}: {CypherAggregationApi.SOURCE_NODE_PROPERTIES}"
336+
f"{CypherProjectionApi.SOURCE_NODE_PROPERTIES}: {CypherProjectionApi.SOURCE_NODE_PROPERTIES}"
303337
)
338+
if is_cypher_projection_v2:
339+
nodes_config_fields.append(
340+
f"{CypherProjectionApi.TARGET_NODE_PROPERTIES}: NULL",
341+
)
304342

305-
return f"{{{', '.join(nodes_config_fields)}}}"
343+
return nodes_config_fields
306344

307-
def rels_config(self, rel_cols: List[EntityColumnSchema]) -> str:
345+
def rels_config_part(self, rel_cols: List[EntityColumnSchema], rel_properties_key: str) -> List[str]:
308346
rels_config_fields: List[str] = []
309347

310-
if reduce(lambda x, y: x | y.has_rel_type(), rel_cols, False):
348+
if any(x.has_rel_type() for x in rel_cols):
311349
rels_config_fields.append(
312-
f"{CypherAggregationApi.RELATIONSHIP_TYPE}: {CypherAggregationApi.RELATIONSHIP_TYPE}"
350+
f"{CypherProjectionApi.RELATIONSHIP_TYPE}: {CypherProjectionApi.RELATIONSHIP_TYPE}"
313351
)
314352

315-
if reduce(lambda x, y: x | y.has_properties(), rel_cols, False):
316-
rels_config_fields.append(
317-
f"{CypherAggregationApi.REL_PROPERTIES}: {CypherAggregationApi.REL_PROPERTIES}"
318-
)
353+
if any(x.has_properties() for x in rel_cols):
354+
rels_config_fields.append(f"{rel_properties_key}: {rel_properties_key}")
319355

320-
return f"{{{', '.join(rels_config_fields)}}}"
356+
return rels_config_fields
321357

322-
class CyperProjectionRunner:
358+
class LegacyCypherProjectionRunner:
323359
def __init__(self, query_runner: QueryRunner, graph_name: str, concurrency: int):
324360
self._query_runner = query_runner
325361
self._concurrency = concurrency

0 commit comments

Comments
 (0)