Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 75 additions & 50 deletions rework_pysatl_mpest/core/mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,22 +9,25 @@

from collections.abc import Iterator, Sequence
from copy import copy
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Generic, Optional

import numpy as np
from numpy import float64
from numpy.typing import ArrayLike, NDArray
from scipy.special import logsumexp, softmax

from ..typings import DType

if TYPE_CHECKING:
from ..distributions import ContinuousDistribution


class MixtureModel:
class MixtureModel(Generic[DType]):
"""Represents a finite mixture of continuous probability distributions.

This class encapsulates a collection of distribution components and their
corresponding weights.
corresponding weights. All components within the mixture are automatically
converted to the specified `dtype` of the MixtureModel, ensuring
computational consistency.

Instances of this class can be compared for equality (``==``) and
inequality (``!=``). Two models are considered equal if they have the
Expand All @@ -39,17 +42,21 @@ class MixtureModel:
An array of initial weights for the components. The weights must be
positive and sum to 1. If None, components are assigned equal
weights. Defaults to None.
dtype : type[DType], optional
The numpy data type used for internal calculations and
output arrays (e.g., `np.float32` or `np.float64`).
Defaults to `np.float64`.

Attributes
----------
components : tuple[ContinuousDistribution]
components : tuple[ContinuousDistribution[DType], ...]
A tuple of the distribution objects that form the mixture.
n_components : int
The number of components in the mixture.
weights : NDArray[np.float64]
weights : NDArray[DType]
A NumPy array of the normalized weights for each component. The sum
of weights is always 1.
log_weights : NDArray[np.float64]
log_weights : NDArray[DType]
A NumPy array of the natural logarithm of the component weights.

Raises
Expand All @@ -72,31 +79,40 @@ class MixtureModel:
generate
"""

def __init__(self, components: Sequence["ContinuousDistribution"], weights: Optional[ArrayLike] = None):
_dtype: type[DType]

def __init__(
self,
components: Sequence["ContinuousDistribution"],
weights: Optional[ArrayLike] = None,
dtype: type[DType] = np.float64, # type: ignore[assignment]
):
n_components = len(components)
if n_components == 0:
raise ValueError("List of components cannot be an empty")

self._dtype = dtype

if weights is None:
weights = np.full(n_components, 1.0 / n_components)
weights = np.full(n_components, 1.0 / n_components, dtype=self.dtype)
else:
weights = np.asarray(weights, dtype=float64)
weights = np.asarray(weights, dtype=self.dtype)
self._validate_weights(n_components, weights)

self._components = list(components)
self._log_weights = np.log(weights + 1e-30)
self._cached_weights: Optional[NDArray[float64]] = None
self._components = [comp.astype(self.dtype) for comp in components]
self._log_weights = np.log(weights + np.finfo(self.dtype).tiny)
self._cached_weights: Optional[NDArray[DType]] = None

self._sorted_pairs_cache: Optional[list[tuple[ContinuousDistribution, float]]] = None
self._sorted_pairs_cache: Optional[list[tuple[ContinuousDistribution[DType], DType]]] = None

def _validate_weights(self, n_components: int, weights: NDArray[float64]):
def _validate_weights(self, n_components: int, weights: NDArray[DType]):
"""Validates the component weights.

Parameters
----------
n_components : int
The expected number of components.
weights : NDArray[np.float64]
weights : NDArray[DType]
The array of weights to validate.

Raises
Expand All @@ -112,9 +128,15 @@ def _validate_weights(self, n_components: int, weights: NDArray[float64]):
if np.any(weights < 0):
raise ValueError("Weights must be positive.")

if not np.isclose(np.sum(weights), 1.0):
if not np.isclose(np.sum(weights), self.dtype(1.0)):
raise ValueError(f"Sum of the weights must be equal 1, but it equal {np.sum(weights)}.")

@property
def dtype(self) -> type[DType]:
"""type[DType]: The numpy data type of the mixture's outputs."""

return self._dtype

@property
def n_components(self):
"""int: The number of components in the mixture model."""
Expand All @@ -123,13 +145,13 @@ def n_components(self):

@property
def components(self):
"""tuple[ContinuousDistribution, ...]: The components of the mixture."""
"""tuple[ContinuousDistribution[DType], ...]: The components of the mixture."""

return tuple(self._components)

@property
def weights(self) -> NDArray[float64]:
"""NDArray[np.float64]: The normalized weights of the components.
def weights(self) -> NDArray[DType]:
"""NDArray[DType]: The normalized weights of the components.

The weights are computed from the log-weights using the softmax
function and cached for efficiency.
Expand All @@ -141,8 +163,8 @@ def weights(self) -> NDArray[float64]:
return self._cached_weights # type: ignore

@property
def log_weights(self) -> NDArray[float64]:
"""NDArray[np.float64]: The logarithm of the component weights."""
def log_weights(self) -> NDArray[DType]:
"""NDArray[DType]: The logarithm of the component weights."""

return self._log_weights

Expand All @@ -162,11 +184,11 @@ def log_weights(self, new_log_weights: ArrayLike):
number of components.
"""

new_log_weights = np.asarray(new_log_weights, dtype=float64)
new_log_weights = np.asarray(new_log_weights, dtype=self.dtype)

if len(new_log_weights) != self.n_components:
raise ValueError("The length of the new logit vector does not match the number of components.")
self._log_weights = np.asarray(new_log_weights, dtype=float)
self._log_weights = new_log_weights
self._cached_weights = None
self._sorted_pairs_cache = None

Expand All @@ -192,11 +214,13 @@ def add_component(self, component: "ContinuousDistribution", weight: float):
if not (0 < weight < 1):
raise ValueError("The weight of the new component must be in the range (0, 1).")

self._log_weights += np.log(1 - weight)
new_log_weight = np.log(weight)
d_weight = self.dtype(weight)
self._log_weights += np.log(self.dtype(1.0) - d_weight)
new_log_weight = np.log(d_weight)
self._log_weights = np.append(self._log_weights, new_log_weight)

self._components.append(component)
new_component = component.astype(self.dtype)
self._components.append(new_component)
self._cached_weights = None
self._sorted_pairs_cache = None

Expand Down Expand Up @@ -231,7 +255,7 @@ def remove_component(self, component_idx: int):
self._cached_weights = None
self._sorted_pairs_cache = None

def pdf(self, X: ArrayLike) -> NDArray[float64]:
def pdf(self, X: ArrayLike) -> NDArray[DType]:
"""Probability Density Function of the mixture.

The PDF is computed as the weighted sum of the PDFs of its
Expand All @@ -244,15 +268,15 @@ def pdf(self, X: ArrayLike) -> NDArray[float64]:

Returns
-------
NDArray[np.float64]
NDArray[DType]
The PDF values corresponding to each point in :attr:`X`.
"""

X = np.asarray(X, dtype=float64)
X = np.asarray(X, dtype=self.dtype)
component_pdfs = np.array([comp.pdf(X) for comp in self.components])
return np.asarray(np.dot(self.weights, component_pdfs))

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would rewrite this method for better numerical stability:

Suggested change
return np.asarray(np.dot(self.weights, component_pdfs))
return np.exp(self.lpdf(X))

Then you will need to synchronize the corresponding test like this:

изображение


def lpdf(self, X: ArrayLike) -> NDArray[float64]:
def lpdf(self, X: ArrayLike) -> NDArray[DType]:
"""Logarithms of the Probability Density Function.

Parameters
Expand All @@ -262,17 +286,17 @@ def lpdf(self, X: ArrayLike) -> NDArray[float64]:

Returns
-------
NDArray[np.float64]
NDArray[DType]
The log-PDF values corresponding to each point in :attr:`X`.
"""

X = np.atleast_1d(X)
X = np.atleast_1d(X).astype(self.dtype)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
X = np.atleast_1d(X).astype(self.dtype)
X = np.asarray(X, dtype=self.dtype)

component_lpdfs = np.array([comp.lpdf(X) for comp in self.components])
log_weights = self.log_weights
log_terms = log_weights[:, np.newaxis] + component_lpdfs
return logsumexp(log_terms, axis=0) # type: ignore

def loglikelihood(self, X: ArrayLike) -> float:
def loglikelihood(self, X: ArrayLike) -> DType:
"""Log-likelihood of the complete data :attr:`X`.

The log-likelihood is the sum of the log-PDF values for all data
Expand All @@ -285,14 +309,14 @@ def loglikelihood(self, X: ArrayLike) -> float:

Returns
-------
float
DType
The total log-likelihood value.
"""

X = np.asarray(X, dtype=float64)
X = np.asarray(X, dtype=self.dtype)
return np.sum(self.lpdf(X))

def generate(self, size: int) -> NDArray[float64]:
def generate(self, size: int) -> NDArray[DType]:
"""Generates random samples from the mixture model.

First, a component is chosen based on the mixture weights. Then, a
Expand All @@ -306,13 +330,13 @@ def generate(self, size: int) -> NDArray[float64]:

Returns
-------
NDArray[np.float64]
NDArray[DType]
A NumPy array containing the generated samples. Returns an
empty array if :attr:`size` is not positive.
"""

if size == 0:
return np.array([])
return np.array([], dtype=self.dtype)

component_choices = np.random.choice(self.n_components, size=size, p=self.weights)

Expand All @@ -324,7 +348,7 @@ def generate(self, size: int) -> NDArray[float64]:
np.random.shuffle(samples)
return samples

def __getitem__(self, key: int) -> "ContinuousDistribution":
def __getitem__(self, key: int) -> "ContinuousDistribution[DType]":
"""Retrieves components by index.

Parameters
Expand All @@ -334,46 +358,47 @@ def __getitem__(self, key: int) -> "ContinuousDistribution":

Returns
-------
ContinuousDistribution
ContinuousDistribution[DType]
A single component of the mixture
"""

return self.components[key]

def __iter__(self) -> Iterator["ContinuousDistribution"]:
def __iter__(self) -> Iterator["ContinuousDistribution[DType]"]:
"""Returns an iterator over the mixture components.

This allows the `MixtureModel` instance to be used directly in
loops, such as a `for` loop, to iterate over its components.

Yields
------
Iterator[ContinuousDistribution]
Iterator[ContinuousDistribution[DType]
An iterator that yields the components of the mixture model.
"""

return iter(self.components)

def __copy__(self) -> "MixtureModel":
def __copy__(self) -> "MixtureModel[DType]":
"""Creates a copy of the mixture model instance.

Returns
-------
MixtureModel
MixtureModel[DType]
A new instance of the distribution, identical to the original.
"""

copied_components = [copy(component) for component in self._components]
new_mixture = MixtureModel(components=copied_components, weights=self.weights.copy())
new_mixture = MixtureModel(components=copied_components, weights=self.weights.copy(), dtype=self.dtype)
return new_mixture

def _get_sorted_pairs(self, for_hashing: bool = False) -> list[tuple["ContinuousDistribution", float]]:
def _get_sorted_pairs(self, for_hashing: bool = False) -> list[tuple["ContinuousDistribution[DType]", DType]]:
"""Internal helper to get component-weight pairs, sorted by component hash."""

if self._sorted_pairs_cache is None or for_hashing:
weights_to_use = self.weights
if for_hashing:
weights_to_use = np.round(weights_to_use, 8)
decimals = np.finfo(self.dtype).precision
weights_to_use = np.round(weights_to_use, decimals)

pairs = sorted(zip(self.components, weights_to_use), key=lambda p: hash(p[0]))
if not for_hashing:
Expand Down Expand Up @@ -401,7 +426,7 @@ def __eq__(self, other: object) -> bool:
if not isinstance(other, MixtureModel):
return NotImplemented

if self.n_components != other.n_components:
if self.dtype != other.dtype or self.n_components != other.n_components:
return False

self_pairs = self._get_sorted_pairs()
Expand All @@ -425,4 +450,4 @@ def __hash__(self) -> int:
"""

sorted_pairs_for_hash = self._get_sorted_pairs(for_hashing=True)
return hash(tuple(sorted_pairs_for_hash))
return hash((self.dtype, tuple(sorted_pairs_for_hash)))
Loading