@@ -235,7 +235,12 @@ def __repr__(self) -> str:
235235 return f"{ self .__class__ .__name__ } ({ self ._graph_info (yields = yield_fields ).to_dict ()} )"
236236
237237 def visualize (
238- self , node_count : int = 100 , center_nodes : Optional [List [int ]] = None , include_node_properties : List [str ] = None
238+ self ,
239+ notebook : bool = True ,
240+ node_count : int = 100 ,
241+ center_nodes : Optional [List [int ]] = None ,
242+ include_node_properties : List [str ] = None ,
243+ color_property : Optional [str ] = None ,
239244 ) -> Any :
240245 visual_graph = self ._name
241246 if self .node_count () > node_count :
@@ -262,6 +267,12 @@ def visualize(
262267 if include_node_properties is not None :
263268 node_properties .extend (include_node_properties )
264269
270+ if color_property is not None :
271+ node_properties .append (color_property )
272+
273+ # Remove possible duplicates
274+ node_properties = list (set (node_properties ))
275+
265276 result = self ._query_runner .call_procedure (
266277 endpoint = "gds.graph.nodeProperties.stream" ,
267278 params = CallParameters (
@@ -307,15 +318,20 @@ def visualize(
307318 from pyvis .network import Network
308319
309320 net = Network (
310- notebook = True ,
311- cdn_resources = "remote" ,
321+ notebook = True if notebook else False ,
322+ cdn_resources = "remote" if notebook else "local" ,
312323 bgcolor = "#222222" , # Dark background
313324 font_color = "white" ,
314325 height = "750px" , # Modify according to your screen size
315326 width = "100%" ,
316327 )
317328
318- label_to_color = {label : self ._random_bright_color () for label in self .node_labels ()}
329+ if color_property is None :
330+ color_map = {label : self ._random_bright_color () for label in self .node_labels ()}
331+ else :
332+ color_map = {
333+ prop_val : self ._random_bright_color () for prop_val in node_properties_df [color_property ].unique ()
334+ }
319335
320336 for _ , node in node_properties_df .iterrows ():
321337 title = f"Node ID: { node ['nodeId' ]} \n Labels: { node ['nodeLabels' ]} "
@@ -324,10 +340,15 @@ def visualize(
324340 for prop in include_node_properties :
325341 title += f"\n { prop } = { node [prop ]} "
326342
343+ if color_property is None :
344+ color = color_map [node ["nodeLabels" ][0 ]]
345+ else :
346+ color = color_map [node [color_property ]]
347+
327348 net .add_node (
328349 int (node ["nodeId" ]),
329350 value = node [pr_prop ],
330- color = label_to_color [ node [ "nodeLabels" ][ 0 ]] ,
351+ color = color ,
331352 title = title ,
332353 )
333354
0 commit comments