Skip to content

Commit 0d73142

Browse files
committed
Use UMAP for node embedding visualization
1 parent 8933400 commit 0d73142

File tree

3 files changed

+325
-9
lines changed

3 files changed

+325
-9
lines changed
Lines changed: 211 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,211 @@
1+
#!/usr/bin/env python
2+
3+
# This Python script uses UMAP (https://umap-learn.readthedocs.io) to reduce the dimensionality of node embeddings to two dimensions for visualization purposes.
4+
# This is useful to get a visual intuition about the structure of the code units (like Java packages) and their dependencies.
5+
# The resulting 2D coordinates are written back to Neo4j for further use.
6+
7+
# Prerequisite:
8+
# - Provide the password for Neo4j in the environment variable "NEO4J_INITIAL_PASSWORD".
9+
# - Already existing Graph with analyzed code units (like Java Packages) and their dependencies.
10+
# - Already existing node embeddings for the code units, e.g. generated by Fast Random Projection (FastRP) or other algorithms.
11+
12+
import typing
13+
14+
import os
15+
import sys
16+
import argparse
17+
import pprint
18+
19+
import pandas as pd
20+
import numpy as np
21+
22+
from neo4j import GraphDatabase, Driver
23+
import umap
24+
25+
26+
class Parameters:
27+
required_parameters_ = ["projection_node_label"]
28+
29+
def __init__(self, input_parameters: typing.Dict[str, str], verbose: bool = False):
30+
self.query_parameters_ = input_parameters.copy() # copy enforces immutability
31+
self.verbose_ = verbose
32+
33+
def __repr__(self):
34+
pretty_dict = pprint.pformat(self.query_parameters_, indent=4)
35+
return f"Parameters: verbose={self.verbose_}, query_parameters:\n{pretty_dict}"
36+
37+
@staticmethod
38+
def log_dependency_versions_() -> None:
39+
print('---------------------------------------')
40+
41+
print('Python version: {}'.format(sys.version))
42+
43+
from numpy import __version__ as numpy_version
44+
print('numpy version: {}'.format(numpy_version))
45+
46+
from pandas import __version__ as pandas_version
47+
print('pandas version: {}'.format(pandas_version))
48+
49+
from neo4j import __version__ as neo4j_version
50+
print('neo4j version: {}'.format(neo4j_version))
51+
52+
from umap import __version__ as umap_version
53+
print('umap version: {}'.format(umap_version))
54+
55+
print('---------------------------------------')
56+
57+
@classmethod
58+
def from_input_parameters(cls, input_parameters: typing.Dict[str, str], verbose: bool = False):
59+
"""
60+
Creates a Parameters instance from a dictionary of input parameters.
61+
The dictionary must contain the following keys:
62+
- "projection_node_label": The node type of the projection.
63+
"""
64+
missing_parameters = [parameter for parameter in cls.required_parameters_ if parameter not in input_parameters]
65+
if missing_parameters:
66+
raise ValueError("Missing parameters:", missing_parameters)
67+
created_parameters = cls(input_parameters, verbose)
68+
if created_parameters.is_verbose():
69+
print(created_parameters)
70+
cls.log_dependency_versions_()
71+
return created_parameters
72+
73+
@classmethod
74+
def example(cls):
75+
return cls(dict(projection_node_label="Package"))
76+
77+
def get_query_parameters(self) -> typing.Dict[str, str]:
78+
return self.query_parameters_.copy() # copy enforces immutability
79+
80+
def clone_with_projection_name(self, projection_name: str):
81+
updated_parameter = self.get_query_parameters()
82+
updated_parameter.update({"projection_name": projection_name})
83+
return Parameters(updated_parameter)
84+
85+
def get_projection_node_label(self) -> str:
86+
return self.query_parameters_["projection_node_label"]
87+
88+
def is_verbose(self) -> bool:
89+
return self.verbose_
90+
91+
92+
def parse_input_parameters() -> Parameters:
93+
# Convert list of "key=value" strings to a dictionary
94+
def parse_key_value_list(param_list: typing.List[str]) -> typing.Dict[str, str]:
95+
param_dict = {}
96+
for item in param_list:
97+
if '=' in item:
98+
key, value = item.split('=', 1)
99+
param_dict[key] = value
100+
return param_dict
101+
102+
parser = argparse.ArgumentParser(
103+
description="Unsupervised clustering to assign labels to code units (Java packages, types,...) and their dependencies based on how structurally similar they are within a software system.")
104+
parser.add_argument('--verbose', action='store_true', help='Enable verbose mode to log all details')
105+
parser.add_argument('query_parameters', nargs='*', type=str, help='List of key=value Cypher query parameters')
106+
parser.set_defaults(verbose=False)
107+
args = parser.parse_args()
108+
return Parameters.from_input_parameters(parse_key_value_list(args.query_parameters), args.verbose)
109+
110+
111+
def get_graph_database_driver() -> Driver:
112+
driver = GraphDatabase.driver(
113+
uri="bolt://localhost:7687",
114+
auth=("neo4j", os.environ.get("NEO4J_INITIAL_PASSWORD"))
115+
)
116+
driver.verify_connectivity()
117+
return driver
118+
119+
120+
def query_cypher_to_data_frame(query: typing.LiteralString, parameters: typing.Optional[typing.Dict[str, typing.Any]] = None):
121+
records, summary, keys = driver.execute_query(query, parameters_=parameters)
122+
return pd.DataFrame([record.values() for record in records], columns=keys)
123+
124+
125+
def write_batch_data_into_database(dataframe: pd.DataFrame, node_label: str, id_column: str = "nodeElementId", batch_size: int = 1000):
126+
"""
127+
Writes the given dataframe to the Neo4j database using a batch write operation.
128+
129+
Parameters:
130+
- dataframe: The pandas DataFrame to write.
131+
- label: The label to use for the nodes in the Neo4j database.
132+
- id_column: The name of the column in the DataFrame that contains the node IDs.
133+
- cypher_query_file: The file containing the Cypher query for writing the data.
134+
- batch_size: The number of rows to write in each batch.
135+
"""
136+
def prepare_rows(dataframe):
137+
rows = []
138+
for _, row in dataframe.iterrows():
139+
properties_without_id = row.drop(labels=[id_column]).to_dict()
140+
rows.append({
141+
"nodeId": row[id_column],
142+
"properties": properties_without_id
143+
})
144+
return rows
145+
146+
def update_batch(transaction, rows):
147+
query = """
148+
UNWIND $rows AS row
149+
MATCH (codeUnit)
150+
WHERE elementId(codeUnit) = row.nodeId
151+
AND $node_label IN labels(codeUnit)
152+
SET codeUnit += row.properties
153+
"""
154+
transaction.run(query, rows=rows, node_label=node_label)
155+
156+
with driver.session() as session:
157+
for start in range(0, len(dataframe), batch_size):
158+
batch_dataframe = dataframe.iloc[start:start + batch_size]
159+
batch_rows = prepare_rows(batch_dataframe)
160+
return session.execute_write(update_batch, batch_rows)
161+
162+
163+
def prepare_node_embeddings_for_2d_visualization(embeddings: pd.DataFrame) -> pd.DataFrame:
164+
"""
165+
Reduces the dimensionality of the node embeddings (e.g. 64 floating point numbers in an array)
166+
to two dimensions for 2D visualization using UMAP.
167+
see https://umap-learn.readthedocs.io
168+
"""
169+
170+
if embeddings.empty:
171+
print("No projected data for node embeddings dimensionality reduction available")
172+
return embeddings
173+
174+
# Convert the list of embeddings to a numpy array
175+
embeddings_as_numpy_array = np.array(embeddings.embedding.to_list())
176+
177+
# Use UMAP to reduce the dimensionality to 2D for visualization
178+
umap_reducer = umap.UMAP(n_components=2, min_dist=0.3, random_state=47, n_jobs=1, verbose=parameters.is_verbose())
179+
two_dimension_node_embeddings = umap_reducer.fit_transform(embeddings_as_numpy_array)
180+
181+
# Add the 2D coordinates to the DataFrame
182+
embeddings['embeddingVisualizationX'] = two_dimension_node_embeddings[:, 0]
183+
embeddings['embeddingVisualizationY'] = two_dimension_node_embeddings[:, 1]
184+
185+
return embeddings
186+
187+
188+
# ------------------------------------------------------------------------------------------------------------
189+
# MAIN
190+
# ------------------------------------------------------------------------------------------------------------
191+
192+
parameters = parse_input_parameters()
193+
driver = get_graph_database_driver()
194+
195+
cypher_query_embeddings_: typing.LiteralString = """
196+
MATCH (codeUnit)
197+
WHERE $projection_node_label IN labels(codeUnit)
198+
AND codeUnit.embeddingsFastRandomProjectionForClustering IS NOT NULL
199+
RETURN elementId(codeUnit) AS nodeElementId
200+
,codeUnit.embeddingsFastRandomProjectionForClustering AS embedding
201+
"""
202+
203+
embeddings = query_cypher_to_data_frame(cypher_query_embeddings_, parameters.get_query_parameters())
204+
embeddings = prepare_node_embeddings_for_2d_visualization(embeddings)
205+
206+
data_to_write = pd.DataFrame(data={
207+
'nodeElementId': embeddings["nodeElementId"],
208+
'embeddingFastRandomProjectionVisualizationX': embeddings["embeddingVisualizationX"],
209+
'embeddingFastRandomProjectionVisualizationY': embeddings["embeddingVisualizationY"],
210+
})
211+
write_batch_data_into_database(data_to_write, parameters.get_projection_node_label())

0 commit comments

Comments
 (0)