Skip to content

Commit 6f2fee4

Browse files
authored
fix: minor changes for three last commits
1 parent d8b9c5f commit 6f2fee4

File tree

10 files changed

+145
-164
lines changed

10 files changed

+145
-164
lines changed

examples/readme_example/example_ml.py

Lines changed: 76 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -111,79 +111,90 @@ def plot_distributions(ax, x, true_mixture, fitted_mixture, title):
111111
legend_fontsize = 14
112112
tick_fontsize = 14
113113

114-
sns.histplot(x, color="royalblue", ax=ax, stat="density", alpha=0.8,
115-
binwidth=0.5, edgecolor='white', linewidth=1)
114+
sns.histplot(x, color="royalblue", ax=ax, stat="density", alpha=0.8, binwidth=0.5, edgecolor="white", linewidth=1)
116115

117-
ax.set_xlabel("Значение x", fontsize=label_fontsize, fontweight='bold', labelpad=10)
118-
ax.set_ylabel("Плотность (density)", fontsize=label_fontsize, fontweight='bold', labelpad=10)
119-
ax.set_title(title, fontsize=title_fontsize, fontweight='bold', pad=15)
120-
ax.grid(True, linestyle='--', alpha=0.5, linewidth=1)
116+
ax.set_xlabel("Значение x", fontsize=label_fontsize, fontweight="bold", labelpad=10)
117+
ax.set_ylabel("Плотность (density)", fontsize=label_fontsize, fontweight="bold", labelpad=10)
118+
ax.set_title(title, fontsize=title_fontsize, fontweight="bold", pad=15)
119+
ax.grid(True, linestyle="--", alpha=0.5, linewidth=1)
121120
ax.set_xlim(0, 20)
122121

123122
ax.set_xticks(np.arange(0, 21, 2))
124123
ax.set_yticks(np.linspace(0, ax.get_yticks().max(), len(ax.get_yticks())))
125124

126-
ax.tick_params(axis='both', which='both',
127-
labelsize=tick_fontsize,
128-
width=3, length=8,
129-
pad=8,
130-
colors='black',
131-
grid_color='black',
132-
grid_alpha=0.5)
125+
ax.tick_params(
126+
axis="both",
127+
which="both",
128+
labelsize=tick_fontsize,
129+
width=3,
130+
length=8,
131+
pad=8,
132+
colors="black",
133+
grid_color="black",
134+
grid_alpha=0.5,
135+
)
133136

134137
for label in ax.get_xticklabels() + ax.get_yticklabels():
135-
label.set_fontweight('bold')
138+
label.set_fontweight("bold")
136139

137140
for spine in ax.spines.values():
138141
spine.set_linewidth(3)
139-
spine.set_color('black')
142+
spine.set_color("black")
140143

141144
ax_ = ax.twinx()
142-
ax_.set_ylabel("p(x)", fontsize=label_fontsize, fontweight='bold', labelpad=15)
145+
ax_.set_ylabel("p(x)", fontsize=label_fontsize, fontweight="bold", labelpad=15)
143146
ax_.set_yscale("log")
144147

145148
y_ticks = [0.01, 0.02, 0.05, 0.1, 0.2, 0.5, 1.0]
146149
ax_.set_yticks(y_ticks)
147-
ax_.set_yticklabels([f"{tick:.2f}" for tick in y_ticks],
148-
fontsize=tick_fontsize,
149-
fontweight='bold',
150-
color='black')
150+
ax_.set_yticklabels([f"{tick:.2f}" for tick in y_ticks], fontsize=tick_fontsize, fontweight="bold", color="black")
151151

152-
ax_.tick_params(axis='y', which='both',
153-
width=3, length=8,
154-
pad=10,
155-
colors='black')
152+
ax_.tick_params(axis="y", which="both", width=3, length=8, pad=10, colors="black")
156153

157154
ax_.set_ylim(bottom=y_ticks[0], top=y_ticks[-1])
158155

159156
for spine in ax_.spines.values():
160157
spine.set_linewidth(3)
161-
spine.set_color('black')
158+
spine.set_color("black")
162159

163160
X_plot = np.linspace(0.001, 20, 1000)
164-
ax_.plot(X_plot, [true_mixture.pdf(xi) for xi in X_plot],
165-
color="darkgreen", label="Истинное распределение",
166-
linewidth=4, linestyle='-', alpha=0.9)
167-
ax_.plot(X_plot, [fitted_mixture.pdf(xi) for xi in X_plot],
168-
color="crimson", label="Подобранное распределение",
169-
linewidth=4, linestyle='--', alpha=0.9)
170-
171-
legend = ax_.legend(loc='upper right',
172-
fontsize=legend_fontsize,
173-
framealpha=1,
174-
edgecolor='black',
175-
facecolor='white',
176-
frameon=True,
177-
borderpad=1)
161+
ax_.plot(
162+
X_plot,
163+
[true_mixture.pdf(xi) for xi in X_plot],
164+
color="darkgreen",
165+
label="Истинное распределение",
166+
linewidth=4,
167+
linestyle="-",
168+
alpha=0.9,
169+
)
170+
ax_.plot(
171+
X_plot,
172+
[fitted_mixture.pdf(xi) for xi in X_plot],
173+
color="crimson",
174+
label="Подобранное распределение",
175+
linewidth=4,
176+
linestyle="--",
177+
alpha=0.9,
178+
)
179+
180+
legend = ax_.legend(
181+
loc="upper right",
182+
fontsize=legend_fontsize,
183+
framealpha=1,
184+
edgecolor="black",
185+
facecolor="white",
186+
frameon=True,
187+
borderpad=1,
188+
)
178189
legend.get_frame().set_linewidth(2)
179190

180191
ax.minorticks_on()
181192
ax_.minorticks_on()
182-
ax.tick_params(axis='both', which='minor', width=2, length=5)
183-
ax_.tick_params(axis='both', which='minor', width=2, length=5)
193+
ax.tick_params(axis="both", which="minor", width=2, length=5)
194+
ax_.tick_params(axis="both", which="minor", width=2, length=5)
184195

185196
for y in y_ticks:
186-
ax_.axhline(y=y, color='gray', linestyle=':', alpha=0.3, linewidth=1)
197+
ax_.axhline(y=y, color="gray", linestyle=":", alpha=0.3, linewidth=1)
187198

188199

189200
def save_metrics_table(metrics_data: dict[str, dict[str, float]], filename: str, title: str):
@@ -244,10 +255,14 @@ def _initialize_methods(mixture: MixtureDistribution, eps) -> list[tuple]:
244255
raise ValueError(f"Unsupported model type: {model_type}")
245256
n_clusters = len(models)
246257
return [
247-
("BayesEStep",None,BayesEStep()),
248-
("KMeans+ML","kmeans",EnhancedClusteringEStep(models,clusterizer=KMeans(n_clusters=n_clusters))),
249-
("Agglo+ML","agglo",EnhancedClusteringEStep(models,clusterizer=AgglomerativeClustering(n_clusters=n_clusters))),
250-
("DBSCAN+ML","dbscan",EnhancedClusteringEStep(models,eps=eps,clusterizer=DBSCAN())),
258+
("BayesEStep", None, BayesEStep()),
259+
("KMeans+ML", "kmeans", EnhancedClusteringEStep(models, clusterizer=KMeans(n_clusters=n_clusters))),
260+
(
261+
"Agglo+ML",
262+
"agglo",
263+
EnhancedClusteringEStep(models, clusterizer=AgglomerativeClustering(n_clusters=n_clusters)),
264+
),
265+
("DBSCAN+ML", "dbscan", EnhancedClusteringEStep(models, eps=eps, clusterizer=DBSCAN())),
251266
]
252267

253268

@@ -292,9 +307,14 @@ def _calculate_summary_metrics(all_results: dict) -> dict:
292307
return summary_metrics
293308

294309

295-
def _save_comparison_plots(methods: list, mixture: MixtureDistribution,
296-
problem: Problem, summary_metrics: dict,
297-
group_name: str, sample_size: int):
310+
def _save_comparison_plots(
311+
methods: list,
312+
mixture: MixtureDistribution,
313+
problem: Problem,
314+
summary_metrics: dict,
315+
group_name: str,
316+
sample_size: int,
317+
):
298318
"""Save all comparison plots with metrics under titles"""
299319
fig, axes = plt.subplots(2, 2, figsize=(18, 14))
300320
# fig.suptitle(f"Comparison of methods for {group_name} group (n={sample_size})", fontsize=16)
@@ -330,8 +350,7 @@ def _save_comparison_plots(methods: list, mixture: MixtureDistribution,
330350
_save_pair_plots(methods, mixture, problem, group_name)
331351

332352

333-
def _save_pair_plots(methods: list, mixture: MixtureDistribution,
334-
problem: Problem, group_name: str):
353+
def _save_pair_plots(methods: list, mixture: MixtureDistribution, problem: Problem, group_name: str):
335354
"""Save pair comparison plots with metrics"""
336355
# Bayes vs KMeans
337356
fig, axes = plt.subplots(1, 2, figsize=(18, 8))
@@ -345,9 +364,7 @@ def _save_pair_plots(methods: list, mixture: MixtureDistribution,
345364
em = EM(StepCountBreakpointer(max_step=128), FiniteChecker(), method=method)
346365
result = em.solve(problem)
347366

348-
title = (
349-
f"{name}"
350-
)
367+
title = f"{name}"
351368
plot_distributions(ax, problem.samples, mixture, result.result, title)
352369

353370
plt.tight_layout()
@@ -365,9 +382,7 @@ def _save_pair_plots(methods: list, mixture: MixtureDistribution,
365382
method = Method(e_step, m_step)
366383
em = EM(StepCountBreakpointer(max_step=128), FiniteChecker(), method=method)
367384
result = em.solve(problem)
368-
title = (
369-
f"{name}"
370-
)
385+
title = f"{name}"
371386
plot_distributions(ax, problem.samples, mixture, result.result, title)
372387

373388
plt.tight_layout()
@@ -376,7 +391,7 @@ def _save_pair_plots(methods: list, mixture: MixtureDistribution,
376391

377392

378393
def run_experiment_group(
379-
mixture: MixtureDistribution, sample_size: int, n_experiments: int = 5, group_name: str = "default"
394+
mixture: MixtureDistribution, sample_size: int, n_experiments: int = 5, group_name: str = "default"
380395
) -> dict[str, dict[str, float]]:
381396
"""Run multiple experiments for a given mixture model"""
382397
all_results = {method: [] for method in ["BayesEStep", "KMeans+ML", "Agglo+ML", "DBSCAN+ML"]}
@@ -403,8 +418,9 @@ def run_experiment_group(
403418
x = mixture.generate(sample_size)
404419
eps = EnhancedClusteringEStep.auto_eps(x)
405420
problem = Problem(x, mixture)
406-
_save_comparison_plots(_initialize_methods(mixture, eps), mixture,
407-
problem, summary_metrics, group_name, sample_size)
421+
_save_comparison_plots(
422+
_initialize_methods(mixture, eps), mixture, problem, summary_metrics, group_name, sample_size
423+
)
408424

409425
return summary_metrics
410426

experimental_env/mixture_generators/utils.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from random import uniform
44

55
from mpest import Distribution
6-
from mpest.models import AModel, ExponentialModel, GaussianModel
6+
from mpest.models import AModel, Beta, Cauchy, ExponentialModel, GaussianModel, Pareto, Uniform
77

88

99
def generate_standart_params(models: list[type[AModel]]) -> list[Distribution]:
@@ -14,8 +14,12 @@ def generate_standart_params(models: list[type[AModel]]) -> list[Distribution]:
1414
for m in models:
1515
if m == ExponentialModel:
1616
params = [1.0]
17-
elif m == GaussianModel:
17+
elif m in (GaussianModel, Uniform, Cauchy):
1818
params = [0.0, 1.0]
19+
elif m == Beta:
20+
params = [1.0, 1.0]
21+
elif m == Pareto:
22+
params = [1.0, 2.0]
1923
else: # Weibull
2024
params = [1.0, 1.0]
2125

@@ -34,6 +38,14 @@ def generate_uniform_params(models: list[type[AModel]]) -> list[Distribution]:
3438
params = [uniform(0.1, 5.0)]
3539
elif m == GaussianModel:
3640
params = [uniform(-5.0, 5.0), uniform(0.1, 5.0)]
41+
elif m == Uniform:
42+
params = list(sorted([uniform(-5.0, 5.0), uniform(-5.0, 5.0)]))
43+
elif m == Cauchy:
44+
params = [uniform(-5.0, 5.0), uniform(0.1, 5.0)]
45+
elif m == Beta:
46+
params = [uniform(0.1, 5.0), uniform(0.1, 5.0)]
47+
elif m == Pareto:
48+
params = [uniform(0.1, 5.0), uniform(1.0, 5.0)]
3749
else: # Weibull
3850
params = [uniform(0.1, 5.0), uniform(0.1, 5.0)]
3951

mm.py

Lines changed: 0 additions & 66 deletions
This file was deleted.

mpest/em/methods/l_moments_method.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
EResult = tuple[Problem, np.ndarray] | ResultWithError[MixtureDistribution]
1717

18+
1819
class LMomentsMStep(AMaximization[EResult]):
1920
"""
2021
Class which calculate new params using matrix with indicator from E step.
@@ -55,7 +56,6 @@ def calculate_mr_j(self, r: int, j: int, samples: Samples, indicators: np.ndarra
5556
mr_j += p_rk * b_k
5657
return mr_j
5758

58-
5959
def step(self, e_result: EResult) -> Result:
6060
"""
6161
A function that performs E step

0 commit comments

Comments
 (0)