From f684f37061187acabb6b2cdd2186a3a347cda0cb Mon Sep 17 00:00:00 2001 From: matdev83 <211248003+matdev83@users.noreply.github.com> Date: Fri, 7 Nov 2025 22:12:59 +0100 Subject: [PATCH] Prevent failover validator coroutine leaks --- src/core/persistence.py | 4 +- .../test_persistence_failover_validator.py | 58 +++++++++++++++++++ 2 files changed, 60 insertions(+), 2 deletions(-) create mode 100644 tests/unit/core/test_persistence_failover_validator.py diff --git a/src/core/persistence.py b/src/core/persistence.py index 5f231d57c..6456dd76b 100644 --- a/src/core/persistence.py +++ b/src/core/persistence.py @@ -255,8 +255,6 @@ def validate(self, backend_name: str, model_name: str) -> FailoverValidationResu ), ) - coroutine = backend_service.validate_backend_and_model(backend_name, model_name) - try: asyncio.get_running_loop() loop_running = True @@ -274,6 +272,8 @@ def validate(self, backend_name: str, model_name: str) -> FailoverValidationResu ) return FailoverValidationResult(is_valid=True, warning=warning) + coroutine = backend_service.validate_backend_and_model(backend_name, model_name) + try: is_valid, message = asyncio.run(coroutine) except BackendError as exc: diff --git a/tests/unit/core/test_persistence_failover_validator.py b/tests/unit/core/test_persistence_failover_validator.py new file mode 100644 index 000000000..063f5db1d --- /dev/null +++ b/tests/unit/core/test_persistence_failover_validator.py @@ -0,0 +1,58 @@ +from __future__ import annotations + +import gc +import warnings +from typing import Any, TypeVar +from unittest.mock import AsyncMock + +import pytest +from src.core.persistence import ServiceProviderFailoverRouteValidator + +_T = TypeVar("_T") + + +class _DummyProvider: + def __init__(self, service: Any): + self._service = service + + def get_required_service(self, _service_type: type[_T]) -> _T: + return self._service + + +def _strict_supplier() -> bool: + return False + + +def test_validator_runs_backend_validation_when_loop_not_running() -> None: + backend_service = type("BackendService", (), {})() + backend_service.validate_backend_and_model = AsyncMock(return_value=(True, None)) + + validator = ServiceProviderFailoverRouteValidator( + _DummyProvider(backend_service), _strict_supplier + ) + + result = validator.validate("backend", "model") + + backend_service.validate_backend_and_model.assert_awaited_once() + assert result.is_valid is True + assert result.warning is None + + +@pytest.mark.asyncio +async def test_validator_does_not_leak_coroutines_when_loop_running() -> None: + backend_service = type("BackendService", (), {})() + backend_service.validate_backend_and_model = AsyncMock(return_value=(True, None)) + + validator = ServiceProviderFailoverRouteValidator( + _DummyProvider(backend_service), _strict_supplier + ) + + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always", RuntimeWarning) + result = validator.validate("backend", "model") + gc.collect() + + assert result.is_valid is True + assert result.warning is not None + backend_service.validate_backend_and_model.assert_not_called() + assert not any("was never awaited" in str(w.message) for w in caught)