diff --git a/rework_pysatl_mpest/estimators/iterative/__init__.py b/rework_pysatl_mpest/estimators/iterative/__init__.py index 8a528f42..6d8075d3 100644 --- a/rework_pysatl_mpest/estimators/iterative/__init__.py +++ b/rework_pysatl_mpest/estimators/iterative/__init__.py @@ -18,7 +18,7 @@ from .breakpointer import Breakpointer -from .breakpointers import StepBreakpointer +from .breakpointers import LikelihoodBreakpointer, StepBreakpointer from .pipeline import Pipeline from .pipeline_state import PipelineState from .pipeline_step import PipelineStep @@ -29,6 +29,7 @@ __all__ = [ "Breakpointer", "ExpectationStep", + "LikelihoodBreakpointer", "MaximizationStep", "MaximizationStrategy", "OptimizationBlock", diff --git a/rework_pysatl_mpest/estimators/iterative/breakpointers/__init__.py b/rework_pysatl_mpest/estimators/iterative/breakpointers/__init__.py index 27e1cd7f..b270f842 100644 --- a/rework_pysatl_mpest/estimators/iterative/breakpointers/__init__.py +++ b/rework_pysatl_mpest/estimators/iterative/breakpointers/__init__.py @@ -6,6 +6,7 @@ __license__ = "SPDX-License-Identifier: MIT" +from .likelihood_breakpointer import LikelihoodBreakpointer from .step_breakpointer import StepBreakpointer -__all__ = ["StepBreakpointer"] +__all__ = ["LikelihoodBreakpointer", "StepBreakpointer"] diff --git a/rework_pysatl_mpest/estimators/iterative/breakpointers/likelihood_breakpointer.py b/rework_pysatl_mpest/estimators/iterative/breakpointers/likelihood_breakpointer.py new file mode 100644 index 00000000..a231c1d9 --- /dev/null +++ b/rework_pysatl_mpest/estimators/iterative/breakpointers/likelihood_breakpointer.py @@ -0,0 +1,91 @@ +"""Module that provides a :class:`rework_pysatl_mpest.estimators.iterative.Pipeline` +stopping strategy based on when log-likelihood of the mixture converges""" + +__author__ = "Maksim Pastukhov" +__copyright__ = "Copyright (c) 2025 PySATL project" +__license__ = "SPDX-License-Identifier: MIT" + +from typing import Optional + +from ..breakpointer import Breakpointer +from ..pipeline_state import PipelineState + + +class LikelihoodBreakpointer(Breakpointer): + """Stops the pipeline when the log-likelihood of the mixture converges. + + This breakpointer terminates the iterative estimation process when the + absolute difference between the current and previous log-likelihood values + falls below a specified threshold: + + |L_{t+1} - L_t| < threshold + + It tracks the log-likelihood of the current mixture model on the observed + data at each iteration and compares it to the previous value. + + Parameters + ---------- + threshold : float + The convergence threshold for the log-likelihood difference. + Must be a positive number. + + Attributes + ---------- + threshold : float + The convergence threshold. + + Raises + ------ + ValueError + If `threshold` is not greater than 0. + + Methods + ------- + check(state: PipelineState) -> bool + Returns True if convergence is detected, False otherwise. + """ + + def __init__(self, threshold: float): + self._validate(threshold) + self.threshold = threshold + self._L_old: Optional[float] = None + self._L_new: Optional[float] = None + + def _validate(self, threshold: float): + """Validates the threshold parameter.""" + if threshold <= 0: + raise ValueError("The threshold must be greater than 0") + + def check(self, state: PipelineState) -> bool: + """Checks if the log-likelihood has converged. + + Computes the current log-likelihood of the mixture on the data in + the pipeline state and compares it with the previous value. + + Parameters + ---------- + state : PipelineState + The current state of the pipeline, which must contain a valid + `curr_mixture` and data `X`. + + Returns + ------- + bool + True if |L_new - L_old| < threshold (converged), False otherwise. + On the first call (no previous likelihood), returns False and + initializes internal state. + """ + self._L_new = state.curr_mixture.loglikelihood(state.X) + + # First iteration: cannot compare, so just store and continue + if self._L_old is None: + self._L_old = self._L_new + return False + + if abs(self._L_new - self._L_old) < self.threshold: + self._L_old = None + self._L_new = None + return True + else: + self._L_old = self._L_new + return False diff --git a/rework_tests/unit/estimators/iterative/breakpointers/test_likelihood_breakpointer.py b/rework_tests/unit/estimators/iterative/breakpointers/test_likelihood_breakpointer.py new file mode 100644 index 00000000..1515d226 --- /dev/null +++ b/rework_tests/unit/estimators/iterative/breakpointers/test_likelihood_breakpointer.py @@ -0,0 +1,101 @@ +"""Tests for LikelihoodBreakpointer""" + +__author__ = "Maksim Pastukhov" +__copyright__ = "Copyright (c) 2025 PySATL project" +__license__ = "SPDX-License-Identifier: MIT" + +from unittest.mock import Mock + +import numpy as np +import pytest +from rework_pysatl_mpest.estimators.iterative import PipelineState +from rework_pysatl_mpest.estimators.iterative.breakpointers.likelihood_breakpointer import LikelihoodBreakpointer + + +@pytest.fixture +def mock_mixture_with_likelihood(): + """Returns a mock mixture that returns predefined log-likelihoods.""" + + def make_mixture(ll_values): + mixture = Mock() + gen = iter(ll_values) + mixture.loglikelihood = lambda X: next(gen) + return mixture + + return make_mixture + + +@pytest.fixture +def dummy_state_factory(): + """Factory to create PipelineState with custom mixture.""" + + def _make_state(mixture, X=None): + if X is None: + X = np.array([1.0, 2.0, 3.0]) + return PipelineState(X, None, None, mixture, None) + + return _make_state + + +# --- Initialization Tests --- + + +class TestInitialization: + @pytest.mark.parametrize("threshold", [0.01, 0.5, 10.0]) + def test_initialization_with_valid_threshold(self, threshold: float): + bp = LikelihoodBreakpointer(threshold) + assert bp.threshold == threshold + assert bp._L_old is None + assert bp._L_new is None + + def test_initialization_rejects_non_positive_threshold(self): + with pytest.raises(ValueError, match="The threshold must be greater than 0"): + LikelihoodBreakpointer(0.0) + with pytest.raises(ValueError, match="The threshold must be greater than 0"): + LikelihoodBreakpointer(-1.0) + + +# --- Core Logic Tests --- + + +class TestCheckLogic: + def test_first_call_never_stops(self, mock_mixture_with_likelihood, dummy_state_factory): + FIRST_CALL = 5.0 + mixture = mock_mixture_with_likelihood([5.0]) + state = dummy_state_factory(mixture) + bp = LikelihoodBreakpointer(0.1) + assert not bp.check(state) + assert bp._L_old == FIRST_CALL + + def test_convergence_detected(self, mock_mixture_with_likelihood, dummy_state_factory): + mixture = mock_mixture_with_likelihood([10.0, 10.05]) + state = dummy_state_factory(mixture) + bp = LikelihoodBreakpointer(0.1) + + assert not bp.check(state) + assert bp.check(state) + + def test_no_convergence_continues(self, mock_mixture_with_likelihood, dummy_state_factory): + mixture = mock_mixture_with_likelihood([5.0, 6.0]) + state = dummy_state_factory(mixture) + bp = LikelihoodBreakpointer(0.5) + + assert not bp.check(state) + assert not bp.check(state) + + def test_reset_after_convergence_enables_reuse(self, mock_mixture_with_likelihood, dummy_state_factory): + mixture = mock_mixture_with_likelihood([10.0, 10.01, 20.0, 20.005]) + state = dummy_state_factory(mixture) + bp = LikelihoodBreakpointer(0.02) + + # First cycle + assert not bp.check(state) + assert bp.check(state) + + # After reset, should behave like new + assert not bp.check(state) + assert bp.check(state) + + # Internal state reset + assert bp._L_old is None + assert bp._L_new is None