diff --git a/src/core/di/weak_container.py b/src/core/di/weak_container.py index 7d8efae78..a5de59885 100644 --- a/src/core/di/weak_container.py +++ b/src/core/di/weak_container.py @@ -71,12 +71,19 @@ def register_instance( if cleanup_callback: self._cleanup_callbacks[service_type] = cleanup_callback - # Set up weak reference callback for cleanup - def on_delete(ref): + instance_ref = weakref.ref(instance) + + # Set up weak reference callback for cleanup without strong reference cycles + def on_delete(_: weakref.ReferenceType[Any]) -> None: + service_instance = instance_ref() + if service_instance is None: + return try: - cleanup_callback(instance) - except Exception as e: - logger.warning(f"Error in cleanup callback for {service_type}: {e}") + cleanup_callback(service_instance) + except Exception as exc: + logger.warning( + f"Error in cleanup callback for {service_type}: {exc}", + ) weakref.ref(instance, on_delete) @@ -123,13 +130,17 @@ async def get_service(self, service_type: type[T]) -> T: # Set up cleanup callback if registered cleanup_callback = self._cleanup_callbacks.get(service_type) if cleanup_callback: + instance_ref = weakref.ref(instance) - def on_delete(ref): + def on_delete(_: weakref.ReferenceType[Any]) -> None: + service_instance = instance_ref() + if service_instance is None: + return try: - cleanup_callback(instance) - except Exception as e: + cleanup_callback(service_instance) + except Exception as exc: logger.warning( - f"Error in cleanup callback for {service_type}: {e}" + f"Error in cleanup callback for {service_type}: {exc}", ) weakref.ref(instance, on_delete) diff --git a/tests/unit/core/di/test_weak_container.py b/tests/unit/core/di/test_weak_container.py new file mode 100644 index 000000000..eb396b61a --- /dev/null +++ b/tests/unit/core/di/test_weak_container.py @@ -0,0 +1,36 @@ +"""Tests for the weak DI container memory management.""" + +from __future__ import annotations + +import gc +import weakref + +import pytest +from src.core.di.weak_container import WeakDIContainer + + +class _DummyService: + """Simple service used to verify garbage collection behavior.""" + + +@pytest.mark.asyncio +async def test_weak_container_allows_garbage_collection() -> None: + """Instances registered in the weak container should not be leaked.""" + + container = WeakDIContainer() + + def factory() -> _DummyService: + return _DummyService() + + container.register_factory(_DummyService, factory) + + service = await container.get_service(_DummyService) + service_ref = weakref.ref(service) + + # Drop the strong reference held by the test + del service + + # Force garbage collection to trigger weakref callbacks + gc.collect() + + assert service_ref() is None