|
2 | 2 |
|
3 | 3 | from types import TracebackType |
4 | 4 | from typing import Any, List, Optional, Type, Union |
| 5 | +from uuid import uuid4 |
5 | 6 |
|
6 | 7 | from pandas import Series |
7 | 8 |
|
@@ -230,3 +231,74 @@ def __repr__(self) -> str: |
230 | 231 | "memoryUsage", |
231 | 232 | ] |
232 | 233 | return f"{self.__class__.__name__}({self._graph_info(yields=yield_fields).to_dict()})" |
| 234 | + |
| 235 | + def visualize(self, node_count: int = 100): |
| 236 | + visual_graph = self._name |
| 237 | + if self.node_count() > node_count: |
| 238 | + ratio = float(node_count) / self.node_count() |
| 239 | + visual_graph = str(uuid4()) |
| 240 | + self._query_runner.call_procedure( |
| 241 | + endpoint="gds.graph.sample.rwr", |
| 242 | + params=CallParameters( |
| 243 | + graph_name=visual_graph, fromGraphName=self._name, config=dict(samplingRatio=ratio) |
| 244 | + ), |
| 245 | + custom_error=False, |
| 246 | + ) |
| 247 | + |
| 248 | + pr_prop = str(uuid4()) |
| 249 | + self._query_runner.call_procedure( |
| 250 | + endpoint="gds.pageRank.mutate", |
| 251 | + params=CallParameters(graph_name=visual_graph, config=dict(mutateProperty=pr_prop)), |
| 252 | + custom_error=False, |
| 253 | + ) |
| 254 | + |
| 255 | + result = self._query_runner.call_procedure( |
| 256 | + endpoint="gds.graph.nodeProperties.stream", |
| 257 | + params=CallParameters(graph_name=visual_graph, properties=[pr_prop]), |
| 258 | + custom_error=False, |
| 259 | + ) |
| 260 | + |
| 261 | + # new format was requested, but the query was run via Cypher |
| 262 | + if "propertyValue" in result.keys(): |
| 263 | + wide_result = result.pivot(index=["nodeId"], columns=["nodeProperty"], values="propertyValue") |
| 264 | + result = wide_result.reset_index() |
| 265 | + result.columns.name = None |
| 266 | + node_properties_df = result |
| 267 | + |
| 268 | + relationships_df = self._query_runner.call_procedure( |
| 269 | + endpoint="gds.graph.relationships.stream", |
| 270 | + params=CallParameters(graph_name=visual_graph), |
| 271 | + custom_error=False, |
| 272 | + ) |
| 273 | + |
| 274 | + if visual_graph != self._name: |
| 275 | + self._query_runner.call_procedure( |
| 276 | + endpoint="gds.graph.drop", |
| 277 | + params=CallParameters(graph_name=visual_graph), |
| 278 | + custom_error=False, |
| 279 | + ) |
| 280 | + else: |
| 281 | + self._query_runner.call_procedure( |
| 282 | + endpoint="gds.graph.nodeProperties.drop", |
| 283 | + params=CallParameters(graph_name=visual_graph, nodeProperties=pr_prop), |
| 284 | + custom_error=False, |
| 285 | + ) |
| 286 | + |
| 287 | + from pyvis.network import Network |
| 288 | + |
| 289 | + net = Network( |
| 290 | + notebook=True, |
| 291 | + cdn_resources="remote", |
| 292 | + bgcolor="#222222", |
| 293 | + font_color="white", |
| 294 | + height="750px", # Modify according to your screen size |
| 295 | + width="100%", |
| 296 | + ) |
| 297 | + |
| 298 | + for _, node in node_properties_df.iterrows(): |
| 299 | + net.add_node(int(node["nodeId"]), value=node[pr_prop]) |
| 300 | + |
| 301 | + # Add all the relationships |
| 302 | + net.add_edges(zip(relationships_df["sourceNodeId"], relationships_df["targetNodeId"])) |
| 303 | + |
| 304 | + return net.show(f"{self._name}.html") |
0 commit comments