11from __future__ import annotations
2+ from itertools import chain
23
34import colorsys
45import random
@@ -82,7 +83,6 @@ def node_count(self) -> int:
8283 """
8384 Returns:
8485 the number of nodes in the graph
85-
8686 """
8787 return self ._graph_info (["nodeCount" ]) # type: ignore
8888
@@ -191,7 +191,6 @@ def drop(self, failIfMissing: bool = False) -> "Series[str]":
191191
192192 Returns:
193193 the result of the drop operation
194-
195194 """
196195 result = self ._query_runner .call_procedure (
197196 endpoint = "gds.graph.drop" ,
@@ -205,7 +204,6 @@ def creation_time(self) -> Any: # neo4j.time.DateTime not exported
205204 """
206205 Returns:
207206 the creation time of the graph
208-
209207 """
210208 return self ._graph_info (["creationTime" ])
211209
@@ -236,12 +234,56 @@ def __repr__(self) -> str:
236234
237235 def visualize (
238236 self ,
239- notebook : bool = True ,
240237 node_count : int = 100 ,
238+ directed : bool = True ,
241239 center_nodes : Optional [List [int ]] = None ,
242- include_node_properties : List [str ] = None ,
243240 color_property : Optional [str ] = None ,
241+ size_property : Optional [str ] = None ,
242+ include_node_properties : Optional [List [str ]] = None ,
243+ rel_weight_property : Optional [str ] = None ,
244+ notebook : bool = True ,
245+ px_height : int = 750 ,
246+ theme : str = "dark" ,
244247 ) -> Any :
248+ """
249+ Visualize the `Graph` in an interactive graphical interface.
250+ The graph will be sampled down to specified `node_count` to limit computationally expensive rendering.
251+
252+ Args:
253+ node_count: number of nodes in the graph to be visualized
254+ directed: whether or not to display relationships as directed
255+ center_nodes: nodes around subgraph will be sampled, if sampling is necessary
256+ color_property: node property that determines node categories for coloring. Default is to use node labels
257+ size_property: node property that determines the size of nodes. Default is to compute a page rank for this
258+ include_node_properties: node properties to include for mouse-over inspection
259+ rel_weight_property: relationship property that determines width of relationships
260+ notebook: whether or not the code is run in a notebook
261+ px_height: the height of the graphic containing output the visualization
262+ theme: coloring theme for the visualization. "light" or "dark"
263+
264+ Returns:
265+ an interactive graphical visualization of the specified graph
266+ """
267+
268+ actual_node_properties = list (chain .from_iterable (self .node_properties ().to_dict ().values ()))
269+ if (color_property is not None ) and (color_property not in actual_node_properties ):
270+ raise ValueError (f"There is no node property '{ color_property } ' in graph '{ self ._name } '" )
271+
272+ if size_property is not None and size_property not in actual_node_properties :
273+ raise ValueError (f"There is no node property '{ size_property } ' in graph '{ self ._name } '" )
274+
275+ if include_node_properties is not None :
276+ for prop in include_node_properties :
277+ if prop not in actual_node_properties :
278+ raise ValueError (f"There is no node property '{ prop } ' in graph '{ self ._name } '" )
279+
280+ actual_rel_properties = list (chain .from_iterable (self .relationship_properties ().to_dict ().values ()))
281+ if rel_weight_property is not None and rel_weight_property not in actual_rel_properties :
282+ raise ValueError (f"There is no relationship property '{ rel_weight_property } ' in graph '{ self ._name } '" )
283+
284+ if theme not in {"light" , "dark" }:
285+ raise ValueError (f"Color `theme` '{ theme } ' is not allowed. Must be either 'light' or 'dark'" )
286+
245287 visual_graph = self ._name
246288 if self .node_count () > node_count :
247289 visual_graph = str (uuid4 ())
@@ -256,14 +298,19 @@ def visualize(
256298 custom_error = False ,
257299 )
258300
259- pr_prop = str (uuid4 ())
260- self ._query_runner .call_procedure (
261- endpoint = "gds.pageRank.mutate" ,
262- params = CallParameters (graph_name = visual_graph , config = dict (mutateProperty = pr_prop )),
263- custom_error = False ,
264- )
301+ # Make sure we always have at least a size property so that we can run `gds.graph.nodeProperties.stream`
302+ if size_property is None :
303+ size_property = str (uuid4 ())
304+ self ._query_runner .call_procedure (
305+ endpoint = "gds.pageRank.mutate" ,
306+ params = CallParameters (graph_name = visual_graph , config = dict (mutateProperty = size_property )),
307+ custom_error = False ,
308+ )
309+ clean_up_size_prop = True
310+ else :
311+ clean_up_size_prop = False
265312
266- node_properties = [pr_prop ]
313+ node_properties = [size_property ]
267314 if include_node_properties is not None :
268315 node_properties .extend (include_node_properties )
269316
@@ -295,11 +342,18 @@ def visualize(
295342 result .columns .name = None
296343 node_properties_df = result
297344
298- relationships_df = self ._query_runner .call_procedure (
299- endpoint = "gds.graph.relationships.stream" ,
300- params = CallParameters (graph_name = visual_graph ),
301- custom_error = False ,
302- )
345+ if rel_weight_property is None :
346+ relationships_df = self ._query_runner .call_procedure (
347+ endpoint = "gds.graph.relationships.stream" ,
348+ params = CallParameters (graph_name = visual_graph ),
349+ custom_error = False ,
350+ )
351+ else :
352+ relationships_df = self ._query_runner .call_procedure (
353+ endpoint = "gds.graph.relationshipProperty.stream" ,
354+ params = CallParameters (graph_name = visual_graph , properties = rel_weight_property ),
355+ custom_error = False ,
356+ )
303357
304358 # Clean up
305359 if visual_graph != self ._name :
@@ -308,10 +362,10 @@ def visualize(
308362 params = CallParameters (graph_name = visual_graph ),
309363 custom_error = False ,
310364 )
311- else :
365+ elif clean_up_size_prop :
312366 self ._query_runner .call_procedure (
313367 endpoint = "gds.graph.nodeProperties.drop" ,
314- params = CallParameters (graph_name = visual_graph , nodeProperties = pr_prop ),
368+ params = CallParameters (graph_name = visual_graph , nodeProperties = size_property ),
315369 custom_error = False ,
316370 )
317371
@@ -320,19 +374,21 @@ def visualize(
320374 net = Network (
321375 notebook = True if notebook else False ,
322376 cdn_resources = "remote" if notebook else "local" ,
323- bgcolor = "#222222" , # Dark background
324- font_color = "white" ,
325- height = "750px" , # Modify according to your screen size
377+ directed = directed ,
378+ bgcolor = "#222222" if theme == "dark" else "#F2F2F2" ,
379+ font_color = "white" if theme == "dark" else "black" ,
380+ height = f"{ px_height } px" ,
326381 width = "100%" ,
327382 )
328383
329384 if color_property is None :
330- color_map = {label : self ._random_bright_color ( ) for label in self .node_labels ()}
385+ color_map = {label : self ._random_themed_color ( theme ) for label in self .node_labels ()}
331386 else :
332387 color_map = {
333- prop_val : self ._random_bright_color ( ) for prop_val in node_properties_df [color_property ].unique ()
388+ prop_val : self ._random_themed_color ( theme ) for prop_val in node_properties_df [color_property ].unique ()
334389 }
335390
391+ # Add all the nodes
336392 for _ , node in node_properties_df .iterrows ():
337393 title = f"Node ID: { node ['nodeId' ]} \n Labels: { node ['nodeLabels' ]} "
338394 if include_node_properties is not None :
@@ -347,17 +403,22 @@ def visualize(
347403
348404 net .add_node (
349405 int (node ["nodeId" ]),
350- value = node [pr_prop ],
406+ value = node [size_property ],
351407 color = color ,
352408 title = title ,
353409 )
354410
355411 # Add all the relationships
356- net .add_edges (zip (relationships_df ["sourceNodeId" ], relationships_df ["targetNodeId" ]))
412+ for _ , rel in relationships_df .iterrows ():
413+ if rel_weight_property is None :
414+ net .add_edge (rel ["sourceNodeId" ], rel ["targetNodeId" ], title = f"Type: { rel ['relationshipType' ]} " )
415+ else :
416+ title = f"Type: { rel ['relationshipType' ]} \n { rel_weight_property } = { rel ['rel_weight_property' ]} "
417+ net .add_edge (rel ["sourceNodeId" ], rel ["targetNodeId" ], title = title , value = rel [rel_weight_property ])
357418
358419 return net .show (f"{ self ._name } .html" )
359420
360421 @staticmethod
361- def _random_bright_color ( ) -> str :
362- h = random . randint ( 0 , 255 ) / 255.0
363- return "#%02X%02X%02X" % tuple (map (lambda x : int (x * 255 ), colorsys .hls_to_rgb (h , 0.7 , 1.0 )))
422+ def _random_themed_color ( theme ) -> str :
423+ l = 0.7 if theme == "dark" else 0.4
424+ return "#%02X%02X%02X" % tuple (map (lambda x : int (x * 255 ), colorsys .hls_to_rgb (random . random (), l , 1.0 )))
0 commit comments