From f27498259bc61b3fccbab1411b478bb95b7ce04a Mon Sep 17 00:00:00 2001 From: matdev83 <211248003+matdev83@users.noreply.github.com> Date: Fri, 7 Nov 2025 15:11:49 +0100 Subject: [PATCH] Fix metrics timer stats deadlock --- src/core/services/metrics_service.py | 44 +-- .../core/services/aaa_test_metrics_service.py | 260 ++++++++++-------- 2 files changed, 166 insertions(+), 138 deletions(-) diff --git a/src/core/services/metrics_service.py b/src/core/services/metrics_service.py index 2a81f23de..c51dccd2c 100644 --- a/src/core/services/metrics_service.py +++ b/src/core/services/metrics_service.py @@ -14,6 +14,28 @@ _counters: dict[str, int] = defaultdict(int) _timers: dict[str, list[float]] = defaultdict(list) +def _calculate_timer_stats(durations: list[float]) -> dict[str, Any]: + """Calculate statistics for the provided list of durations.""" + if not durations: + return { + "count": 0, + "total": 0.0, + "average": 0.0, + "min": 0.0, + "max": 0.0, + } + + total_duration = sum(durations) + count = len(durations) + return { + "count": count, + "total": total_duration, + "average": total_duration / count, + "min": min(durations), + "max": max(durations), + } + + def inc(name: str, by: int = 1) -> None: """Increment a counter metric by the specified amount. @@ -90,23 +112,9 @@ def get_timer_stats(name: str) -> dict[str, Any]: A dictionary containing count, total, average, min, and max durations """ with _lock: - durations = _timers.get(name, []) - if not durations: - return { - "count": 0, - "total": 0.0, - "average": 0.0, - "min": 0.0, - "max": 0.0, - } + durations = list(_timers.get(name, [])) - return { - "count": len(durations), - "total": sum(durations), - "average": sum(durations) / len(durations), - "min": min(durations), - "max": max(durations), - } + return _calculate_timer_stats(durations) def get_all_timer_stats() -> dict[str, dict[str, Any]]: @@ -116,7 +124,9 @@ def get_all_timer_stats() -> dict[str, dict[str, Any]]: A dictionary mapping timer names to their statistics """ with _lock: - return {name: get_timer_stats(name) for name in _timers} + timers_snapshot = {name: list(durations) for name, durations in _timers.items()} + + return {name: _calculate_timer_stats(durations) for name, durations in timers_snapshot.items()} def log_performance_stats() -> None: diff --git a/tests/unit/core/services/aaa_test_metrics_service.py b/tests/unit/core/services/aaa_test_metrics_service.py index 90e9c8e2c..546299757 100644 --- a/tests/unit/core/services/aaa_test_metrics_service.py +++ b/tests/unit/core/services/aaa_test_metrics_service.py @@ -1,121 +1,139 @@ -""" -Unit tests for the metrics service. -""" - -from __future__ import annotations - -import time - -from src.core.services import metrics_service - - -class TestMetricsService: - """Test the metrics service functionality.""" - - def setup_method(self): - """Reset metrics before each test.""" - # Clear counters and timers - with metrics_service._lock: - metrics_service._counters.clear() - metrics_service._timers.clear() - - def test_counter_increment(self): - """Test basic counter increment functionality.""" - metrics_service.inc("test.counter") - assert metrics_service.get("test.counter") == 1 - - metrics_service.inc("test.counter", by=5) - assert metrics_service.get("test.counter") == 6 - - def test_counter_get_nonexistent(self): - """Test getting a counter that doesn't exist returns 0.""" - assert metrics_service.get("nonexistent.counter") == 0 - - def test_counter_snapshot(self): - """Test getting a snapshot of all counters.""" - metrics_service.inc("counter1") - metrics_service.inc("counter2", by=3) - metrics_service.inc("counter3", by=10) - - snapshot = metrics_service.snapshot() - assert snapshot["counter1"] == 1 - assert snapshot["counter2"] == 3 - assert snapshot["counter3"] == 10 - - def test_record_duration(self): - """Test recording duration measurements.""" - metrics_service.record_duration("test.timer", 0.5) - metrics_service.record_duration("test.timer", 1.0) - metrics_service.record_duration("test.timer", 0.75) - - stats = metrics_service.get_timer_stats("test.timer") - assert stats["count"] == 3 - assert stats["total"] == 2.25 - assert stats["average"] == 0.75 - assert stats["min"] == 0.5 - assert stats["max"] == 1.0 - - def test_timer_context_manager(self): - """Test the timer context manager.""" - with metrics_service.timer("test.operation"): - time.sleep(0.01) # Sleep for 10ms - - stats = metrics_service.get_timer_stats("test.operation") - assert stats["count"] == 1 - assert stats["total"] >= 0.01 # Should be at least 10ms - assert stats["average"] >= 0.01 - - def test_timer_stats_empty(self): - """Test getting stats for a timer with no measurements.""" - stats = metrics_service.get_timer_stats("nonexistent.timer") - assert stats["count"] == 0 - assert stats["total"] == 0.0 - assert stats["average"] == 0.0 - assert stats["min"] == 0.0 - assert stats["max"] == 0.0 - - def test_get_all_timer_stats(self): - """Test getting stats for all timers.""" - metrics_service.record_duration("timer1", 0.5) - metrics_service.record_duration("timer2", 1.0) - - all_stats = metrics_service.get_all_timer_stats() - assert "timer1" in all_stats - assert "timer2" in all_stats - assert all_stats["timer1"]["count"] == 1 - assert all_stats["timer2"]["count"] == 1 - - def test_tool_call_processing_metrics(self): - """Test metrics specific to tool call processing.""" - # Simulate processing and skipping messages - metrics_service.inc("tool_call.messages.processed", by=5) - metrics_service.inc("tool_call.messages.skipped", by=45) - - assert metrics_service.get("tool_call.messages.processed") == 5 - assert metrics_service.get("tool_call.messages.skipped") == 45 - - # Calculate skip rate - total = 5 + 45 - skip_rate = (45 / total) * 100 - assert skip_rate == 90.0 - - def test_log_performance_stats_with_data(self, caplog): - """Test logging performance statistics with data.""" - metrics_service.inc("tool_call.messages.processed", by=10) - metrics_service.inc("tool_call.messages.skipped", by=90) - metrics_service.record_duration("tool_call.processing.duration", 0.05) - metrics_service.record_duration("tool_call.processing.duration", 0.03) - - metrics_service.log_performance_stats() - - # Check that log messages were generated - assert any("processed=10" in record.message for record in caplog.records) - assert any("skipped=90" in record.message for record in caplog.records) - assert any("skip_rate=90.0%" in record.message for record in caplog.records) - - def test_log_performance_stats_no_data(self, caplog): - """Test logging performance statistics with no data.""" - metrics_service.log_performance_stats() - - # Should not log anything when there's no data - assert len(caplog.records) == 0 +""" +Unit tests for the metrics service. +""" + +from __future__ import annotations + +import threading +import time + +from src.core.services import metrics_service + + +class TestMetricsService: + """Test the metrics service functionality.""" + + def setup_method(self): + """Reset metrics before each test.""" + # Clear counters and timers + with metrics_service._lock: + metrics_service._counters.clear() + metrics_service._timers.clear() + + def test_counter_increment(self): + """Test basic counter increment functionality.""" + metrics_service.inc("test.counter") + assert metrics_service.get("test.counter") == 1 + + metrics_service.inc("test.counter", by=5) + assert metrics_service.get("test.counter") == 6 + + def test_counter_get_nonexistent(self): + """Test getting a counter that doesn't exist returns 0.""" + assert metrics_service.get("nonexistent.counter") == 0 + + def test_counter_snapshot(self): + """Test getting a snapshot of all counters.""" + metrics_service.inc("counter1") + metrics_service.inc("counter2", by=3) + metrics_service.inc("counter3", by=10) + + snapshot = metrics_service.snapshot() + assert snapshot["counter1"] == 1 + assert snapshot["counter2"] == 3 + assert snapshot["counter3"] == 10 + + def test_record_duration(self): + """Test recording duration measurements.""" + metrics_service.record_duration("test.timer", 0.5) + metrics_service.record_duration("test.timer", 1.0) + metrics_service.record_duration("test.timer", 0.75) + + stats = metrics_service.get_timer_stats("test.timer") + assert stats["count"] == 3 + assert stats["total"] == 2.25 + assert stats["average"] == 0.75 + assert stats["min"] == 0.5 + assert stats["max"] == 1.0 + + def test_timer_context_manager(self): + """Test the timer context manager.""" + with metrics_service.timer("test.operation"): + time.sleep(0.01) # Sleep for 10ms + + stats = metrics_service.get_timer_stats("test.operation") + assert stats["count"] == 1 + assert stats["total"] >= 0.01 # Should be at least 10ms + assert stats["average"] >= 0.01 + + def test_timer_stats_empty(self): + """Test getting stats for a timer with no measurements.""" + stats = metrics_service.get_timer_stats("nonexistent.timer") + assert stats["count"] == 0 + assert stats["total"] == 0.0 + assert stats["average"] == 0.0 + assert stats["min"] == 0.0 + assert stats["max"] == 0.0 + + def test_get_all_timer_stats(self): + """Test getting stats for all timers.""" + metrics_service.record_duration("timer1", 0.5) + metrics_service.record_duration("timer2", 1.0) + + all_stats = metrics_service.get_all_timer_stats() + assert "timer1" in all_stats + assert "timer2" in all_stats + assert all_stats["timer1"]["count"] == 1 + assert all_stats["timer2"]["count"] == 1 + + def test_get_all_timer_stats_thread_safe(self): + """Ensure get_all_timer_stats does not deadlock when called from another thread.""" + metrics_service.record_duration("timer1", 0.1) + metrics_service.record_duration("timer2", 0.2) + + result: dict[str, dict[str, float]] = {} + + def target() -> None: + result.update(metrics_service.get_all_timer_stats()) + + worker = threading.Thread(target=target) + worker.start() + worker.join(timeout=1) + + assert not worker.is_alive(), "get_all_timer_stats deadlocked when called from another thread" + assert result, "Expected timer stats to be populated after thread execution" + + def test_tool_call_processing_metrics(self): + """Test metrics specific to tool call processing.""" + # Simulate processing and skipping messages + metrics_service.inc("tool_call.messages.processed", by=5) + metrics_service.inc("tool_call.messages.skipped", by=45) + + assert metrics_service.get("tool_call.messages.processed") == 5 + assert metrics_service.get("tool_call.messages.skipped") == 45 + + # Calculate skip rate + total = 5 + 45 + skip_rate = (45 / total) * 100 + assert skip_rate == 90.0 + + def test_log_performance_stats_with_data(self, caplog): + """Test logging performance statistics with data.""" + metrics_service.inc("tool_call.messages.processed", by=10) + metrics_service.inc("tool_call.messages.skipped", by=90) + metrics_service.record_duration("tool_call.processing.duration", 0.05) + metrics_service.record_duration("tool_call.processing.duration", 0.03) + + metrics_service.log_performance_stats() + + # Check that log messages were generated + assert any("processed=10" in record.message for record in caplog.records) + assert any("skipped=90" in record.message for record in caplog.records) + assert any("skip_rate=90.0%" in record.message for record in caplog.records) + + def test_log_performance_stats_no_data(self, caplog): + """Test logging performance statistics with no data.""" + metrics_service.log_performance_stats() + + # Should not log anything when there's no data + assert len(caplog.records) == 0