@@ -111,79 +111,90 @@ def plot_distributions(ax, x, true_mixture, fitted_mixture, title):
111111 legend_fontsize = 14
112112 tick_fontsize = 14
113113
114- sns .histplot (x , color = "royalblue" , ax = ax , stat = "density" , alpha = 0.8 ,
115- binwidth = 0.5 , edgecolor = 'white' , linewidth = 1 )
114+ sns .histplot (x , color = "royalblue" , ax = ax , stat = "density" , alpha = 0.8 , binwidth = 0.5 , edgecolor = "white" , linewidth = 1 )
116115
117- ax .set_xlabel ("Значение x" , fontsize = label_fontsize , fontweight = ' bold' , labelpad = 10 )
118- ax .set_ylabel ("Плотность (density)" , fontsize = label_fontsize , fontweight = ' bold' , labelpad = 10 )
119- ax .set_title (title , fontsize = title_fontsize , fontweight = ' bold' , pad = 15 )
120- ax .grid (True , linestyle = '--' , alpha = 0.5 , linewidth = 1 )
116+ ax .set_xlabel ("Значение x" , fontsize = label_fontsize , fontweight = " bold" , labelpad = 10 )
117+ ax .set_ylabel ("Плотность (density)" , fontsize = label_fontsize , fontweight = " bold" , labelpad = 10 )
118+ ax .set_title (title , fontsize = title_fontsize , fontweight = " bold" , pad = 15 )
119+ ax .grid (True , linestyle = "--" , alpha = 0.5 , linewidth = 1 )
121120 ax .set_xlim (0 , 20 )
122121
123122 ax .set_xticks (np .arange (0 , 21 , 2 ))
124123 ax .set_yticks (np .linspace (0 , ax .get_yticks ().max (), len (ax .get_yticks ())))
125124
126- ax .tick_params (axis = 'both' , which = 'both' ,
127- labelsize = tick_fontsize ,
128- width = 3 , length = 8 ,
129- pad = 8 ,
130- colors = 'black' ,
131- grid_color = 'black' ,
132- grid_alpha = 0.5 )
125+ ax .tick_params (
126+ axis = "both" ,
127+ which = "both" ,
128+ labelsize = tick_fontsize ,
129+ width = 3 ,
130+ length = 8 ,
131+ pad = 8 ,
132+ colors = "black" ,
133+ grid_color = "black" ,
134+ grid_alpha = 0.5 ,
135+ )
133136
134137 for label in ax .get_xticklabels () + ax .get_yticklabels ():
135- label .set_fontweight (' bold' )
138+ label .set_fontweight (" bold" )
136139
137140 for spine in ax .spines .values ():
138141 spine .set_linewidth (3 )
139- spine .set_color (' black' )
142+ spine .set_color (" black" )
140143
141144 ax_ = ax .twinx ()
142- ax_ .set_ylabel ("p(x)" , fontsize = label_fontsize , fontweight = ' bold' , labelpad = 15 )
145+ ax_ .set_ylabel ("p(x)" , fontsize = label_fontsize , fontweight = " bold" , labelpad = 15 )
143146 ax_ .set_yscale ("log" )
144147
145148 y_ticks = [0.01 , 0.02 , 0.05 , 0.1 , 0.2 , 0.5 , 1.0 ]
146149 ax_ .set_yticks (y_ticks )
147- ax_ .set_yticklabels ([f"{ tick :.2f} " for tick in y_ticks ],
148- fontsize = tick_fontsize ,
149- fontweight = 'bold' ,
150- color = 'black' )
150+ ax_ .set_yticklabels ([f"{ tick :.2f} " for tick in y_ticks ], fontsize = tick_fontsize , fontweight = "bold" , color = "black" )
151151
152- ax_ .tick_params (axis = 'y' , which = 'both' ,
153- width = 3 , length = 8 ,
154- pad = 10 ,
155- colors = 'black' )
152+ ax_ .tick_params (axis = "y" , which = "both" , width = 3 , length = 8 , pad = 10 , colors = "black" )
156153
157154 ax_ .set_ylim (bottom = y_ticks [0 ], top = y_ticks [- 1 ])
158155
159156 for spine in ax_ .spines .values ():
160157 spine .set_linewidth (3 )
161- spine .set_color (' black' )
158+ spine .set_color (" black" )
162159
163160 X_plot = np .linspace (0.001 , 20 , 1000 )
164- ax_ .plot (X_plot , [true_mixture .pdf (xi ) for xi in X_plot ],
165- color = "darkgreen" , label = "Истинное распределение" ,
166- linewidth = 4 , linestyle = '-' , alpha = 0.9 )
167- ax_ .plot (X_plot , [fitted_mixture .pdf (xi ) for xi in X_plot ],
168- color = "crimson" , label = "Подобранное распределение" ,
169- linewidth = 4 , linestyle = '--' , alpha = 0.9 )
170-
171- legend = ax_ .legend (loc = 'upper right' ,
172- fontsize = legend_fontsize ,
173- framealpha = 1 ,
174- edgecolor = 'black' ,
175- facecolor = 'white' ,
176- frameon = True ,
177- borderpad = 1 )
161+ ax_ .plot (
162+ X_plot ,
163+ [true_mixture .pdf (xi ) for xi in X_plot ],
164+ color = "darkgreen" ,
165+ label = "Истинное распределение" ,
166+ linewidth = 4 ,
167+ linestyle = "-" ,
168+ alpha = 0.9 ,
169+ )
170+ ax_ .plot (
171+ X_plot ,
172+ [fitted_mixture .pdf (xi ) for xi in X_plot ],
173+ color = "crimson" ,
174+ label = "Подобранное распределение" ,
175+ linewidth = 4 ,
176+ linestyle = "--" ,
177+ alpha = 0.9 ,
178+ )
179+
180+ legend = ax_ .legend (
181+ loc = "upper right" ,
182+ fontsize = legend_fontsize ,
183+ framealpha = 1 ,
184+ edgecolor = "black" ,
185+ facecolor = "white" ,
186+ frameon = True ,
187+ borderpad = 1 ,
188+ )
178189 legend .get_frame ().set_linewidth (2 )
179190
180191 ax .minorticks_on ()
181192 ax_ .minorticks_on ()
182- ax .tick_params (axis = ' both' , which = ' minor' , width = 2 , length = 5 )
183- ax_ .tick_params (axis = ' both' , which = ' minor' , width = 2 , length = 5 )
193+ ax .tick_params (axis = " both" , which = " minor" , width = 2 , length = 5 )
194+ ax_ .tick_params (axis = " both" , which = " minor" , width = 2 , length = 5 )
184195
185196 for y in y_ticks :
186- ax_ .axhline (y = y , color = ' gray' , linestyle = ':' , alpha = 0.3 , linewidth = 1 )
197+ ax_ .axhline (y = y , color = " gray" , linestyle = ":" , alpha = 0.3 , linewidth = 1 )
187198
188199
189200def save_metrics_table (metrics_data : dict [str , dict [str , float ]], filename : str , title : str ):
@@ -244,10 +255,14 @@ def _initialize_methods(mixture: MixtureDistribution, eps) -> list[tuple]:
244255 raise ValueError (f"Unsupported model type: { model_type } " )
245256 n_clusters = len (models )
246257 return [
247- ("BayesEStep" ,None ,BayesEStep ()),
248- ("KMeans+ML" ,"kmeans" ,EnhancedClusteringEStep (models ,clusterizer = KMeans (n_clusters = n_clusters ))),
249- ("Agglo+ML" ,"agglo" ,EnhancedClusteringEStep (models ,clusterizer = AgglomerativeClustering (n_clusters = n_clusters ))),
250- ("DBSCAN+ML" ,"dbscan" ,EnhancedClusteringEStep (models ,eps = eps ,clusterizer = DBSCAN ())),
258+ ("BayesEStep" , None , BayesEStep ()),
259+ ("KMeans+ML" , "kmeans" , EnhancedClusteringEStep (models , clusterizer = KMeans (n_clusters = n_clusters ))),
260+ (
261+ "Agglo+ML" ,
262+ "agglo" ,
263+ EnhancedClusteringEStep (models , clusterizer = AgglomerativeClustering (n_clusters = n_clusters )),
264+ ),
265+ ("DBSCAN+ML" , "dbscan" , EnhancedClusteringEStep (models , eps = eps , clusterizer = DBSCAN ())),
251266 ]
252267
253268
@@ -292,9 +307,14 @@ def _calculate_summary_metrics(all_results: dict) -> dict:
292307 return summary_metrics
293308
294309
295- def _save_comparison_plots (methods : list , mixture : MixtureDistribution ,
296- problem : Problem , summary_metrics : dict ,
297- group_name : str , sample_size : int ):
310+ def _save_comparison_plots (
311+ methods : list ,
312+ mixture : MixtureDistribution ,
313+ problem : Problem ,
314+ summary_metrics : dict ,
315+ group_name : str ,
316+ sample_size : int ,
317+ ):
298318 """Save all comparison plots with metrics under titles"""
299319 fig , axes = plt .subplots (2 , 2 , figsize = (18 , 14 ))
300320 # fig.suptitle(f"Comparison of methods for {group_name} group (n={sample_size})", fontsize=16)
@@ -330,8 +350,7 @@ def _save_comparison_plots(methods: list, mixture: MixtureDistribution,
330350 _save_pair_plots (methods , mixture , problem , group_name )
331351
332352
333- def _save_pair_plots (methods : list , mixture : MixtureDistribution ,
334- problem : Problem , group_name : str ):
353+ def _save_pair_plots (methods : list , mixture : MixtureDistribution , problem : Problem , group_name : str ):
335354 """Save pair comparison plots with metrics"""
336355 # Bayes vs KMeans
337356 fig , axes = plt .subplots (1 , 2 , figsize = (18 , 8 ))
@@ -345,9 +364,7 @@ def _save_pair_plots(methods: list, mixture: MixtureDistribution,
345364 em = EM (StepCountBreakpointer (max_step = 128 ), FiniteChecker (), method = method )
346365 result = em .solve (problem )
347366
348- title = (
349- f"{ name } "
350- )
367+ title = f"{ name } "
351368 plot_distributions (ax , problem .samples , mixture , result .result , title )
352369
353370 plt .tight_layout ()
@@ -365,9 +382,7 @@ def _save_pair_plots(methods: list, mixture: MixtureDistribution,
365382 method = Method (e_step , m_step )
366383 em = EM (StepCountBreakpointer (max_step = 128 ), FiniteChecker (), method = method )
367384 result = em .solve (problem )
368- title = (
369- f"{ name } "
370- )
385+ title = f"{ name } "
371386 plot_distributions (ax , problem .samples , mixture , result .result , title )
372387
373388 plt .tight_layout ()
@@ -376,7 +391,7 @@ def _save_pair_plots(methods: list, mixture: MixtureDistribution,
376391
377392
378393def run_experiment_group (
379- mixture : MixtureDistribution , sample_size : int , n_experiments : int = 5 , group_name : str = "default"
394+ mixture : MixtureDistribution , sample_size : int , n_experiments : int = 5 , group_name : str = "default"
380395) -> dict [str , dict [str , float ]]:
381396 """Run multiple experiments for a given mixture model"""
382397 all_results = {method : [] for method in ["BayesEStep" , "KMeans+ML" , "Agglo+ML" , "DBSCAN+ML" ]}
@@ -403,8 +418,9 @@ def run_experiment_group(
403418 x = mixture .generate (sample_size )
404419 eps = EnhancedClusteringEStep .auto_eps (x )
405420 problem = Problem (x , mixture )
406- _save_comparison_plots (_initialize_methods (mixture , eps ), mixture ,
407- problem , summary_metrics , group_name , sample_size )
421+ _save_comparison_plots (
422+ _initialize_methods (mixture , eps ), mixture , problem , summary_metrics , group_name , sample_size
423+ )
408424
409425 return summary_metrics
410426
0 commit comments