Skip to content

Commit bacc65b

Browse files
authored
fix: #1907 guardrails w/ turn_detection.interrupt_response: true (#1968)
1 parent 1240562 commit bacc65b

File tree

4 files changed

+88
-27
lines changed

4 files changed

+88
-27
lines changed

src/agents/realtime/model_inputs.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,9 @@ class RealtimeModelSendToolOutput:
9595
class RealtimeModelSendInterrupt:
9696
"""Send an interrupt to the model."""
9797

98+
force_response_cancel: bool = False
99+
"""Force sending a response.cancel event even if automatic cancellation is enabled."""
100+
98101

99102
@dataclass
100103
class RealtimeModelSendSessionUpdate:

src/agents/realtime/openai_realtime.py

Lines changed: 30 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -395,36 +395,36 @@ async def _send_interrupt(self, event: RealtimeModelSendInterrupt) -> None:
395395
current_item_id = playback_state.get("current_item_id")
396396
current_item_content_index = playback_state.get("current_item_content_index")
397397
elapsed_ms = playback_state.get("elapsed_ms")
398+
398399
if current_item_id is None or elapsed_ms is None:
399400
logger.debug(
400401
"Skipping interrupt. "
401402
f"Item id: {current_item_id}, "
402403
f"elapsed ms: {elapsed_ms}, "
403404
f"content index: {current_item_content_index}"
404405
)
405-
return
406-
407-
current_item_content_index = current_item_content_index or 0
408-
if elapsed_ms > 0:
409-
await self._emit_event(
410-
RealtimeModelAudioInterruptedEvent(
411-
item_id=current_item_id,
412-
content_index=current_item_content_index,
413-
)
414-
)
415-
converted = _ConversionHelper.convert_interrupt(
416-
current_item_id,
417-
current_item_content_index,
418-
int(elapsed_ms),
419-
)
420-
await self._send_raw_message(converted)
421406
else:
422-
logger.debug(
423-
"Didn't interrupt bc elapsed ms is < 0. "
424-
f"Item id: {current_item_id}, "
425-
f"elapsed ms: {elapsed_ms}, "
426-
f"content index: {current_item_content_index}"
427-
)
407+
current_item_content_index = current_item_content_index or 0
408+
if elapsed_ms > 0:
409+
await self._emit_event(
410+
RealtimeModelAudioInterruptedEvent(
411+
item_id=current_item_id,
412+
content_index=current_item_content_index,
413+
)
414+
)
415+
converted = _ConversionHelper.convert_interrupt(
416+
current_item_id,
417+
current_item_content_index,
418+
int(elapsed_ms),
419+
)
420+
await self._send_raw_message(converted)
421+
else:
422+
logger.debug(
423+
"Didn't interrupt bc elapsed ms is < 0. "
424+
f"Item id: {current_item_id}, "
425+
f"elapsed ms: {elapsed_ms}, "
426+
f"content index: {current_item_content_index}"
427+
)
428428

429429
session = self._created_session
430430
automatic_response_cancellation_enabled = (
@@ -434,12 +434,16 @@ async def _send_interrupt(self, event: RealtimeModelSendInterrupt) -> None:
434434
and session.audio.input.turn_detection is not None
435435
and session.audio.input.turn_detection.interrupt_response is True
436436
)
437-
if not automatic_response_cancellation_enabled:
437+
should_cancel_response = event.force_response_cancel or (
438+
not automatic_response_cancellation_enabled
439+
)
440+
if should_cancel_response:
438441
await self._cancel_response()
439442

440-
self._audio_state_tracker.on_interrupted()
441-
if self._playback_tracker:
442-
self._playback_tracker.on_interrupted()
443+
if current_item_id is not None and elapsed_ms is not None:
444+
self._audio_state_tracker.on_interrupted()
445+
if self._playback_tracker:
446+
self._playback_tracker.on_interrupted()
443447

444448
async def _send_session_update(self, event: RealtimeModelSendSessionUpdate) -> None:
445449
"""Send a session update to the model."""

src/agents/realtime/session.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -704,7 +704,7 @@ async def _run_output_guardrails(self, text: str, response_id: str) -> bool:
704704
)
705705

706706
# Interrupt the model
707-
await self._model.send_event(RealtimeModelSendInterrupt())
707+
await self._model.send_event(RealtimeModelSendInterrupt(force_response_cancel=True))
708708

709709
# Send guardrail triggered message
710710
guardrail_names = [result.guardrail.get_name() for result in triggered_results]

tests/realtime/test_openai_realtime.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import json
2+
from types import SimpleNamespace
23
from typing import Any, cast
34
from unittest.mock import AsyncMock, Mock, patch
45

@@ -509,6 +510,59 @@ async def test_send_event_dispatch(self, model, monkeypatch):
509510
# session update -> 1
510511
assert send_raw.await_count == 8
511512

513+
@pytest.mark.asyncio
514+
async def test_interrupt_force_cancel_overrides_auto_cancellation(self, model, monkeypatch):
515+
"""Interrupt should send response.cancel even when auto cancel is enabled."""
516+
model._audio_state_tracker.set_audio_format("pcm16")
517+
model._audio_state_tracker.on_audio_delta("item_1", 0, b"\x00" * 4800)
518+
model._ongoing_response = True
519+
model._created_session = SimpleNamespace(
520+
audio=SimpleNamespace(
521+
input=SimpleNamespace(
522+
turn_detection=SimpleNamespace(interrupt_response=True)
523+
)
524+
)
525+
)
526+
527+
send_raw = AsyncMock()
528+
emit_event = AsyncMock()
529+
monkeypatch.setattr(model, "_send_raw_message", send_raw)
530+
monkeypatch.setattr(model, "_emit_event", emit_event)
531+
532+
await model._send_interrupt(RealtimeModelSendInterrupt(force_response_cancel=True))
533+
534+
assert send_raw.await_count == 2
535+
payload_types = {call.args[0].type for call in send_raw.call_args_list}
536+
assert payload_types == {"conversation.item.truncate", "response.cancel"}
537+
assert model._ongoing_response is False
538+
assert model._audio_state_tracker.get_last_audio_item() is None
539+
540+
@pytest.mark.asyncio
541+
async def test_interrupt_respects_auto_cancellation_when_not_forced(self, model, monkeypatch):
542+
"""Interrupt should avoid sending response.cancel when relying on automatic cancellation."""
543+
model._audio_state_tracker.set_audio_format("pcm16")
544+
model._audio_state_tracker.on_audio_delta("item_1", 0, b"\x00" * 4800)
545+
model._ongoing_response = True
546+
model._created_session = SimpleNamespace(
547+
audio=SimpleNamespace(
548+
input=SimpleNamespace(
549+
turn_detection=SimpleNamespace(interrupt_response=True)
550+
)
551+
)
552+
)
553+
554+
send_raw = AsyncMock()
555+
emit_event = AsyncMock()
556+
monkeypatch.setattr(model, "_send_raw_message", send_raw)
557+
monkeypatch.setattr(model, "_emit_event", emit_event)
558+
559+
await model._send_interrupt(RealtimeModelSendInterrupt())
560+
561+
assert send_raw.await_count == 1
562+
assert send_raw.call_args_list[0].args[0].type == "conversation.item.truncate"
563+
assert all(call.args[0].type != "response.cancel" for call in send_raw.call_args_list)
564+
assert model._ongoing_response is True
565+
512566
def test_add_remove_listener_and_tools_conversion(self, model):
513567
listener = AsyncMock()
514568
model.add_listener(listener)

0 commit comments

Comments
 (0)