Skip to content

Commit 0607c2a

Browse files
committed
feat(estimators, optimizers): switch to generics
1 parent 8b80550 commit 0607c2a

29 files changed

+874
-423
lines changed

rework_pysatl_mpest/core/mixture.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -347,6 +347,30 @@ def generate(self, size: int) -> NDArray[DType]:
347347
np.random.shuffle(samples)
348348
return samples
349349

350+
def astype(self, new_dtype: type[DType]) -> "MixtureModel[DType]":
351+
"""Creates a copy of the MixtureModel with a new data type.
352+
353+
If the specified `new_dtype` is the same as the instance's current `dtype`,
354+
this method returns the original instance instead.
355+
356+
Parameters
357+
----------
358+
new_dtype : type[DType]
359+
The target NumPy data type for the new distribution instance.
360+
361+
Returns
362+
-------
363+
MixtureModel[DType]
364+
A new MixtureModel instance with all components and weights converted to the
365+
specified `new_dtype`, or the original instance if the `dtype` is
366+
unchanged.
367+
"""
368+
if self.dtype is new_dtype:
369+
return self
370+
371+
new_mixture = MixtureModel(components=self.components, weights=self.weights.copy(), dtype=new_dtype)
372+
return new_mixture
373+
350374
def __getitem__(self, key: int) -> "ContinuousDistribution[DType]":
351375
"""Retrieves components by index.
352376

rework_pysatl_mpest/estimators/base_estimator.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,19 @@
11
"""A module that provides an abstract class for implementing custom estimators."""
22

3-
__author__ = "Danil Totmyanin"
3+
__author__ = "Danil Totmyanin, Aleksandra Ri"
44
__copyright__ = "Copyright (c) 2025 PySATL project"
55
__license__ = "SPDX-License-Identifier: MIT"
66

77
from abc import ABC, abstractmethod
8+
from typing import Generic
89

910
from numpy.typing import ArrayLike
1011

1112
from ..core import MixtureModel
13+
from ..typings import DType
1214

1315

14-
class BaseEstimator(ABC):
16+
class BaseEstimator(ABC, Generic[DType]):
1517
"""Abstract class for a mixture model parameter estimator.
1618
1719
This class defines the interface for all estimator algorithms. Estimators are responsible for
@@ -35,7 +37,7 @@ class BaseEstimator(ABC):
3537
"""
3638

3739
@abstractmethod
38-
def fit(self, X: ArrayLike, mixture: MixtureModel) -> MixtureModel:
40+
def fit(self, X: ArrayLike, mixture: MixtureModel[DType]) -> MixtureModel[DType]:
3941
"""Fits the mixture model to the provided data.
4042
4143
This method estimates the parameters of the model's components and their
@@ -45,11 +47,11 @@ def fit(self, X: ArrayLike, mixture: MixtureModel) -> MixtureModel:
4547
----------
4648
X : ArrayLike
4749
The input data sample for fitting the model.
48-
mixture : MixtureModel
50+
mixture : MixtureModel[DType]
4951
The initial mixture model to be fitted.
5052
5153
Returns
5254
-------
53-
MixtureModel
55+
MixtureModel[DType]
5456
The mixture model with estimated parameters.
5557
"""

rework_pysatl_mpest/estimators/ecm.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
to fit the parameters of a mixture model to data.
77
"""
88

9-
__author__ = "Danil Totmyanin"
9+
__author__ = "Danil Totmyanin, Aleksandra Ri"
1010
__copyright__ = "Copyright (c) 2025 PySATL project"
1111
__license__ = "SPDX-License-Identifier: MIT"
1212

@@ -17,6 +17,7 @@
1717

1818
from ..core import MixtureModel
1919
from ..optimizers import Optimizer
20+
from ..typings import DType
2021
from .base_estimator import BaseEstimator
2122
from .iterative import (
2223
Breakpointer,
@@ -30,7 +31,7 @@
3031
from .iterative._logger import IterationsHistory
3132

3233

33-
class ECM(BaseEstimator):
34+
class ECM(BaseEstimator[DType]):
3435
"""An estimator that implements the Expectation-Conditional Maximization (ECM) algorithm.
3536
3637
This class encapsulates the logic for the ECM algorithm, a variant of the
@@ -87,10 +88,10 @@ def __init__(self, breakpointers: Sequence[Breakpointer], pruners: Sequence[Prun
8788
self.breakpointers = list(breakpointers)
8889
self.pruners = list(pruners)
8990
self.optimizer = optimizer
90-
self._logger: IterationsHistory | None = None
91+
self._logger: IterationsHistory[DType] | None = None
9192

9293
@property
93-
def logger(self) -> IterationsHistory:
94+
def logger(self) -> IterationsHistory[DType]:
9495
"""An object that collects information about each iteration.
9596
9697
Raises
@@ -103,7 +104,7 @@ def logger(self) -> IterationsHistory:
103104
raise AttributeError("Logger is not available. Call the 'fit' method first.")
104105
return self._logger
105106

106-
def fit(self, X: ArrayLike, mixture: MixtureModel, once_in_iterations: int = 1) -> MixtureModel:
107+
def fit(self, X: ArrayLike, mixture: MixtureModel[DType], once_in_iterations: int = 1) -> MixtureModel[DType]:
107108
"""Fits the mixture model to the data using the ECM algorithm.
108109
109110
This method sets up and runs an iterative pipeline to estimate the
@@ -115,15 +116,15 @@ def fit(self, X: ArrayLike, mixture: MixtureModel, once_in_iterations: int = 1)
115116
----------
116117
X : ArrayLike
117118
The input dataset for fitting the model.
118-
mixture : MixtureModel
119+
mixture : MixtureModel[DType]
119120
The initial mixture model to be fitted.
120121
once_in_iterations : int, optional
121122
The logging frequency. A value of `n` means logging occurs every
122123
`n` iterations. Defaults to 1.
123124
124125
Returns
125126
-------
126-
MixtureModel
127+
MixtureModel[DType]
127128
The mixture model with the estimated parameters.
128129
"""
129130

@@ -132,7 +133,7 @@ def fit(self, X: ArrayLike, mixture: MixtureModel, once_in_iterations: int = 1)
132133
block = OptimizationBlock(i, comp.params_to_optimize, MaximizationStrategy.QFUNCTION)
133134
blocks.append(block)
134135

135-
pipeline = Pipeline(
136+
pipeline: Pipeline[DType] = Pipeline(
136137
[ExpectationStep(), MaximizationStep(blocks, self.optimizer)],
137138
self.breakpointers,
138139
self.pruners,

rework_pysatl_mpest/estimators/iterative/_logger.py

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -6,22 +6,22 @@
66
and errors across iterations.
77
"""
88

9-
__author__ = "Maksim Pastukhov"
9+
__author__ = "Maksim Pastukhov, Aleksandra Ri"
1010
__copyright__ = "Copyright (c) 2025 PySATL project"
1111
__license__ = "SPDX-License-Identifier: MIT"
1212

1313
from dataclasses import dataclass
14-
from typing import Optional
14+
from typing import Generic, Optional
1515

16-
from numpy import float64
1716
from numpy.typing import NDArray
1817

1918
from ...core import MixtureModel
19+
from ...typings import DType
2020
from .pruner import Pruner
2121

2222

2323
@dataclass
24-
class IterationRecord:
24+
class IterationRecord(Generic[DType]):
2525
"""Data class representing a single pipeline iteration snapshot.
2626
2727
This class captures the complete state of a pipeline iteration after
@@ -32,15 +32,15 @@ class IterationRecord:
3232
----------
3333
iteration : int
3434
The iteration number (0-based index).
35-
mixture : MixtureModel
35+
mixture : MixtureModel[DType]
3636
The state of the mixture model after pruning in this iteration.
37-
X : NDArray[float64]
37+
X : NDArray[DType]
3838
The input data sample being processed (conventionally named `X`).
39-
H : Optional[NDArray[float64]]
39+
H : Optional[NDArray[DType]]
4040
The responsibility matrix (posterior probabilities) if available,
4141
where `H[i, j]` represents the probability that data point `i`
4242
belongs to component `j`. May be `None` if not computed.
43-
pruners_used : Optional[list[Pruner]]
43+
pruners_used : Optional[list[Pruner[DType]]]
4444
List of pruner instances that were applied during this iteration.
4545
`None` or empty if no pruning occurred.
4646
error : Optional[Exception]
@@ -49,14 +49,14 @@ class IterationRecord:
4949
"""
5050

5151
iteration: int
52-
mixture: MixtureModel
53-
X: NDArray[float64]
54-
H: Optional[NDArray[float64]]
55-
pruners_used: Optional[list[Pruner]]
52+
mixture: MixtureModel[DType]
53+
X: NDArray[DType]
54+
H: Optional[NDArray[DType]]
55+
pruners_used: Optional[list[Pruner[DType]]]
5656
error: Optional[Exception]
5757

5858

59-
class IterationsHistory:
59+
class IterationsHistory(Generic[DType]):
6060
"""A container for storing and accessing pipeline iteration history.
6161
6262
`IterationsHistory` collects and stores snapshots of each pipeline iteration
@@ -98,7 +98,7 @@ class IterationsHistory:
9898
_counter : int
9999
Internal counter tracking the total number of `log()` calls
100100
(i.e., total iterations processed, not just recorded ones).
101-
_logs : list[IterationRecord]
101+
_logs : list[IterationRecord[DType]]
102102
List of stored iteration records. Only iterations matching the
103103
recording frequency are appended.
104104
@@ -123,11 +123,11 @@ def __init__(self, once_in_iterations: int = 1) -> None:
123123
if once_in_iterations < 1:
124124
raise ValueError("once_in_iterations must be a positive integer")
125125

126-
self._logs: list[IterationRecord] = []
126+
self._logs: list[IterationRecord[DType]] = []
127127
self._counter: int = 0
128128
self.once_in_iterations = once_in_iterations
129129

130-
def log(self, record: IterationRecord) -> None:
130+
def log(self, record: IterationRecord[DType]) -> None:
131131
"""Store an iteration record based on the configured frequency.
132132
133133
The record is stored only if the current internal counter is divisible
@@ -136,7 +136,7 @@ def log(self, record: IterationRecord) -> None:
136136
137137
Parameters
138138
----------
139-
record : IterationRecord
139+
record : IterationRecord[DType]
140140
The iteration snapshot to potentially store. The `record.iteration`
141141
should ideally match the logger's internal state, though this is
142142
not enforced.
@@ -168,7 +168,7 @@ def __len__(self) -> int:
168168
"""
169169
return len(self._logs)
170170

171-
def __getitem__(self, index: int) -> IterationRecord:
171+
def __getitem__(self, index: int) -> IterationRecord[DType]:
172172
"""Access a stored iteration record by index.
173173
174174
Supports both positive (0-based) and negative indexing (e.g., `-1` for last).
@@ -181,7 +181,7 @@ def __getitem__(self, index: int) -> IterationRecord:
181181
182182
Returns
183183
-------
184-
IterationRecord
184+
IterationRecord[DType]
185185
The recorded state of the specified iteration.
186186
187187
Raises

rework_pysatl_mpest/estimators/iterative/_strategies/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,13 @@
99

1010
from ....distributions import ContinuousDistribution
1111
from ....optimizers import Optimizer
12+
from ....typings import DType
1213
from ..pipeline_state import PipelineState
1314
from ..steps import OptimizationBlock
1415
from .q_function import q_function_strategy as _q_function_strategy
1516

1617
q_function_strategy: Callable[
17-
[ContinuousDistribution, PipelineState, OptimizationBlock, Optimizer], tuple[int, dict[str, float]]
18+
[ContinuousDistribution, PipelineState, OptimizationBlock, Optimizer], tuple[int, dict[str, DType]]
1819
] = _q_function_strategy
1920

2021

0 commit comments

Comments
 (0)