@@ -82,7 +82,6 @@ def node_count(self) -> int:
8282 """
8383 Returns:
8484 the number of nodes in the graph
85-
8685 """
8786 return self ._graph_info (["nodeCount" ]) # type: ignore
8887
@@ -191,7 +190,6 @@ def drop(self, failIfMissing: bool = False) -> "Series[str]":
191190
192191 Returns:
193192 the result of the drop operation
194-
195193 """
196194 result = self ._query_runner .call_procedure (
197195 endpoint = "gds.graph.drop" ,
@@ -205,7 +203,6 @@ def creation_time(self) -> Any: # neo4j.time.DateTime not exported
205203 """
206204 Returns:
207205 the creation time of the graph
208-
209206 """
210207 return self ._graph_info (["creationTime" ])
211208
@@ -236,12 +233,53 @@ def __repr__(self) -> str:
236233
237234 def visualize (
238235 self ,
239- notebook : bool = True ,
240236 node_count : int = 100 ,
237+ directed : bool = True ,
241238 center_nodes : Optional [List [int ]] = None ,
242- include_node_properties : List [str ] = None ,
243239 color_property : Optional [str ] = None ,
240+ size_property : Optional [str ] = None ,
241+ include_node_properties : Optional [List [str ]] = None ,
242+ rel_weight_property : Optional [str ] = None ,
243+ notebook : bool = True ,
244+ px_height : int = 750 ,
245+ theme : str = "dark" ,
244246 ) -> Any :
247+ """
248+ Visualize the `Graph` in an interactive graphical interface.
249+ The graph will be sampled down to specified `node_count` to limit computationally expensive rendering.
250+
251+ Args:
252+ node_count: number of nodes in the graph to be visualized
253+ directed: whether or not to display relationships as directed
254+ center_nodes: nodes around subgraph will be sampled, if sampling is necessary
255+ color_property: node property that determines node categories for coloring. Default is to use node labels
256+ size_property: node property that determines the size of nodes. Default is to compute a page rank for this
257+ include_node_properties: node properties to include for mouse-over inspection
258+ rel_weight_property: relationship property that determines width of relationships
259+ notebook: whether or not the code is run in a notebook
260+ px_height: the height of the graphic containing output the visualization
261+ theme: coloring theme for the visualization. "light" or "dark"
262+
263+ Returns:
264+ an interactive graphical visualization of the specified graph
265+ """
266+
267+ actual_node_properties = self .node_properties ()
268+ if color_property not in actual_node_properties :
269+ raise ValueError (f"There is no node property '{ color_property } ' in graph '{ self ._name } '" )
270+
271+ if size_property not in actual_node_properties :
272+ raise ValueError (f"There is no node property '{ size_property } ' in graph '{ self ._name } '" )
273+
274+ if include_node_properties is not None :
275+ for prop in include_node_properties :
276+ if prop not in actual_node_properties :
277+ raise ValueError (f"There is no node property '{ prop } ' in graph '{ self ._name } '" )
278+
279+ actual_rel_properties = self .relationship_properties ()
280+ if rel_weight_property not in actual_rel_properties :
281+ raise ValueError (f"There is no relationship property '{ rel_weight_property } ' in graph '{ self ._name } '" )
282+
245283 visual_graph = self ._name
246284 if self .node_count () > node_count :
247285 visual_graph = str (uuid4 ())
@@ -256,14 +294,19 @@ def visualize(
256294 custom_error = False ,
257295 )
258296
259- pr_prop = str (uuid4 ())
260- self ._query_runner .call_procedure (
261- endpoint = "gds.pageRank.mutate" ,
262- params = CallParameters (graph_name = visual_graph , config = dict (mutateProperty = pr_prop )),
263- custom_error = False ,
264- )
297+ # Make sure we always have at least a size property so that we can run `gds.graph.nodeProperties.stream`
298+ if size_property is None :
299+ size_property = str (uuid4 ())
300+ self ._query_runner .call_procedure (
301+ endpoint = "gds.pageRank.mutate" ,
302+ params = CallParameters (graph_name = visual_graph , config = dict (mutateProperty = size_property )),
303+ custom_error = False ,
304+ )
305+ clean_up_size_prop = True
306+ else :
307+ clean_up_size_prop = False
265308
266- node_properties = [pr_prop ]
309+ node_properties = [size_property ]
267310 if include_node_properties is not None :
268311 node_properties .extend (include_node_properties )
269312
@@ -295,11 +338,18 @@ def visualize(
295338 result .columns .name = None
296339 node_properties_df = result
297340
298- relationships_df = self ._query_runner .call_procedure (
299- endpoint = "gds.graph.relationships.stream" ,
300- params = CallParameters (graph_name = visual_graph ),
301- custom_error = False ,
302- )
341+ if rel_weight_property is None :
342+ relationships_df = self ._query_runner .call_procedure (
343+ endpoint = "gds.graph.relationships.stream" ,
344+ params = CallParameters (graph_name = visual_graph ),
345+ custom_error = False ,
346+ )
347+ else :
348+ relationships_df = self ._query_runner .call_procedure (
349+ endpoint = "gds.graph.relationshipProperty.stream" ,
350+ params = CallParameters (graph_name = visual_graph , properties = rel_weight_property ),
351+ custom_error = False ,
352+ )
303353
304354 # Clean up
305355 if visual_graph != self ._name :
@@ -308,10 +358,10 @@ def visualize(
308358 params = CallParameters (graph_name = visual_graph ),
309359 custom_error = False ,
310360 )
311- else :
361+ elif clean_up_size_prop :
312362 self ._query_runner .call_procedure (
313363 endpoint = "gds.graph.nodeProperties.drop" ,
314- params = CallParameters (graph_name = visual_graph , nodeProperties = pr_prop ),
364+ params = CallParameters (graph_name = visual_graph , nodeProperties = size_property ),
315365 custom_error = False ,
316366 )
317367
@@ -320,19 +370,21 @@ def visualize(
320370 net = Network (
321371 notebook = True if notebook else False ,
322372 cdn_resources = "remote" if notebook else "local" ,
323- bgcolor = "#222222" , # Dark background
324- font_color = "white" ,
325- height = "750px" , # Modify according to your screen size
373+ directed = directed ,
374+ bgcolor = "#222222" if theme == "dark" else "#FDFDFD" ,
375+ font_color = "white" if theme == "dark" else "black" ,
376+ height = f"{ px_height } px" ,
326377 width = "100%" ,
327378 )
328379
329380 if color_property is None :
330- color_map = {label : self ._random_bright_color ( ) for label in self .node_labels ()}
381+ color_map = {label : self ._random_themed_color ( theme ) for label in self .node_labels ()}
331382 else :
332383 color_map = {
333- prop_val : self ._random_bright_color ( ) for prop_val in node_properties_df [color_property ].unique ()
384+ prop_val : self ._random_themed_color ( theme ) for prop_val in node_properties_df [color_property ].unique ()
334385 }
335386
387+ # Add all the nodes
336388 for _ , node in node_properties_df .iterrows ():
337389 title = f"Node ID: { node ['nodeId' ]} \n Labels: { node ['nodeLabels' ]} "
338390 if include_node_properties is not None :
@@ -347,17 +399,22 @@ def visualize(
347399
348400 net .add_node (
349401 int (node ["nodeId" ]),
350- value = node [pr_prop ],
402+ value = node [size_property ],
351403 color = color ,
352404 title = title ,
353405 )
354406
355407 # Add all the relationships
356- net .add_edges (zip (relationships_df ["sourceNodeId" ], relationships_df ["targetNodeId" ]))
408+ for _ , rel in relationships_df .iterrows ():
409+ if rel_weight_property is None :
410+ net .add_edge (rel ["sourceNodeId" ], rel ["targetNodeId" ], title = f"Type: { rel ['relationshipType' ]} " )
411+ else :
412+ title = f"Type: { rel ['relationshipType' ]} \n { rel_weight_property } = { rel ['rel_weight_property' ]} "
413+ net .add_edge (rel ["sourceNodeId" ], rel ["targetNodeId" ], title = title , value = rel [rel_weight_property ])
357414
358415 return net .show (f"{ self ._name } .html" )
359416
360417 @staticmethod
361- def _random_bright_color ( ) -> str :
362- h = random . randint ( 0 , 255 ) / 255.0
363- return "#%02X%02X%02X" % tuple (map (lambda x : int (x * 255 ), colorsys .hls_to_rgb (h , 0.7 , 1.0 )))
418+ def _random_themed_color ( theme ) -> str :
419+ l = 0.7 if theme == "dark" else 0.4
420+ return "#%02X%02X%02X" % tuple (map (lambda x : int (x * 255 ), colorsys .hls_to_rgb (random . random (), l , 1.0 )))
0 commit comments