11from __future__ import annotations
22
3+ import colorsys
4+ import random
35from types import TracebackType
46from typing import Any , List , Optional , Type , Union
57from uuid import uuid4
@@ -232,16 +234,18 @@ def __repr__(self) -> str:
232234 ]
233235 return f"{ self .__class__ .__name__ } ({ self ._graph_info (yields = yield_fields ).to_dict ()} )"
234236
235- def visualize (self , node_count : int = 100 ) :
237+ def visualize (self , node_count : int = 100 , center_nodes : Optional [ List [ int ]] = None ) -> Any :
236238 visual_graph = self ._name
237239 if self .node_count () > node_count :
238- ratio = float (node_count ) / self .node_count ()
239240 visual_graph = str (uuid4 ())
241+ config = dict (samplingRatio = float (node_count ) / self .node_count ())
242+
243+ if center_nodes is not None :
244+ config ["startNodes" ] = center_nodes
245+
240246 self ._query_runner .call_procedure (
241247 endpoint = "gds.graph.sample.rwr" ,
242- params = CallParameters (
243- graph_name = visual_graph , fromGraphName = self ._name , config = dict (samplingRatio = ratio )
244- ),
248+ params = CallParameters (graph_name = visual_graph , fromGraphName = self ._name , config = config ),
245249 custom_error = False ,
246250 )
247251
@@ -254,13 +258,22 @@ def visualize(self, node_count: int = 100):
254258
255259 result = self ._query_runner .call_procedure (
256260 endpoint = "gds.graph.nodeProperties.stream" ,
257- params = CallParameters (graph_name = visual_graph , properties = [pr_prop ]),
261+ params = CallParameters (
262+ graph_name = visual_graph ,
263+ properties = [pr_prop ],
264+ nodeLabels = self .node_labels (),
265+ config = dict (listNodeLabels = True ),
266+ ),
258267 custom_error = False ,
259268 )
260269
261270 # new format was requested, but the query was run via Cypher
262271 if "propertyValue" in result .keys ():
263272 wide_result = result .pivot (index = ["nodeId" ], columns = ["nodeProperty" ], values = "propertyValue" )
273+ # nodeLabels cannot be an index column of the pivot as its not hashable
274+ # so we need to manually join it back in
275+ labels_df = result [["nodeId" , "nodeLabels" ]].set_index ("nodeId" )
276+ wide_result = wide_result .join (labels_df , on = "nodeId" )
264277 result = wide_result .reset_index ()
265278 result .columns .name = None
266279 node_properties_df = result
@@ -271,6 +284,7 @@ def visualize(self, node_count: int = 100):
271284 custom_error = False ,
272285 )
273286
287+ # Clean up
274288 if visual_graph != self ._name :
275289 self ._query_runner .call_procedure (
276290 endpoint = "gds.graph.drop" ,
@@ -289,16 +303,28 @@ def visualize(self, node_count: int = 100):
289303 net = Network (
290304 notebook = True ,
291305 cdn_resources = "remote" ,
292- bgcolor = "#222222" ,
306+ bgcolor = "#222222" , # Dark background
293307 font_color = "white" ,
294308 height = "750px" , # Modify according to your screen size
295309 width = "100%" ,
296310 )
297311
312+ label_to_color = {label : self ._random_bright_color () for label in self .node_labels ()}
313+
298314 for _ , node in node_properties_df .iterrows ():
299- net .add_node (int (node ["nodeId" ]), value = node [pr_prop ])
315+ net .add_node (
316+ int (node ["nodeId" ]),
317+ value = node [pr_prop ],
318+ color = label_to_color [node ["nodeLabels" ][0 ]],
319+ title = str (node ["nodeId" ]),
320+ )
300321
301322 # Add all the relationships
302323 net .add_edges (zip (relationships_df ["sourceNodeId" ], relationships_df ["targetNodeId" ]))
303324
304325 return net .show (f"{ self ._name } .html" )
326+
327+ @staticmethod
328+ def _random_bright_color () -> str :
329+ h = random .randint (0 , 255 ) / 255.0
330+ return "#%02X%02X%02X" % tuple (map (lambda x : int (x * 255 ), colorsys .hls_to_rgb (h , 0.7 , 1.0 )))
0 commit comments