From 553c1bf098034a1150ee3e7316ab39b910368456 Mon Sep 17 00:00:00 2001 From: ruskaruma Date: Sat, 8 Nov 2025 02:57:56 +0530 Subject: [PATCH] Fix realtime PCM duration calculation (#2010) --- src/agents/realtime/_util.py | 17 ++++++++++++++--- tests/realtime/test_openai_realtime.py | 19 ++++++++++++------- tests/realtime/test_playback_tracker.py | 11 ++++++----- .../test_playback_tracker_manual_unit.py | 6 +++--- 4 files changed, 35 insertions(+), 18 deletions(-) diff --git a/src/agents/realtime/_util.py b/src/agents/realtime/_util.py index 52a3483e9..4de38f06f 100644 --- a/src/agents/realtime/_util.py +++ b/src/agents/realtime/_util.py @@ -2,8 +2,19 @@ from .config import RealtimeAudioFormat +PCM16_SAMPLE_RATE_HZ = 24_000 +PCM16_SAMPLE_WIDTH_BYTES = 2 +G711_SAMPLE_RATE_HZ = 8_000 + def calculate_audio_length_ms(format: RealtimeAudioFormat | None, audio_bytes: bytes) -> float: - if format and isinstance(format, str) and format.startswith("g711"): - return (len(audio_bytes) / 8000) * 1000 - return (len(audio_bytes) / 24 / 2) * 1000 + if not audio_bytes: + return 0.0 + + normalized_format = format.lower() if isinstance(format, str) else None + + if normalized_format and normalized_format.startswith("g711"): + return (len(audio_bytes) / G711_SAMPLE_RATE_HZ) * 1000 + + samples = len(audio_bytes) / PCM16_SAMPLE_WIDTH_BYTES + return (samples / PCM16_SAMPLE_RATE_HZ) * 1000 diff --git a/tests/realtime/test_openai_realtime.py b/tests/realtime/test_openai_realtime.py index 08c45e5d7..721ad1ec8 100644 --- a/tests/realtime/test_openai_realtime.py +++ b/tests/realtime/test_openai_realtime.py @@ -698,23 +698,27 @@ async def test_audio_timing_calculation_accuracy(self, model): for event in audio_deltas: await model._handle_ws_event(event) - # Should accumulate audio length: 8 bytes / 24 / 2 * 1000 = milliseconds - # Total: 8 bytes / 24 / 2 * 1000 - expected_length = (8 / 24 / 2) * 1000 + # Should accumulate audio length: 8 bytes -> 4 samples -> (4 / 24000) * 1000 ≈ 0.167 ms + expected_length = (8 / (24_000 * 2)) * 1000 # Test through the actual audio state tracker audio_state = model._audio_state_tracker.get_state("item_1", 0) assert audio_state is not None - assert abs(audio_state.audio_length_ms - expected_length) < 0.001 + assert audio_state.audio_length_ms == pytest.approx(expected_length, rel=0, abs=1e-6) def test_calculate_audio_length_ms_pure_function(self, model): """Test the pure audio length calculation function.""" from agents.realtime._util import calculate_audio_length_ms # Test various audio buffer sizes for pcm16 format - assert calculate_audio_length_ms("pcm16", b"test") == (4 / 24 / 2) * 1000 # 4 bytes + expected_pcm = (len(b"test") / (24_000 * 2)) * 1000 + assert calculate_audio_length_ms("pcm16", b"test") == pytest.approx( + expected_pcm, rel=0, abs=1e-6 + ) # 4 bytes assert calculate_audio_length_ms("pcm16", b"") == 0 # empty - assert calculate_audio_length_ms("pcm16", b"a" * 48) == 1000.0 # exactly 1000ms worth + assert calculate_audio_length_ms("pcm16", b"a" * 48) == pytest.approx( + (48 / (24_000 * 2)) * 1000, rel=0, abs=1e-6 + ) # exactly 1ms worth # Test g711 format assert calculate_audio_length_ms("g711_ulaw", b"test") == (4 / 8000) * 1000 # 4 bytes @@ -741,7 +745,8 @@ async def test_handle_audio_delta_state_management(self, model): # Test that audio state is tracked correctly audio_state = model._audio_state_tracker.get_state("test_item", 5) assert audio_state is not None - assert audio_state.audio_length_ms == (4 / 24 / 2) * 1000 # 4 bytes in milliseconds + expected_ms = (len(b"test") / (24_000 * 2)) * 1000 + assert audio_state.audio_length_ms == pytest.approx(expected_ms, rel=0, abs=1e-6) # Test that last audio item is tracked last_item = model._audio_state_tracker.get_last_audio_item() diff --git a/tests/realtime/test_playback_tracker.py b/tests/realtime/test_playback_tracker.py index c0bfba468..135034ec2 100644 --- a/tests/realtime/test_playback_tracker.py +++ b/tests/realtime/test_playback_tracker.py @@ -64,9 +64,9 @@ def test_audio_state_accumulation_across_deltas(self): state = tracker.get_state("item_1", 0) assert state is not None - # Should accumulate: 8 bytes / 24 / 2 * 1000 = 166.67ms - expected_length = (8 / 24 / 2) * 1000 - assert abs(state.audio_length_ms - expected_length) < 0.01 + # Should accumulate: 8 bytes -> 4 samples -> (4 / 24000) * 1000 ≈ 0.167ms + expected_length = (8 / (24_000 * 2)) * 1000 + assert state.audio_length_ms == pytest.approx(expected_length, rel=0, abs=1e-6) def test_state_cleanup_on_interruption(self): """Test both trackers properly reset state on interruption.""" @@ -105,8 +105,9 @@ def test_audio_length_calculation_with_different_formats(self): # Test PCM format (24kHz, default) pcm_bytes = b"test" # 4 bytes pcm_length = calculate_audio_length_ms("pcm16", pcm_bytes) - assert pcm_length == (4 / 24 / 2) * 1000 # ~83.33ms + expected_pcm = (len(pcm_bytes) / (24_000 * 2)) * 1000 + assert pcm_length == pytest.approx(expected_pcm, rel=0, abs=1e-6) # Test None format (defaults to PCM) none_length = calculate_audio_length_ms(None, pcm_bytes) - assert none_length == pcm_length + assert none_length == pytest.approx(expected_pcm, rel=0, abs=1e-6) diff --git a/tests/realtime/test_playback_tracker_manual_unit.py b/tests/realtime/test_playback_tracker_manual_unit.py index 35adc1264..ff901dd84 100644 --- a/tests/realtime/test_playback_tracker_manual_unit.py +++ b/tests/realtime/test_playback_tracker_manual_unit.py @@ -5,16 +5,16 @@ def test_playback_tracker_on_play_bytes_and_state(): tr = RealtimePlaybackTracker() tr.set_audio_format("pcm16") # PCM path - # 48k bytes -> (48000 / 24 / 2) * 1000 = 1,000,000ms per current util + # 48k bytes -> (48000 / (24000 * 2)) * 1000 = 1_000ms tr.on_play_bytes("item1", 0, b"x" * 48000) st = tr.get_state() assert st["current_item_id"] == "item1" - assert st["elapsed_ms"] and abs(st["elapsed_ms"] - 1_000_000.0) < 1e-6 + assert st["elapsed_ms"] and abs(st["elapsed_ms"] - 1_000.0) < 1e-6 # Subsequent play on same item accumulates tr.on_play_ms("item1", 0, 500.0) st2 = tr.get_state() - assert st2["elapsed_ms"] and abs(st2["elapsed_ms"] - 1_000_500.0) < 1e-6 + assert st2["elapsed_ms"] and abs(st2["elapsed_ms"] - 1_500.0) < 1e-6 # Interruption clears state tr.on_interrupted()