1- from collections import namedtuple
1+ from collections import defaultdict , namedtuple
22from typing import Any , NamedTuple , Optional , Tuple
33
44from pandas import Series
@@ -62,6 +62,12 @@ def __str__(self) -> str:
6262 return f"{ self .left_arrow } { self .type_filter } { self .right_arrow } (target{ self .label_filter } )"
6363
6464
65+ class LabelPropertyMapping (NamedTuple ):
66+ label : str
67+ property_key : str
68+ default_value : Optional [Any ] = None
69+
70+
6571class GraphCypherRunner (IllegalAttrChecker ):
6672 def __init__ (self , query_runner : QueryRunner , namespace : str , server_version : ServerVersion ) -> None :
6773 if server_version < ServerVersion (2 , 4 , 0 ):
@@ -121,6 +127,8 @@ def project(
121127 right_arrow = "-" if inverse else "->" ,
122128 )
123129
130+ label_mappings = defaultdict (list )
131+
124132 if nodes :
125133 if len (nodes ) == 1 or combine_labels_with == "AND" :
126134 match_pattern = match_pattern ._replace (label_filter = f":{ ':' .join (spec .source_label for spec in nodes )} " )
@@ -147,14 +155,22 @@ def project(
147155 else :
148156 raise ValueError (f"Invalid value for combine_labels_with: { combine_labels_with } " )
149157
158+ for spec in nodes :
159+ if spec .properties :
160+ for prop in spec .properties :
161+ label_mappings [spec .source_label ].append (
162+ LabelPropertyMapping (spec .source_label , prop .property_key , prop .default_value )
163+ )
164+
165+ rel_var = ""
150166 if rels :
151167 if len (rels ) == 1 :
152- rel_var = ""
153168 data_config ["relationshipType" ] = rels [0 ].source_type
154169 else :
155170 rel_var = "rel"
156171 data_config ["relationshipTypes" ] = "type(rel)"
157172 data_config_is_static = False
173+
158174 match_pattern = match_pattern ._replace (
159175 type_filter = f"[{ rel_var } :{ '|' .join (spec .source_type for spec in rels )} ]"
160176 )
@@ -169,6 +185,24 @@ def project(
169185
170186 match_part = str (match_part )
171187
188+ case_part = []
189+ if label_mappings :
190+ with_rel = f", { rel_var } " if rel_var else ""
191+ case_part = [f"WITH source, target{ with_rel } " ]
192+ for kind in ["source" , "target" ]:
193+ case_part .append ("CASE" )
194+
195+ for label , mappings in label_mappings .items ():
196+ mappings = ", " .join (f".{ key .property_key } " for key in mappings )
197+ when_part = f"WHEN '{ label } ' in labels({ kind } ) THEN [{ kind } {{{ mappings } }}]"
198+ case_part .append (when_part )
199+
200+ case_part .append (f"END AS { kind } NodeProperties" )
201+
202+ data_config ["sourceNodeProperties" ] = "sourceNodeProperties"
203+ data_config ["targetNodeProperties" ] = "targetNodeProperties"
204+ data_config_is_static = False
205+
172206 args = ["$graph_name" , "source" , "target" ]
173207
174208 if data_config :
@@ -184,9 +218,7 @@ def project(
184218
185219 return_part = f"RETURN { self ._namespace } ({ ', ' .join (args )} )"
186220
187- query = "\n " .join (part for part in [match_part , return_part ] if part )
188-
189- print (query )
221+ query = "\n " .join (part for part in [match_part , * case_part , return_part ] if part )
190222
191223 result = self ._query_runner .run_query_with_logging (
192224 query ,
@@ -208,16 +240,39 @@ def _node_projections_spec(self, spec: Any) -> list[NodeProjection]:
208240 if isinstance (spec , dict ):
209241 return [self ._node_projection_spec (node , name ) for name , node in spec .items ()]
210242
211- raise TypeError (f"Invalid node projection specification: { spec } " )
243+ raise TypeError (f"Invalid node projections specification: { spec } " )
212244
213245 def _node_projection_spec (self , spec : Any , name : Optional [str ] = None ) -> NodeProjection :
214246 if isinstance (spec , str ):
215247 return NodeProjection (name = name or spec , source_label = spec )
216248
249+ if name is None :
250+ raise ValueError (f"Node projections with properties must use the dict syntax: { spec } " )
251+
252+ if isinstance (spec , dict ):
253+ properties = [self ._node_properties_spec (prop , name ) for name , prop in spec .items ()]
254+ return NodeProjection (name = name , source_label = name , properties = properties )
255+
256+ if isinstance (spec , list ):
257+ properties = [self ._node_properties_spec (prop ) for prop in spec ]
258+ return NodeProjection (name = name , source_label = name , properties = properties )
259+
217260 raise TypeError (f"Invalid node projection specification: { spec } " )
218261
219- def _node_properties_spec (self , properties : dict [str , Any ]) -> list [NodeProperty ]:
220- raise TypeError (f"Invalid node projection specification: { properties } " )
262+ def _node_properties_spec (self , spec : Any , name : Optional [str ] = None ) -> NodeProperty :
263+ if isinstance (spec , str ):
264+ return NodeProperty (name = name or spec , property_key = spec )
265+
266+ if name is None :
267+ raise ValueError (f"Node properties spec must be used with the dict syntax: { spec } " )
268+
269+ if spec is True :
270+ return NodeProperty (name = name , property_key = name )
271+
272+ if isinstance (spec , dict ):
273+ return NodeProperty (name = name , property_key = name , ** spec )
274+
275+ raise TypeError (f"Invalid node property specification: { spec } " )
221276
222277 def _rel_projections_spec (self , spec : Any ) -> list [RelationshipProjection ]:
223278 if spec is None or spec is False :
0 commit comments