2828import matplotlib .pyplot as plot
2929import seaborn
3030
31+ from visualization import plot_annotation_style , annotate_each , annotate_each_with_index , zoom_into_center , zoom_into_center_while_preserving_scores_above_threshold , zoom_into_center_while_preserving_top_scores
3132
3233class Parameters :
3334 required_parameters_ = ["projection_node_label" ]
@@ -256,19 +257,6 @@ def get_clusters_by_criteria(
256257 return data [(data [by ] >= threshold ) | (data [label_column_name ] == - 1 )]
257258
258259
259- plot_annotation_style : dict = {
260- 'textcoords' : 'offset points' ,
261- 'arrowprops' : dict (arrowstyle = '->' , color = 'black' , alpha = 0.3 ),
262- 'fontsize' : 6 ,
263- 'backgroundcolor' : 'white' ,
264- 'bbox' : dict (boxstyle = 'round,pad=0.4' ,
265- edgecolor = 'silver' ,
266- facecolor = 'whitesmoke' ,
267- alpha = 1
268- )
269- }
270-
271-
272260def get_file_path (name : str , parameters : Parameters , extension : str = 'svg' ) -> str :
273261 name = parameters .get_report_directory () + '/' + name .replace (' ' , '_' ) + '.' + extension
274262 if parameters .is_verbose ():
@@ -460,33 +448,35 @@ def plot_clustering_coefficient_vs_page_rank(
460448 'clusterNoise' : clustering_noise ,
461449 }, index = clustering_coefficients .index )
462450
451+ common_column_names_for_annotations = {
452+ "name_column" : 'shortName' ,
453+ "x_position_column" : 'clusteringCoefficient' ,
454+ "y_position_column" : 'pageRank'
455+ }
456+
463457 # Annotate points with their names. Filter out values with a page rank smaller than 1.5 standard deviations
464458 mean_page_rank = page_ranks .mean ()
465459 standard_deviation_page_rank = page_ranks .std ()
466460 threshold_page_rank = mean_page_rank + 1.5 * standard_deviation_page_rank
467- significant_points = combined_data [combined_data ['pageRank' ] > threshold_page_rank ].reset_index (drop = True ).head (10 )
468- for dataframe_index , row in significant_points .iterrows ():
469- index = typing .cast (int , dataframe_index )
470- plot .annotate (
471- text = row ['shortName' ],
472- xy = (row ['clusteringCoefficient' ], row ['pageRank' ]),
473- xytext = (5 , 5 + index * 10 ), # Offset y position for better visibility
474- ** plot_annotation_style
475- )
461+ significant_points = combined_data [combined_data ['pageRank' ] > threshold_page_rank ].sort_values (by = 'pageRank' , ascending = False ).reset_index (drop = True ).head (10 )
462+ annotate_each_with_index (
463+ significant_points ,
464+ using = plot .annotate ,
465+ value_column = 'pageRank' ,
466+ ** common_column_names_for_annotations
467+ )
476468
477469 # Annotate points with the highest clustering coefficients (top 20) and only show the lowest 5 page ranks
478470 combined_data ['page_rank_ranking' ] = combined_data ['pageRank' ].rank (ascending = False ).astype (int )
479471 combined_data ['clustering_coefficient_ranking' ] = combined_data ['clusteringCoefficient' ].rank (ascending = False ).astype (int )
480472 top_clustering_coefficients = combined_data .sort_values (by = 'clusteringCoefficient' , ascending = False ).reset_index (drop = True ).head (20 )
481473 top_clustering_coefficients = top_clustering_coefficients .sort_values (by = 'pageRank' , ascending = True ).reset_index (drop = True ).head (5 )
482- for dataframe_index , row in top_clustering_coefficients .iterrows ():
483- index = typing .cast (int , dataframe_index )
484- plot .annotate (
485- text = f"{ row ['shortName' ]} (score { row ['pageRank' ]:.4f} )" ,
486- xy = (row ['clusteringCoefficient' ], row ['pageRank' ]),
487- xytext = (5 , 5 + index * 10 ), # Offset y position for better visibility
488- ** plot_annotation_style
489- )
474+ annotate_each_with_index (
475+ top_clustering_coefficients ,
476+ using = plot .annotate ,
477+ value_column = 'clusteringCoefficient' ,
478+ ** common_column_names_for_annotations
479+ )
490480
491481 # plot.yscale('log') # Use logarithmic scale for better visibility of differences
492482 plot .grid (True )
@@ -523,9 +513,16 @@ def truncate(text: str, max_length: int):
523513 # Setup columns
524514 node_size_column = centrality_column_name
525515
516+ clustering_visualization_dataframe_zoomed = zoom_into_center (
517+ clustering_visualization_dataframe ,
518+ x_position_column ,
519+ y_position_column ,
520+ percentile_of_distance_to_center = 0.9
521+ )
522+
526523 # Separate HDBSCAN non-noise and noise nodes
527- node_embeddings_without_noise = clustering_visualization_dataframe [ clustering_visualization_dataframe [cluster_label_column_name ] != - 1 ]
528- node_embeddings_noise_only = clustering_visualization_dataframe [ clustering_visualization_dataframe [cluster_label_column_name ] == - 1 ]
524+ node_embeddings_without_noise = clustering_visualization_dataframe_zoomed [ clustering_visualization_dataframe_zoomed [cluster_label_column_name ] != - 1 ]
525+ node_embeddings_noise_only = clustering_visualization_dataframe_zoomed [ clustering_visualization_dataframe_zoomed [cluster_label_column_name ] == - 1 ]
529526
530527 # ------------------------------------------
531528 # Subplot: HDBSCAN Clustering with KDE
@@ -586,13 +583,15 @@ def truncate(text: str, max_length: int):
586583
587584 # Annotate medoids of the cluster
588585 medoids = cluster_nodes [cluster_nodes [cluster_medoid_column_name ] == 1 ]
589- for index , row in medoids .iterrows ():
590- plot .annotate (
591- text = f"{ truncate (row [code_unit_column_name ], 30 )} ({ row [cluster_label_column_name ]} )" ,
592- xy = (row [x_position_column ], row [y_position_column ]),
593- xytext = (5 , 5 ), # Offset for better visibility
594- ** plot_annotation_style
595- )
586+ annotate_each (
587+ medoids ,
588+ using = plot .annotate ,
589+ name_column = code_unit_column_name ,
590+ x_position_column = x_position_column ,
591+ y_position_column = y_position_column ,
592+ cluster_label_column = cluster_label_column_name ,
593+ alpha = 0.6
594+ )
596595
597596 plot .savefig (plot_file_path )
598597
@@ -609,40 +608,48 @@ def plot_clusters_probabilities(
609608 size_column : str = "pageRank" ,
610609 x_position_column : str = 'embeddingVisualizationX' ,
611610 y_position_column : str = 'embeddingVisualizationY' ,
611+ annotate_n_lowest_probabilities : int = 10
612612) -> None :
613613
614614 if clustering_visualization_dataframe .empty :
615615 print ("No projected data to plot available" )
616616 return
617617
618- def truncate (text : str , max_length : int = 22 ):
619- if len (text ) <= max_length :
620- return text
621- return text [:max_length - 3 ] + "..."
618+ clustering_visualization_dataframe_zoomed = zoom_into_center_while_preserving_top_scores (
619+ clustering_visualization_dataframe ,
620+ x_position_column ,
621+ y_position_column ,
622+ cluster_probability_column ,
623+ annotate_n_lowest_probabilities ,
624+ lowest_scores = True
625+ )
626+
627+ cluster_noise = clustering_visualization_dataframe_zoomed [clustering_visualization_dataframe_zoomed [cluster_label_column ] == - 1 ]
628+ cluster_non_noise = clustering_visualization_dataframe_zoomed [clustering_visualization_dataframe_zoomed [cluster_label_column ] != - 1 ]
629+ cluster_even_labels = clustering_visualization_dataframe_zoomed [clustering_visualization_dataframe_zoomed [cluster_label_column ] % 2 == 0 ]
630+ cluster_odd_labels = clustering_visualization_dataframe_zoomed [clustering_visualization_dataframe_zoomed [cluster_label_column ] % 2 == 1 ]
622631
623- cluster_noise = clustering_visualization_dataframe [clustering_visualization_dataframe [cluster_label_column ] == - 1 ]
624- cluster_non_noise = clustering_visualization_dataframe [clustering_visualization_dataframe [cluster_label_column ] != - 1 ]
625- cluster_even_labels = clustering_visualization_dataframe [clustering_visualization_dataframe [cluster_label_column ] % 2 == 0 ]
626- cluster_odd_labels = clustering_visualization_dataframe [clustering_visualization_dataframe [cluster_label_column ] % 2 == 1 ]
632+ def get_common_plot_parameters (data : pd .DataFrame ) -> dict :
633+ return {
634+ "x" : data [x_position_column ],
635+ "y" : data [y_position_column ],
636+ "s" : data [size_column ] * 10 + 2 ,
637+ }
627638
628639 plot .figure (figsize = (10 , 10 ))
629640 plot .title (title )
630641
631642 # Plot noise
632643 plot .scatter (
633- x = cluster_noise [x_position_column ],
634- y = cluster_noise [y_position_column ],
635- s = cluster_noise [size_column ] * 10 + 2 ,
644+ ** get_common_plot_parameters (cluster_noise ),
636645 color = 'lightgrey' ,
637646 alpha = 0.4 ,
638647 label = 'Noise'
639648 )
640649
641650 # Plot even labels
642651 plot .scatter (
643- x = cluster_even_labels [x_position_column ],
644- y = cluster_even_labels [y_position_column ],
645- s = cluster_even_labels [size_column ] * 10 + 2 ,
652+ ** get_common_plot_parameters (cluster_even_labels ),
646653 c = cluster_even_labels [cluster_probability_column ],
647654 vmin = 0.6 ,
648655 vmax = 1.0 ,
@@ -653,9 +660,7 @@ def truncate(text: str, max_length: int = 22):
653660
654661 # Plot odd labels
655662 plot .scatter (
656- x = cluster_odd_labels [x_position_column ],
657- y = cluster_odd_labels [y_position_column ],
658- s = cluster_odd_labels [size_column ] * 10 + 2 ,
663+ ** get_common_plot_parameters (cluster_odd_labels ),
659664 c = cluster_odd_labels [cluster_probability_column ],
660665 vmin = 0.6 ,
661666 vmax = 1.0 ,
@@ -665,28 +670,33 @@ def truncate(text: str, max_length: int = 22):
665670 )
666671
667672 # Annotate medoids of the cluster
668- cluster_medoids = cluster_non_noise [cluster_non_noise [cluster_medoid_column ] == 1 ].sort_values (by = cluster_size_column , ascending = False ).head (20 )
669- for index , row in cluster_medoids .iterrows ():
670- mean_cluster_probability = cluster_non_noise [cluster_non_noise [cluster_label_column ] == row [cluster_label_column ]][cluster_probability_column ].mean ()
671- plot .annotate (
672- text = f"{ truncate (row [code_unit_column ])} (cluster { row [cluster_label_column ]} ) (p={ mean_cluster_probability :.4f} )" ,
673- xy = (row [x_position_column ], row [y_position_column ]),
674- xytext = (5 , 5 ),
675- alpha = 0.4 ,
676- ** plot_annotation_style
677- )
673+ # Find center node of each cluster (medoid), sort them by cluster size descending and add a mean cluster probability column
674+ cluster_medoids = cluster_non_noise [cluster_non_noise [cluster_medoid_column ] == 1 ]
675+ cluster_medoids_by_cluster_size = cluster_medoids .sort_values (by = cluster_size_column , ascending = False ).head (20 )
676+ mean_probabilities = cluster_non_noise .groupby (cluster_label_column )[cluster_probability_column ].mean ().rename ('mean_cluster_probability' )
677+ cluster_medoids_with_mean_probabilites = cluster_medoids_by_cluster_size .merge (mean_probabilities , on = cluster_label_column , how = 'left' )
678+
679+ annotate_each (
680+ cluster_medoids_with_mean_probabilites ,
681+ using = plot .annotate ,
682+ name_column = code_unit_column ,
683+ x_position_column = x_position_column ,
684+ y_position_column = y_position_column ,
685+ cluster_label_column = cluster_label_column ,
686+ probability_column = 'mean_cluster_probability' ,
687+ alpha = 0.4
688+ )
678689
679- lowest_probabilities = cluster_non_noise .sort_values (by = cluster_probability_column , ascending = True ).reset_index ().head (10 )
680- lowest_probabilities_in_reverse_order = lowest_probabilities .iloc [::- 1 ] # plot most important annotations last to overlap less important ones
681- for dataframe_index , row in lowest_probabilities_in_reverse_order .iterrows ():
682- index = typing .cast (int , dataframe_index )
683- plot .annotate (
684- text = f"#{ index } :{ truncate (row [code_unit_column ], 20 )} ({ row [cluster_probability_column ]:.4f} )" ,
685- xy = (row [x_position_column ], row [y_position_column ]),
686- xytext = (5 , 5 + index * 10 ),
687- color = 'red' ,
688- ** plot_annotation_style
689- )
690+ lowest_probabilities = cluster_non_noise .sort_values (by = cluster_probability_column , ascending = True ).reset_index ().head (annotate_n_lowest_probabilities )
691+ annotate_each_with_index (
692+ lowest_probabilities ,
693+ using = plot .annotate ,
694+ name_column = code_unit_column ,
695+ x_position_column = x_position_column ,
696+ y_position_column = y_position_column ,
697+ probability_column = cluster_probability_column ,
698+ color = "red"
699+ )
690700
691701 plot .savefig (plot_file_path )
692702
@@ -722,23 +732,31 @@ def plot_cluster_noise(
722732 color_90_quantile = noise_points [color_column_name ].quantile (0.90 )
723733 color_threshold = max (color_10th_highest_value , color_90_quantile )
724734
735+ noise_points_zoomed = zoom_into_center_while_preserving_scores_above_threshold (
736+ noise_points ,
737+ x_position_column ,
738+ y_position_column ,
739+ color_column_name ,
740+ color_threshold
741+ )
742+
725743 # Color the color column values above the 90% quantile threshold red, the rest light grey
726- colors = noise_points [color_column_name ].apply (
744+ colors = noise_points_zoomed [color_column_name ].apply (
727745 lambda x : "red" if x >= color_threshold else "lightgrey"
728746 )
729- normalized_size = noise_points [size_column_name ] / noise_points [size_column_name ].max ()
747+ normalized_size = noise_points_zoomed [size_column_name ] / noise_points_zoomed [size_column_name ].max ()
730748
731749 # Scatter plot for noise points
732750 plot .scatter (
733- x = noise_points [x_position_column ],
734- y = noise_points [y_position_column ],
751+ x = noise_points_zoomed [x_position_column ],
752+ y = noise_points_zoomed [y_position_column ],
735753 s = normalized_size .clip (lower = 0.01 ) * 200 + 2 ,
736754 c = colors ,
737755 alpha = 0.6
738756 )
739757
740758 # Annotate the largest 10 points and all colored ones with their names
741- for index , row in noise_points .iterrows ():
759+ for index , row in noise_points_zoomed .iterrows ():
742760 index = typing .cast (int , index )
743761 if colors [index ] != 'red' and index >= 10 :
744762 continue
0 commit comments