Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion rework_pysatl_mpest/estimators/iterative/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@


from .breakpointer import Breakpointer
from .breakpointers import StepBreakpointer
from .breakpointers import LikelihoodBreakpointer, StepBreakpointer
from .pipeline import Pipeline
from .pipeline_state import PipelineState
from .pipeline_step import PipelineStep
Expand All @@ -29,6 +29,7 @@
__all__ = [
"Breakpointer",
"ExpectationStep",
"LikelihoodBreakpointer",
"MaximizationStep",
"MaximizationStrategy",
"OptimizationBlock",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
__license__ = "SPDX-License-Identifier: MIT"


from .likelihood_breakpointer import LikelihoodBreakpointer
from .step_breakpointer import StepBreakpointer

__all__ = ["StepBreakpointer"]
__all__ = ["LikelihoodBreakpointer", "StepBreakpointer"]
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
"""Module that provides a :class:`rework_pysatl_mpest.estimators.iterative.Pipeline`
stopping strategy based on when log-likelihood of the mixture converges"""

__author__ = "Maksim Pastukhov"
__copyright__ = "Copyright (c) 2025 PySATL project"
__license__ = "SPDX-License-Identifier: MIT"

from typing import Optional

from ..breakpointer import Breakpointer
from ..pipeline_state import PipelineState


class LikelihoodBreakpointer(Breakpointer):
"""Stops the pipeline when the log-likelihood of the mixture converges.

This breakpointer terminates the iterative estimation process when the
absolute difference between the current and previous log-likelihood values
falls below a specified threshold:

|L_{t+1} - L_t| < threshold

It tracks the log-likelihood of the current mixture model on the observed
data at each iteration and compares it to the previous value.

Parameters
----------
threshold : float
The convergence threshold for the log-likelihood difference.
Must be a positive number.

Attributes
----------
threshold : float
The convergence threshold.

Raises
------
ValueError
If `threshold` is not greater than 0.

Methods
-------
check(state: PipelineState) -> bool
Returns True if convergence is detected, False otherwise.
"""

def __init__(self, threshold: float):
self._validate(threshold)
self.threshold = threshold
self._L_old: Optional[float] = None
self._L_new: Optional[float] = None

def _validate(self, threshold: float):
"""Validates the threshold parameter."""
if threshold <= 0:
raise ValueError("The threshold must be greater than 0")

def check(self, state: PipelineState) -> bool:
"""Checks if the log-likelihood has converged.

Computes the current log-likelihood of the mixture on the data in
the pipeline state and compares it with the previous value.

Parameters
----------
state : PipelineState
The current state of the pipeline, which must contain a valid
`curr_mixture` and data `X`.

Returns
-------
bool
True if |L_new - L_old| < threshold (converged), False otherwise.
On the first call (no previous likelihood), returns False and
initializes internal state.
"""
self._L_new = state.curr_mixture.loglikelihood(state.X)

# First iteration: cannot compare, so just store and continue
if self._L_old is None:
self._L_old = self._L_new
return False

if abs(self._L_new - self._L_old) < self.threshold:
self._L_old = None
self._L_new = None
return True
else:
self._L_old = self._L_new
return False
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
"""Tests for LikelihoodBreakpointer"""

__author__ = "Maksim Pastukhov"
__copyright__ = "Copyright (c) 2025 PySATL project"
__license__ = "SPDX-License-Identifier: MIT"

from unittest.mock import Mock

import numpy as np
import pytest
from rework_pysatl_mpest.estimators.iterative import PipelineState
from rework_pysatl_mpest.estimators.iterative.breakpointers.likelihood_breakpointer import LikelihoodBreakpointer


@pytest.fixture
def mock_mixture_with_likelihood():
"""Returns a mock mixture that returns predefined log-likelihoods."""

def make_mixture(ll_values):
mixture = Mock()
gen = iter(ll_values)
mixture.loglikelihood = lambda X: next(gen)
return mixture

return make_mixture


@pytest.fixture
def dummy_state_factory():
"""Factory to create PipelineState with custom mixture."""

def _make_state(mixture, X=None):
if X is None:
X = np.array([1.0, 2.0, 3.0])
return PipelineState(X, None, None, mixture, None)

return _make_state


# --- Initialization Tests ---


class TestInitialization:
@pytest.mark.parametrize("threshold", [0.01, 0.5, 10.0])
def test_initialization_with_valid_threshold(self, threshold: float):
bp = LikelihoodBreakpointer(threshold)
assert bp.threshold == threshold
assert bp._L_old is None
assert bp._L_new is None

def test_initialization_rejects_non_positive_threshold(self):
with pytest.raises(ValueError, match="The threshold must be greater than 0"):
LikelihoodBreakpointer(0.0)
with pytest.raises(ValueError, match="The threshold must be greater than 0"):
LikelihoodBreakpointer(-1.0)


# --- Core Logic Tests ---


class TestCheckLogic:
def test_first_call_never_stops(self, mock_mixture_with_likelihood, dummy_state_factory):
FIRST_CALL = 5.0
mixture = mock_mixture_with_likelihood([5.0])
state = dummy_state_factory(mixture)
bp = LikelihoodBreakpointer(0.1)
assert not bp.check(state)
assert bp._L_old == FIRST_CALL

def test_convergence_detected(self, mock_mixture_with_likelihood, dummy_state_factory):
mixture = mock_mixture_with_likelihood([10.0, 10.05])
state = dummy_state_factory(mixture)
bp = LikelihoodBreakpointer(0.1)

assert not bp.check(state)
assert bp.check(state)

def test_no_convergence_continues(self, mock_mixture_with_likelihood, dummy_state_factory):
mixture = mock_mixture_with_likelihood([5.0, 6.0])
state = dummy_state_factory(mixture)
bp = LikelihoodBreakpointer(0.5)

assert not bp.check(state)
assert not bp.check(state)

def test_reset_after_convergence_enables_reuse(self, mock_mixture_with_likelihood, dummy_state_factory):
mixture = mock_mixture_with_likelihood([10.0, 10.01, 20.0, 20.005])
state = dummy_state_factory(mixture)
bp = LikelihoodBreakpointer(0.02)

# First cycle
assert not bp.check(state)
assert bp.check(state)

# After reset, should behave like new
assert not bp.check(state)
assert bp.check(state)

# Internal state reset
assert bp._L_old is None
assert bp._L_new is None