1+ import itertools
12import os
23import warnings
34from dataclasses import dataclass
4- from functools import reduce
55from typing import Any , Dict , List , Optional , Set , Tuple
66from uuid import uuid4
77
1212from graphdatascience .server_version .server_version import ServerVersion
1313
1414
15- class CypherAggregationApi :
15+ class CypherProjectionApi :
1616 RELATIONSHIP_TYPE = "relationshipType"
1717 SOURCE_NODE_LABEL = "sourceNodeLabels"
18+ TARGET_NODE_LABEL = "targetNodeLabels"
1819 SOURCE_NODE_PROPERTIES = "sourceNodeProperties"
20+ TARGET_NODE_PROPERTIES = "targetNodeProperties"
1921 REL_PROPERTIES = "properties"
22+ REL_PROPERTIES_NEW = "relationshipProperties"
2023
2124
2225@dataclass
@@ -71,10 +74,14 @@ def run(self, node_dfs: List[DataFrame], relationship_dfs: List[DataFrame]) -> N
7174 "edition graph construction (slower)"
7275 )
7376
74- # Cypher aggregation supports concurrency since 2.3.0
77+ # New Cypher propjection supports concurrency since 2.3.0
7578 if self ._server_version >= ServerVersion (2 , 3 , 0 ):
76- self .CypherAggregationRunner (
77- self ._query_runner , self ._graph_name , self ._concurrency , self ._undirected_relationship_types
79+ self .CypherProjectionRunner (
80+ self ._query_runner ,
81+ self ._graph_name ,
82+ self ._concurrency ,
83+ self ._undirected_relationship_types ,
84+ self ._server_version ,
7885 ).run (node_dfs , relationship_dfs )
7986 else :
8087 assert not self ._undirected_relationship_types , "This should have been raised earlier."
@@ -91,7 +98,9 @@ def graph_construct_error_multidf(element: str) -> str:
9198 node_df = node_dfs [0 ]
9299 rel_df = relationship_dfs [0 ]
93100
94- self .CyperProjectionRunner (self ._query_runner , self ._graph_name , self ._concurrency ).run (node_df , rel_df )
101+ self .LegacyCypherProjectionRunner (self ._query_runner , self ._graph_name , self ._concurrency ).run (
102+ node_df , rel_df
103+ )
95104
96105 def _should_warn_about_arrow_missing (self ) -> bool :
97106 try :
@@ -111,7 +120,7 @@ def _should_warn_about_arrow_missing(self) -> bool:
111120
112121 return should_warn
113122
114- class CypherAggregationRunner :
123+ class CypherProjectionRunner :
115124 _BIT_COL_SUFFIX = "_is_present" + str (uuid4 ())
116125
117126 def __init__ (
@@ -120,11 +129,13 @@ def __init__(
120129 graph_name : str ,
121130 concurrency : int ,
122131 undirected_relationship_types : Optional [List [str ]],
132+ server_version : ServerVersion ,
123133 ):
124134 self ._query_runner = query_runner
125135 self ._concurrency = concurrency
126136 self ._graph_name = graph_name
127137 self ._undirected_relationship_types = undirected_relationship_types
138+ self ._server_version = server_version
128139
129140 def run (self , node_dfs : List [DataFrame ], relationship_dfs : List [DataFrame ]) -> None :
130141 graph_schema = self .schema (node_dfs , relationship_dfs )
@@ -137,8 +148,16 @@ def run(self, node_dfs: List[DataFrame], relationship_dfs: List[DataFrame]) -> N
137148 f"but the columns { same_cols } exist in both dfs. Please rename the column in one df."
138149 )
139150
140- aligned_node_dfs = self .adjust_node_dfs (node_dfs , graph_schema )
141- aligned_rel_dfs = self .adjust_rel_dfs (relationship_dfs , graph_schema )
151+ is_cypher_projection_v2 = self ._server_version >= ServerVersion (2 , 4 , 0 )
152+
153+ rel_properties_key = (
154+ CypherProjectionApi .REL_PROPERTIES_NEW
155+ if is_cypher_projection_v2
156+ else CypherProjectionApi .REL_PROPERTIES
157+ )
158+
159+ aligned_node_dfs = self .adjust_node_dfs (node_dfs , graph_schema , rel_properties_key )
160+ aligned_rel_dfs = self .adjust_rel_dfs (relationship_dfs , graph_schema , rel_properties_key )
142161
143162 # concat instead of join as we want to first have all nodes and then the rels
144163 # this way we don't duplicate the node property data and its cheaper
@@ -151,12 +170,12 @@ def run(self, node_dfs: List[DataFrame], relationship_dfs: List[DataFrame]) -> N
151170
152171 property_clauses : List [str ] = [
153172 self .check_value_clause (combined_cols , prop_col )
154- for prop_col in [CypherAggregationApi .SOURCE_NODE_PROPERTIES , CypherAggregationApi . REL_PROPERTIES ]
173+ for prop_col in [CypherProjectionApi .SOURCE_NODE_PROPERTIES , rel_properties_key ]
155174 ]
156175
157176 source_node_labels_clause = (
158- self .check_value_clause (combined_cols , CypherAggregationApi .SOURCE_NODE_LABEL )
159- if CypherAggregationApi .SOURCE_NODE_LABEL in combined_cols
177+ self .check_value_clause (combined_cols , CypherProjectionApi .SOURCE_NODE_LABEL )
178+ if CypherProjectionApi .SOURCE_NODE_LABEL in combined_cols
160179 else ""
161180 )
162181 rel_type_clause = (
@@ -166,18 +185,25 @@ def run(self, node_dfs: List[DataFrame], relationship_dfs: List[DataFrame]) -> N
166185 )
167186 target_id_clause = self .check_value_clause (combined_cols , "targetNodeId" )
168187
169- nodes_config = self .nodes_config (graph_schema .nodes_per_df )
170- rels_config = self .rels_config (graph_schema .rels_per_df )
188+ nodes_config_part = self .nodes_config_part (graph_schema .nodes_per_df , is_cypher_projection_v2 )
189+ rels_config_part = self .rels_config_part (graph_schema .rels_per_df , rel_properties_key )
190+
191+ if is_cypher_projection_v2 :
192+ data_config = f"{{{ ', ' .join (itertools .chain (nodes_config_part , rels_config_part ))} }}"
193+ else :
194+ data_config = f"{{{ ', ' .join (nodes_config_part )} }}, {{{ ', ' .join (rels_config_part )} }}"
171195
172196 property_clauses_str = f"{ os .linesep } " if len (property_clauses ) > 0 else ""
173197 property_clauses_str += f"{ os .linesep } " .join (property_clauses )[:- 2 ] # remove the final comma
174198
199+ tier = "" if is_cypher_projection_v2 else ".alpha"
200+
175201 query = (
176202 "UNWIND $data AS data"
177203 f" WITH data, { source_node_labels_clause } { rel_type_clause } { target_id_clause } { property_clauses_str } "
178- " RETURN gds.alpha .graph.project("
204+ f " RETURN gds{ tier } .graph.project("
179205 f"$graph_name, data[{ combined_cols .index ('sourceNodeId' )} ], targetNodeId, "
180- f"{ nodes_config } , { rels_config } , $configuration)"
206+ f"{ data_config } , $configuration)"
181207 )
182208
183209 configuration = {
@@ -219,7 +245,9 @@ def schema(self, node_dfs: List[DataFrame], rel_dfs: List[DataFrame]) -> GraphCo
219245
220246 return GraphColumnSchema (node_schema , rel_schema )
221247
222- def adjust_node_dfs (self , node_dfs : List [DataFrame ], schema : GraphColumnSchema ) -> List [DataFrame ]:
248+ def adjust_node_dfs (
249+ self , node_dfs : List [DataFrame ], schema : GraphColumnSchema , rel_properties_key : str
250+ ) -> List [DataFrame ]:
223251 adjusted_dfs = []
224252
225253 for i , df in enumerate (node_dfs ):
@@ -229,31 +257,33 @@ def adjust_node_dfs(self, node_dfs: List[DataFrame], schema: GraphColumnSchema)
229257 f"targetNodeId{ self ._BIT_COL_SUFFIX } " : False ,
230258 }
231259
232- if CypherAggregationApi .RELATIONSHIP_TYPE in schema .all_rels .all :
233- node_dict [CypherAggregationApi .RELATIONSHIP_TYPE ] = None
234- node_dict [CypherAggregationApi .RELATIONSHIP_TYPE + self ._BIT_COL_SUFFIX ] = False
260+ if CypherProjectionApi .RELATIONSHIP_TYPE in schema .all_rels .all :
261+ node_dict [CypherProjectionApi .RELATIONSHIP_TYPE ] = None
262+ node_dict [CypherProjectionApi .RELATIONSHIP_TYPE + self ._BIT_COL_SUFFIX ] = False
235263
236264 if "labels" in schema .nodes_per_df [i ].all :
237- node_dict [CypherAggregationApi .SOURCE_NODE_LABEL + self ._BIT_COL_SUFFIX ] = True
238- node_dict [CypherAggregationApi .SOURCE_NODE_LABEL ] = df ["labels" ]
265+ node_dict [CypherProjectionApi .SOURCE_NODE_LABEL + self ._BIT_COL_SUFFIX ] = True
266+ node_dict [CypherProjectionApi .SOURCE_NODE_LABEL ] = df ["labels" ]
239267 elif "labels" in schema .all_nodes .all :
240- node_dict [CypherAggregationApi .SOURCE_NODE_LABEL + self ._BIT_COL_SUFFIX ] = False
241- node_dict [CypherAggregationApi .SOURCE_NODE_LABEL ] = ""
268+ node_dict [CypherProjectionApi .SOURCE_NODE_LABEL + self ._BIT_COL_SUFFIX ] = False
269+ node_dict [CypherProjectionApi .SOURCE_NODE_LABEL ] = ""
242270
243271 def collect_to_dict (row : Dict [str , Any ]) -> Dict [str , Any ]:
244272 return {column : row [column ] for column in schema .nodes_per_df [i ].properties }
245273
246274 node_dict_df = DataFrame (node_dict )
247- node_dict_df [CypherAggregationApi .SOURCE_NODE_PROPERTIES ] = df .apply (collect_to_dict , axis = 1 )
248- node_dict_df [CypherAggregationApi .SOURCE_NODE_PROPERTIES + self ._BIT_COL_SUFFIX ] = True
249- node_dict_df [CypherAggregationApi . REL_PROPERTIES ] = None
250- node_dict_df [CypherAggregationApi . REL_PROPERTIES + self ._BIT_COL_SUFFIX ] = False
275+ node_dict_df [CypherProjectionApi .SOURCE_NODE_PROPERTIES ] = df .apply (collect_to_dict , axis = 1 )
276+ node_dict_df [CypherProjectionApi .SOURCE_NODE_PROPERTIES + self ._BIT_COL_SUFFIX ] = True
277+ node_dict_df [rel_properties_key ] = None
278+ node_dict_df [rel_properties_key + self ._BIT_COL_SUFFIX ] = False
251279
252280 adjusted_dfs .append (node_dict_df )
253281
254282 return adjusted_dfs
255283
256- def adjust_rel_dfs (self , rel_dfs : List [DataFrame ], schema : GraphColumnSchema ) -> List [DataFrame ]:
284+ def adjust_rel_dfs (
285+ self , rel_dfs : List [DataFrame ], schema : GraphColumnSchema , rel_properties_key : str
286+ ) -> List [DataFrame ]:
257287 adjusted_dfs = []
258288
259289 for i , df in enumerate (rel_dfs ):
@@ -263,63 +293,69 @@ def adjust_rel_dfs(self, rel_dfs: List[DataFrame], schema: GraphColumnSchema) ->
263293 f"targetNodeId{ self ._BIT_COL_SUFFIX } " : True ,
264294 }
265295
266- if CypherAggregationApi .RELATIONSHIP_TYPE in schema .rels_per_df [i ].all :
267- rel_dict [CypherAggregationApi .RELATIONSHIP_TYPE + self ._BIT_COL_SUFFIX ] = True
268- rel_dict [CypherAggregationApi .RELATIONSHIP_TYPE ] = df [CypherAggregationApi .RELATIONSHIP_TYPE ]
269- elif CypherAggregationApi .RELATIONSHIP_TYPE in schema .all_rels .all :
270- rel_dict [CypherAggregationApi .RELATIONSHIP_TYPE + self ._BIT_COL_SUFFIX ] = False
271- rel_dict [CypherAggregationApi .RELATIONSHIP_TYPE ] = None
296+ if CypherProjectionApi .RELATIONSHIP_TYPE in schema .rels_per_df [i ].all :
297+ rel_dict [CypherProjectionApi .RELATIONSHIP_TYPE + self ._BIT_COL_SUFFIX ] = True
298+ rel_dict [CypherProjectionApi .RELATIONSHIP_TYPE ] = df [CypherProjectionApi .RELATIONSHIP_TYPE ]
299+ elif CypherProjectionApi .RELATIONSHIP_TYPE in schema .all_rels .all :
300+ rel_dict [CypherProjectionApi .RELATIONSHIP_TYPE + self ._BIT_COL_SUFFIX ] = False
301+ rel_dict [CypherProjectionApi .RELATIONSHIP_TYPE ] = None
272302
273303 if "labels" in schema .all_nodes .all :
274- rel_dict [CypherAggregationApi .SOURCE_NODE_LABEL ] = None
275- rel_dict [CypherAggregationApi .SOURCE_NODE_LABEL + self ._BIT_COL_SUFFIX ] = False
304+ rel_dict [CypherProjectionApi .SOURCE_NODE_LABEL ] = None
305+ rel_dict [CypherProjectionApi .SOURCE_NODE_LABEL + self ._BIT_COL_SUFFIX ] = False
276306
277307 def collect_to_dict (row : Dict [str , Any ]) -> Dict [str , Any ]:
278308 return {column : row [column ] for column in schema .rels_per_df [i ].properties }
279309
280310 rel_dict_df = DataFrame (rel_dict )
281- rel_dict_df [CypherAggregationApi . REL_PROPERTIES ] = df .apply (collect_to_dict , axis = 1 )
282- rel_dict_df [CypherAggregationApi . REL_PROPERTIES + self ._BIT_COL_SUFFIX ] = True
283- rel_dict_df [CypherAggregationApi .SOURCE_NODE_PROPERTIES ] = None
284- rel_dict_df [CypherAggregationApi .SOURCE_NODE_PROPERTIES + self ._BIT_COL_SUFFIX ] = False
311+ rel_dict_df [rel_properties_key ] = df .apply (collect_to_dict , axis = 1 )
312+ rel_dict_df [rel_properties_key + self ._BIT_COL_SUFFIX ] = True
313+ rel_dict_df [CypherProjectionApi .SOURCE_NODE_PROPERTIES ] = None
314+ rel_dict_df [CypherProjectionApi .SOURCE_NODE_PROPERTIES + self ._BIT_COL_SUFFIX ] = False
285315
286316 adjusted_dfs .append (rel_dict_df )
287317
288318 return adjusted_dfs
289319
290- def nodes_config (self , node_cols : List [EntityColumnSchema ]) -> str :
320+ def nodes_config_part (self , node_cols : List [EntityColumnSchema ], is_cypher_projection_v2 : bool ) -> List [ str ] :
291321 # Cannot use a dictionary as we need to refer to the `data` variable in the cypher query.
292322 # Otherwise we would just pass a string such as `data[0]`
293323 nodes_config_fields : List [str ] = []
294- if reduce ( lambda x , y : x | y .has_labels (), node_cols , False ):
324+ if any ( x .has_labels () for x in node_cols ):
295325 nodes_config_fields .append (
296- f"{ CypherAggregationApi .SOURCE_NODE_LABEL } : { CypherAggregationApi .SOURCE_NODE_LABEL } "
326+ f"{ CypherProjectionApi .SOURCE_NODE_LABEL } : { CypherProjectionApi .SOURCE_NODE_LABEL } "
297327 )
328+ if is_cypher_projection_v2 :
329+ nodes_config_fields .append (
330+ f"{ CypherProjectionApi .TARGET_NODE_LABEL } : NULL" ,
331+ )
298332
299333 # as we first list all nodes at the top of the df, we don't need to lookup properties for the target node
300- if reduce ( lambda x , y : x | y .has_properties (), node_cols , False ):
334+ if any ( x .has_properties () for x in node_cols ):
301335 nodes_config_fields .append (
302- f"{ CypherAggregationApi .SOURCE_NODE_PROPERTIES } : { CypherAggregationApi .SOURCE_NODE_PROPERTIES } "
336+ f"{ CypherProjectionApi .SOURCE_NODE_PROPERTIES } : { CypherProjectionApi .SOURCE_NODE_PROPERTIES } "
303337 )
338+ if is_cypher_projection_v2 :
339+ nodes_config_fields .append (
340+ f"{ CypherProjectionApi .TARGET_NODE_PROPERTIES } : NULL" ,
341+ )
304342
305- return f"{{ { ', ' . join ( nodes_config_fields ) } }}"
343+ return nodes_config_fields
306344
307- def rels_config (self , rel_cols : List [EntityColumnSchema ]) -> str :
345+ def rels_config_part (self , rel_cols : List [EntityColumnSchema ], rel_properties_key : str ) -> List [ str ] :
308346 rels_config_fields : List [str ] = []
309347
310- if reduce ( lambda x , y : x | y .has_rel_type (), rel_cols , False ):
348+ if any ( x .has_rel_type () for x in rel_cols ):
311349 rels_config_fields .append (
312- f"{ CypherAggregationApi .RELATIONSHIP_TYPE } : { CypherAggregationApi .RELATIONSHIP_TYPE } "
350+ f"{ CypherProjectionApi .RELATIONSHIP_TYPE } : { CypherProjectionApi .RELATIONSHIP_TYPE } "
313351 )
314352
315- if reduce (lambda x , y : x | y .has_properties (), rel_cols , False ):
316- rels_config_fields .append (
317- f"{ CypherAggregationApi .REL_PROPERTIES } : { CypherAggregationApi .REL_PROPERTIES } "
318- )
353+ if any (x .has_properties () for x in rel_cols ):
354+ rels_config_fields .append (f"{ rel_properties_key } : { rel_properties_key } " )
319355
320- return f"{{ { ', ' . join ( rels_config_fields ) } }}"
356+ return rels_config_fields
321357
322- class CyperProjectionRunner :
358+ class LegacyCypherProjectionRunner :
323359 def __init__ (self , query_runner : QueryRunner , graph_name : str , concurrency : int ):
324360 self ._query_runner = query_runner
325361 self ._concurrency = concurrency
0 commit comments