1- from collections import namedtuple
1+ from collections import defaultdict , namedtuple
22from typing import Any , NamedTuple , Optional , Tuple
33
44from pandas import Series
@@ -72,6 +72,12 @@ def __str__(self) -> str:
7272 return f"{ self .left_arrow } { self .type_filter } { self .right_arrow } (target{ self .label_filter } )"
7373
7474
75+ class LabelPropertyMapping (NamedTuple ):
76+ label : str
77+ property_key : str
78+ default_value : Optional [Any ] = None
79+
80+
7581class GraphCypherRunner (IllegalAttrChecker ):
7682 def __init__ (self , query_runner : QueryRunner , namespace : str , server_version : ServerVersion ) -> None :
7783 if server_version < ServerVersion (2 , 4 , 0 ):
@@ -131,6 +137,8 @@ def project(
131137 right_arrow = "-" if inverse else "->" ,
132138 )
133139
140+ label_mappings = defaultdict (list )
141+
134142 if nodes :
135143 if len (nodes ) == 1 or combine_labels_with == "AND" :
136144 match_pattern = match_pattern ._replace (label_filter = f":{ ':' .join (spec .source_label for spec in nodes )} " )
@@ -157,14 +165,22 @@ def project(
157165 else :
158166 raise ValueError (f"Invalid value for combine_labels_with: { combine_labels_with } " )
159167
168+ for spec in nodes :
169+ if spec .properties :
170+ for prop in spec .properties :
171+ label_mappings [spec .source_label ].append (
172+ LabelPropertyMapping (spec .source_label , prop .property_key , prop .default_value )
173+ )
174+
175+ rel_var = ""
160176 if rels :
161177 if len (rels ) == 1 :
162- rel_var = ""
163178 data_config ["relationshipType" ] = rels [0 ].source_type
164179 else :
165180 rel_var = "rel"
166181 data_config ["relationshipTypes" ] = "type(rel)"
167182 data_config_is_static = False
183+
168184 match_pattern = match_pattern ._replace (
169185 type_filter = f"[{ rel_var } :{ '|' .join (spec .source_type for spec in rels )} ]"
170186 )
@@ -179,6 +195,24 @@ def project(
179195
180196 match_part = str (match_part )
181197
198+ case_part = []
199+ if label_mappings :
200+ with_rel = f", { rel_var } " if rel_var else ""
201+ case_part = [f"WITH source, target{ with_rel } " ]
202+ for kind in ["source" , "target" ]:
203+ case_part .append ("CASE" )
204+
205+ for label , mappings in label_mappings .items ():
206+ mappings = ", " .join (f".{ key .property_key } " for key in mappings )
207+ when_part = f"WHEN '{ label } ' in labels({ kind } ) THEN [{ kind } {{{ mappings } }}]"
208+ case_part .append (when_part )
209+
210+ case_part .append (f"END AS { kind } NodeProperties" )
211+
212+ data_config ["sourceNodeProperties" ] = "sourceNodeProperties"
213+ data_config ["targetNodeProperties" ] = "targetNodeProperties"
214+ data_config_is_static = False
215+
182216 args = ["$graph_name" , "source" , "target" ]
183217
184218 if data_config :
@@ -194,9 +228,7 @@ def project(
194228
195229 return_part = f"RETURN { self ._namespace } ({ ', ' .join (args )} )"
196230
197- query = "\n " .join (part for part in [match_part , return_part ] if part )
198-
199- print (query )
231+ query = "\n " .join (part for part in [match_part , * case_part , return_part ] if part )
200232
201233 result = self ._query_runner .run_query_with_logging (
202234 query ,
@@ -218,16 +250,39 @@ def _node_projections_spec(self, spec: Any) -> list[NodeProjection]:
218250 if isinstance (spec , dict ):
219251 return [self ._node_projection_spec (node , name ) for name , node in spec .items ()]
220252
221- raise TypeError (f"Invalid node projection specification: { spec } " )
253+ raise TypeError (f"Invalid node projections specification: { spec } " )
222254
223255 def _node_projection_spec (self , spec : Any , name : Optional [str ] = None ) -> NodeProjection :
224256 if isinstance (spec , str ):
225257 return NodeProjection (name = name or spec , source_label = spec )
226258
259+ if name is None :
260+ raise ValueError (f"Node projections with properties must use the dict syntax: { spec } " )
261+
262+ if isinstance (spec , dict ):
263+ properties = [self ._node_properties_spec (prop , name ) for name , prop in spec .items ()]
264+ return NodeProjection (name = name , source_label = name , properties = properties )
265+
266+ if isinstance (spec , list ):
267+ properties = [self ._node_properties_spec (prop ) for prop in spec ]
268+ return NodeProjection (name = name , source_label = name , properties = properties )
269+
227270 raise TypeError (f"Invalid node projection specification: { spec } " )
228271
229- def _node_properties_spec (self , properties : dict [str , Any ]) -> list [NodeProperty ]:
230- raise TypeError (f"Invalid node projection specification: { properties } " )
272+ def _node_properties_spec (self , spec : Any , name : Optional [str ] = None ) -> NodeProperty :
273+ if isinstance (spec , str ):
274+ return NodeProperty (name = name or spec , property_key = spec )
275+
276+ if name is None :
277+ raise ValueError (f"Node properties spec must be used with the dict syntax: { spec } " )
278+
279+ if spec is True :
280+ return NodeProperty (name = name , property_key = name )
281+
282+ if isinstance (spec , dict ):
283+ return NodeProperty (name = name , property_key = name , ** spec )
284+
285+ raise TypeError (f"Invalid node property specification: { spec } " )
231286
232287 def _rel_projections_spec (self , spec : Any ) -> list [RelationshipProjection ]:
233288 if spec is None or spec is False :
0 commit comments