Skip to content

Commit d2bcdfe

Browse files
committed
Reuse common visualization functions for more unified plots
1 parent b524736 commit d2bcdfe

File tree

4 files changed

+369
-161
lines changed

4 files changed

+369
-161
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,9 @@ coverage/
9393
.ipynb_checkpoints
9494
*.nbconvert*
9595

96+
# Python
97+
__pycache__/
98+
9699
# Python environments
97100
.conda
98101

domains/anomaly-detection/anomalyDetectionFeaturePlots.py

Lines changed: 100 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import matplotlib.pyplot as plot
2929
import seaborn
3030

31+
from visualization import plot_annotation_style, annotate_each, annotate_each_with_index, zoom_into_center, zoom_into_center_while_preserving_scores_above_threshold, zoom_into_center_while_preserving_top_scores
3132

3233
class Parameters:
3334
required_parameters_ = ["projection_node_label"]
@@ -256,19 +257,6 @@ def get_clusters_by_criteria(
256257
return data[(data[by] >= threshold) | (data[label_column_name] == -1)]
257258

258259

259-
plot_annotation_style: dict = {
260-
'textcoords': 'offset points',
261-
'arrowprops': dict(arrowstyle='->', color='black', alpha=0.3),
262-
'fontsize': 6,
263-
'backgroundcolor': 'white',
264-
'bbox': dict(boxstyle='round,pad=0.4',
265-
edgecolor='silver',
266-
facecolor='whitesmoke',
267-
alpha=1
268-
)
269-
}
270-
271-
272260
def get_file_path(name: str, parameters: Parameters, extension: str = 'svg') -> str:
273261
name = parameters.get_report_directory() + '/' + name.replace(' ', '_') + '.' + extension
274262
if parameters.is_verbose():
@@ -460,33 +448,35 @@ def plot_clustering_coefficient_vs_page_rank(
460448
'clusterNoise': clustering_noise,
461449
}, index=clustering_coefficients.index)
462450

451+
common_column_names_for_annotations = {
452+
"name_column": 'shortName',
453+
"x_position_column": 'clusteringCoefficient',
454+
"y_position_column": 'pageRank'
455+
}
456+
463457
# Annotate points with their names. Filter out values with a page rank smaller than 1.5 standard deviations
464458
mean_page_rank = page_ranks.mean()
465459
standard_deviation_page_rank = page_ranks.std()
466460
threshold_page_rank = mean_page_rank + 1.5 * standard_deviation_page_rank
467-
significant_points = combined_data[combined_data['pageRank'] > threshold_page_rank].reset_index(drop=True).head(10)
468-
for dataframe_index, row in significant_points.iterrows():
469-
index = typing.cast(int, dataframe_index)
470-
plot.annotate(
471-
text=row['shortName'],
472-
xy=(row['clusteringCoefficient'], row['pageRank']),
473-
xytext=(5, 5 + index * 10), # Offset y position for better visibility
474-
**plot_annotation_style
475-
)
461+
significant_points = combined_data[combined_data['pageRank'] > threshold_page_rank].sort_values(by='pageRank', ascending=False).reset_index(drop=True).head(10)
462+
annotate_each_with_index(
463+
significant_points,
464+
using=plot.annotate,
465+
value_column='pageRank',
466+
**common_column_names_for_annotations
467+
)
476468

477469
# Annotate points with the highest clustering coefficients (top 20) and only show the lowest 5 page ranks
478470
combined_data['page_rank_ranking'] = combined_data['pageRank'].rank(ascending=False).astype(int)
479471
combined_data['clustering_coefficient_ranking'] = combined_data['clusteringCoefficient'].rank(ascending=False).astype(int)
480472
top_clustering_coefficients = combined_data.sort_values(by='clusteringCoefficient', ascending=False).reset_index(drop=True).head(20)
481473
top_clustering_coefficients = top_clustering_coefficients.sort_values(by='pageRank', ascending=True).reset_index(drop=True).head(5)
482-
for dataframe_index, row in top_clustering_coefficients.iterrows():
483-
index = typing.cast(int, dataframe_index)
484-
plot.annotate(
485-
text=f"{row['shortName']} (score {row['pageRank']:.4f})",
486-
xy=(row['clusteringCoefficient'], row['pageRank']),
487-
xytext=(5, 5 + index * 10), # Offset y position for better visibility
488-
**plot_annotation_style
489-
)
474+
annotate_each_with_index(
475+
top_clustering_coefficients,
476+
using=plot.annotate,
477+
value_column='clusteringCoefficient',
478+
**common_column_names_for_annotations
479+
)
490480

491481
# plot.yscale('log') # Use logarithmic scale for better visibility of differences
492482
plot.grid(True)
@@ -523,9 +513,16 @@ def truncate(text: str, max_length: int):
523513
# Setup columns
524514
node_size_column = centrality_column_name
525515

516+
clustering_visualization_dataframe_zoomed=zoom_into_center(
517+
clustering_visualization_dataframe,
518+
x_position_column,
519+
y_position_column,
520+
percentile_of_distance_to_center=0.9
521+
)
522+
526523
# Separate HDBSCAN non-noise and noise nodes
527-
node_embeddings_without_noise = clustering_visualization_dataframe[clustering_visualization_dataframe[cluster_label_column_name] != -1]
528-
node_embeddings_noise_only = clustering_visualization_dataframe[clustering_visualization_dataframe[cluster_label_column_name] == -1]
524+
node_embeddings_without_noise = clustering_visualization_dataframe_zoomed[clustering_visualization_dataframe_zoomed[cluster_label_column_name] != -1]
525+
node_embeddings_noise_only = clustering_visualization_dataframe_zoomed[clustering_visualization_dataframe_zoomed[cluster_label_column_name] == -1]
529526

530527
# ------------------------------------------
531528
# Subplot: HDBSCAN Clustering with KDE
@@ -586,13 +583,15 @@ def truncate(text: str, max_length: int):
586583

587584
# Annotate medoids of the cluster
588585
medoids = cluster_nodes[cluster_nodes[cluster_medoid_column_name] == 1]
589-
for index, row in medoids.iterrows():
590-
plot.annotate(
591-
text=f"{truncate(row[code_unit_column_name], 30)} ({row[cluster_label_column_name]})",
592-
xy=(row[x_position_column], row[y_position_column]),
593-
xytext=(5, 5), # Offset for better visibility
594-
**plot_annotation_style
595-
)
586+
annotate_each(
587+
medoids,
588+
using=plot.annotate,
589+
name_column=code_unit_column_name,
590+
x_position_column=x_position_column,
591+
y_position_column=y_position_column,
592+
cluster_label_column=cluster_label_column_name,
593+
alpha=0.6
594+
)
596595

597596
plot.savefig(plot_file_path)
598597

@@ -609,40 +608,48 @@ def plot_clusters_probabilities(
609608
size_column: str = "pageRank",
610609
x_position_column: str = 'embeddingVisualizationX',
611610
y_position_column: str = 'embeddingVisualizationY',
611+
annotate_n_lowest_probabilities: int = 10
612612
) -> None:
613613

614614
if clustering_visualization_dataframe.empty:
615615
print("No projected data to plot available")
616616
return
617617

618-
def truncate(text: str, max_length: int = 22):
619-
if len(text) <= max_length:
620-
return text
621-
return text[:max_length - 3] + "..."
618+
clustering_visualization_dataframe_zoomed=zoom_into_center_while_preserving_top_scores(
619+
clustering_visualization_dataframe,
620+
x_position_column,
621+
y_position_column,
622+
cluster_probability_column,
623+
annotate_n_lowest_probabilities,
624+
lowest_scores=True
625+
)
626+
627+
cluster_noise = clustering_visualization_dataframe_zoomed[clustering_visualization_dataframe_zoomed[cluster_label_column] == -1]
628+
cluster_non_noise = clustering_visualization_dataframe_zoomed[clustering_visualization_dataframe_zoomed[cluster_label_column] != -1]
629+
cluster_even_labels = clustering_visualization_dataframe_zoomed[clustering_visualization_dataframe_zoomed[cluster_label_column] % 2 == 0]
630+
cluster_odd_labels = clustering_visualization_dataframe_zoomed[clustering_visualization_dataframe_zoomed[cluster_label_column] % 2 == 1]
622631

623-
cluster_noise = clustering_visualization_dataframe[clustering_visualization_dataframe[cluster_label_column] == -1]
624-
cluster_non_noise = clustering_visualization_dataframe[clustering_visualization_dataframe[cluster_label_column] != -1]
625-
cluster_even_labels = clustering_visualization_dataframe[clustering_visualization_dataframe[cluster_label_column] % 2 == 0]
626-
cluster_odd_labels = clustering_visualization_dataframe[clustering_visualization_dataframe[cluster_label_column] % 2 == 1]
632+
def get_common_plot_parameters(data: pd.DataFrame) -> dict:
633+
return {
634+
"x": data[x_position_column],
635+
"y": data[y_position_column],
636+
"s": data[size_column] * 10 + 2,
637+
}
627638

628639
plot.figure(figsize=(10, 10))
629640
plot.title(title)
630641

631642
# Plot noise
632643
plot.scatter(
633-
x=cluster_noise[x_position_column],
634-
y=cluster_noise[y_position_column],
635-
s=cluster_noise[size_column] * 10 + 2,
644+
**get_common_plot_parameters(cluster_noise),
636645
color='lightgrey',
637646
alpha=0.4,
638647
label='Noise'
639648
)
640649

641650
# Plot even labels
642651
plot.scatter(
643-
x=cluster_even_labels[x_position_column],
644-
y=cluster_even_labels[y_position_column],
645-
s=cluster_even_labels[size_column] * 10 + 2,
652+
**get_common_plot_parameters(cluster_even_labels),
646653
c=cluster_even_labels[cluster_probability_column],
647654
vmin=0.6,
648655
vmax=1.0,
@@ -653,9 +660,7 @@ def truncate(text: str, max_length: int = 22):
653660

654661
# Plot odd labels
655662
plot.scatter(
656-
x=cluster_odd_labels[x_position_column],
657-
y=cluster_odd_labels[y_position_column],
658-
s=cluster_odd_labels[size_column] * 10 + 2,
663+
**get_common_plot_parameters(cluster_odd_labels),
659664
c=cluster_odd_labels[cluster_probability_column],
660665
vmin=0.6,
661666
vmax=1.0,
@@ -665,28 +670,33 @@ def truncate(text: str, max_length: int = 22):
665670
)
666671

667672
# Annotate medoids of the cluster
668-
cluster_medoids = cluster_non_noise[cluster_non_noise[cluster_medoid_column] == 1].sort_values(by=cluster_size_column, ascending=False).head(20)
669-
for index, row in cluster_medoids.iterrows():
670-
mean_cluster_probability = cluster_non_noise[cluster_non_noise[cluster_label_column] == row[cluster_label_column]][cluster_probability_column].mean()
671-
plot.annotate(
672-
text=f"{truncate(row[code_unit_column])} (cluster {row[cluster_label_column]}) (p={mean_cluster_probability:.4f})",
673-
xy=(row[x_position_column], row[y_position_column]),
674-
xytext=(5, 5),
675-
alpha=0.4,
676-
**plot_annotation_style
677-
)
673+
# Find center node of each cluster (medoid), sort them by cluster size descending and add a mean cluster probability column
674+
cluster_medoids = cluster_non_noise[cluster_non_noise[cluster_medoid_column] == 1]
675+
cluster_medoids_by_cluster_size = cluster_medoids.sort_values(by=cluster_size_column, ascending=False).head(20)
676+
mean_probabilities = cluster_non_noise.groupby(cluster_label_column)[cluster_probability_column].mean().rename('mean_cluster_probability')
677+
cluster_medoids_with_mean_probabilites = cluster_medoids_by_cluster_size.merge(mean_probabilities, on=cluster_label_column, how='left')
678+
679+
annotate_each(
680+
cluster_medoids_with_mean_probabilites,
681+
using=plot.annotate,
682+
name_column=code_unit_column,
683+
x_position_column=x_position_column,
684+
y_position_column=y_position_column,
685+
cluster_label_column=cluster_label_column,
686+
probability_column='mean_cluster_probability',
687+
alpha=0.4
688+
)
678689

679-
lowest_probabilities = cluster_non_noise.sort_values(by=cluster_probability_column, ascending=True).reset_index().head(10)
680-
lowest_probabilities_in_reverse_order = lowest_probabilities.iloc[::-1] # plot most important annotations last to overlap less important ones
681-
for dataframe_index, row in lowest_probabilities_in_reverse_order.iterrows():
682-
index = typing.cast(int, dataframe_index)
683-
plot.annotate(
684-
text=f"#{index}:{truncate(row[code_unit_column], 20)} ({row[cluster_probability_column]:.4f})",
685-
xy=(row[x_position_column], row[y_position_column]),
686-
xytext=(5, 5 + index * 10),
687-
color='red',
688-
**plot_annotation_style
689-
)
690+
lowest_probabilities = cluster_non_noise.sort_values(by=cluster_probability_column, ascending=True).reset_index().head(annotate_n_lowest_probabilities)
691+
annotate_each_with_index(
692+
lowest_probabilities,
693+
using=plot.annotate,
694+
name_column=code_unit_column,
695+
x_position_column=x_position_column,
696+
y_position_column=y_position_column,
697+
probability_column=cluster_probability_column,
698+
color="red"
699+
)
690700

691701
plot.savefig(plot_file_path)
692702

@@ -722,23 +732,31 @@ def plot_cluster_noise(
722732
color_90_quantile = noise_points[color_column_name].quantile(0.90)
723733
color_threshold = max(color_10th_highest_value, color_90_quantile)
724734

735+
noise_points_zoomed = zoom_into_center_while_preserving_scores_above_threshold(
736+
noise_points,
737+
x_position_column,
738+
y_position_column,
739+
color_column_name,
740+
color_threshold
741+
)
742+
725743
# Color the color column values above the 90% quantile threshold red, the rest light grey
726-
colors = noise_points[color_column_name].apply(
744+
colors = noise_points_zoomed[color_column_name].apply(
727745
lambda x: "red" if x >= color_threshold else "lightgrey"
728746
)
729-
normalized_size = noise_points[size_column_name] / noise_points[size_column_name].max()
747+
normalized_size = noise_points_zoomed[size_column_name] / noise_points_zoomed[size_column_name].max()
730748

731749
# Scatter plot for noise points
732750
plot.scatter(
733-
x=noise_points[x_position_column],
734-
y=noise_points[y_position_column],
751+
x=noise_points_zoomed[x_position_column],
752+
y=noise_points_zoomed[y_position_column],
735753
s=normalized_size.clip(lower=0.01) * 200 + 2,
736754
c=colors,
737755
alpha=0.6
738756
)
739757

740758
# Annotate the largest 10 points and all colored ones with their names
741-
for index, row in noise_points.iterrows():
759+
for index, row in noise_points_zoomed.iterrows():
742760
index = typing.cast(int, index)
743761
if colors[index] != 'red' and index >= 10:
744762
continue

0 commit comments

Comments
 (0)