diff --git a/src/agents/voice/models/openai_stt.py b/src/agents/voice/models/openai_stt.py index 7ac008428..013aec671 100644 --- a/src/agents/voice/models/openai_stt.py +++ b/src/agents/voice/models/openai_stt.py @@ -321,18 +321,33 @@ def _check_errors(self) -> None: if exc and isinstance(exc, Exception): self._stored_exception = exc - def _cleanup_tasks(self) -> None: + async def _cleanup_tasks(self) -> None: + """Cancel all pending tasks and wait for them to complete. + + This ensures that any exceptions raised by the tasks are properly handled + and prevents warnings about unhandled task exceptions. + """ + tasks = [] + if self._listener_task and not self._listener_task.done(): self._listener_task.cancel() + tasks.append(self._listener_task) if self._process_events_task and not self._process_events_task.done(): self._process_events_task.cancel() + tasks.append(self._process_events_task) if self._stream_audio_task and not self._stream_audio_task.done(): self._stream_audio_task.cancel() + tasks.append(self._stream_audio_task) if self._connection_task and not self._connection_task.done(): self._connection_task.cancel() + tasks.append(self._connection_task) + + # Wait for all cancelled tasks to complete and collect exceptions + if tasks: + await asyncio.gather(*tasks, return_exceptions=True) async def transcribe_turns(self) -> AsyncIterator[str]: self._connection_task = asyncio.create_task(self._process_websocket_connection()) @@ -367,7 +382,7 @@ async def close(self) -> None: if self._websocket: await self._websocket.close() - self._cleanup_tasks() + await self._cleanup_tasks() class OpenAISTTModel(STTModel): diff --git a/tests/voice/test_openai_stt.py b/tests/voice/test_openai_stt.py index 8eefc995f..6b6896947 100644 --- a/tests/voice/test_openai_stt.py +++ b/tests/voice/test_openai_stt.py @@ -378,3 +378,140 @@ async def test_inactivity_timeout(): assert len(collected_turns) == 0, "No transcripts expected, but we got something?" await session.close() + + +@pytest.mark.asyncio +async def test_cleanup_tasks_cancels_and_awaits_all_tasks(): + """ + Test that _cleanup_tasks() properly cancels and awaits all pending tasks. + This ensures proper resource cleanup and prevents unhandled task exceptions. + """ + mock_ws = create_mock_websocket( + [ + json.dumps({"type": "transcription_session.created"}), + json.dumps({"type": "transcription_session.updated"}), + ] + ) + + with patch("websockets.connect", return_value=mock_ws): + audio_input = await FakeStreamedAudioInput.get(count=2) + stt_settings = STTModelSettings() + + session = OpenAISTTTranscriptionSession( + input=audio_input, + client=AsyncMock(api_key="FAKE_KEY"), + model="whisper-1", + settings=stt_settings, + trace_include_sensitive_data=False, + trace_include_sensitive_audio_data=False, + ) + + # Create some tasks to simulate active background tasks + async def long_running_task(): + try: + await asyncio.sleep(10) + except asyncio.CancelledError: + # Expected when cancelled + raise + + session._listener_task = asyncio.create_task(long_running_task()) + session._process_events_task = asyncio.create_task(long_running_task()) + session._stream_audio_task = asyncio.create_task(long_running_task()) + session._connection_task = asyncio.create_task(long_running_task()) + + # Verify tasks are running + assert not session._listener_task.done() + assert not session._process_events_task.done() + assert not session._stream_audio_task.done() + assert not session._connection_task.done() + + # Call cleanup_tasks + await session._cleanup_tasks() + + # Verify all tasks were cancelled and completed + assert session._listener_task.cancelled() + assert session._process_events_task.cancelled() + assert session._stream_audio_task.cancelled() + assert session._connection_task.cancelled() + + +@pytest.mark.asyncio +async def test_cleanup_tasks_handles_exceptions(): + """ + Test that _cleanup_tasks() properly handles exceptions from cancelled tasks + without raising them (using return_exceptions=True). + """ + mock_ws = create_mock_websocket( + [ + json.dumps({"type": "transcription_session.created"}), + json.dumps({"type": "transcription_session.updated"}), + ] + ) + + with patch("websockets.connect", return_value=mock_ws): + audio_input = await FakeStreamedAudioInput.get(count=2) + stt_settings = STTModelSettings() + + session = OpenAISTTTranscriptionSession( + input=audio_input, + client=AsyncMock(api_key="FAKE_KEY"), + model="whisper-1", + settings=stt_settings, + trace_include_sensitive_data=False, + trace_include_sensitive_audio_data=False, + ) + + # Create tasks that raise exceptions when cancelled + async def task_with_exception(): + try: + await asyncio.sleep(10) + except asyncio.CancelledError as e: + raise RuntimeError("Task exception during cancellation") from e + + session._listener_task = asyncio.create_task(task_with_exception()) + session._process_events_task = asyncio.create_task(task_with_exception()) + + # cleanup_tasks should not raise despite the exceptions + await session._cleanup_tasks() + + # Tasks should be done (cancelled or exception raised) + assert session._listener_task.done() + assert session._process_events_task.done() + + +@pytest.mark.asyncio +async def test_close_calls_cleanup_tasks(): + """ + Test that close() properly calls _cleanup_tasks() to clean up background tasks. + """ + mock_ws = create_mock_websocket( + [ + json.dumps({"type": "transcription_session.created"}), + json.dumps({"type": "transcription_session.updated"}), + ] + ) + + with patch("websockets.connect", return_value=mock_ws): + audio_input = await FakeStreamedAudioInput.get(count=2) + stt_settings = STTModelSettings() + + session = OpenAISTTTranscriptionSession( + input=audio_input, + client=AsyncMock(api_key="FAKE_KEY"), + model="whisper-1", + settings=stt_settings, + trace_include_sensitive_data=False, + trace_include_sensitive_audio_data=False, + ) + + # Create a task + async def long_running_task(): + await asyncio.sleep(10) + + session._listener_task = asyncio.create_task(long_running_task()) + + # close() should cancel and await the task + await session.close() + + # Task should be cancelled + assert session._listener_task.cancelled()