From 107f03533c2abde1586109ec4d4709a556801c21 Mon Sep 17 00:00:00 2001 From: Rachit Mehta Date: Thu, 25 Sep 2025 10:47:46 -0400 Subject: [PATCH 01/23] feat(bidirectional_streaming): Add experimental bidirectional streaming MVP POC implementation --- pyproject.toml | 9 +- .../bidirectional_streaming/agent/__init__.py | 2 + .../bidirectional_streaming/agent/agent.py | 167 ++++ .../event_loop/__init__.py | 2 + .../event_loop/bidirectional_event_loop.py | 539 ++++++++++++ .../models/__init__.py | 2 + .../models/bidirectional_model.py | 115 +++ .../models/novasonic.py | 777 ++++++++++++++++++ .../tests/test_bidirectional_streaming.py | 203 +++++ .../bidirectional_streaming/types/__init__.py | 3 + .../types/bidirectional_streaming.py | 167 ++++ .../bidirectional_streaming/utils/debug.py | 45 + 12 files changed, 2030 insertions(+), 1 deletion(-) create mode 100644 src/strands/experimental/bidirectional_streaming/agent/__init__.py create mode 100644 src/strands/experimental/bidirectional_streaming/agent/agent.py create mode 100644 src/strands/experimental/bidirectional_streaming/event_loop/__init__.py create mode 100644 src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py create mode 100644 src/strands/experimental/bidirectional_streaming/models/__init__.py create mode 100644 src/strands/experimental/bidirectional_streaming/models/bidirectional_model.py create mode 100644 src/strands/experimental/bidirectional_streaming/models/novasonic.py create mode 100644 src/strands/experimental/bidirectional_streaming/tests/test_bidirectional_streaming.py create mode 100644 src/strands/experimental/bidirectional_streaming/types/__init__.py create mode 100644 src/strands/experimental/bidirectional_streaming/types/bidirectional_streaming.py create mode 100644 src/strands/experimental/bidirectional_streaming/utils/debug.py diff --git a/pyproject.toml b/pyproject.toml index 3c2243299..d4f7e6eee 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -53,6 +53,13 @@ sagemaker = [ "boto3-stubs[sagemaker-runtime]>=1.26.0,<2.0.0", "openai>=1.68.0,<2.0.0", # SageMaker uses OpenAI-compatible interface ] +bidirectional-streaming = [ + "pyaudio>=0.2.13", + "rx>=3.2.0", + "smithy-aws-core>=0.0.1", + "pytz", + "aws_sdk_bedrock_runtime", +] otel = ["opentelemetry-exporter-otlp-proto-http>=1.30.0,<2.0.0"] docs = [ "sphinx>=5.0.0,<6.0.0", @@ -68,7 +75,7 @@ a2a = [ "fastapi>=0.115.12,<1.0.0", "starlette>=0.46.2,<1.0.0", ] -all = ["strands-agents[a2a,anthropic,docs,gemini,litellm,llamaapi,mistral,ollama,openai,writer,sagemaker,otel]"] +all = ["strands-agents[a2a,anthropic,bidirectional-streaming,docs,litellm,llamaapi,mistral,ollama,openai,writer,sagemaker,otel]"] dev = [ "commitizen>=4.4.0,<5.0.0", diff --git a/src/strands/experimental/bidirectional_streaming/agent/__init__.py b/src/strands/experimental/bidirectional_streaming/agent/__init__.py new file mode 100644 index 000000000..bbd2c91f3 --- /dev/null +++ b/src/strands/experimental/bidirectional_streaming/agent/__init__.py @@ -0,0 +1,2 @@ +"""Bidirectional streaming agent package.""" +# Agent package \ No newline at end of file diff --git a/src/strands/experimental/bidirectional_streaming/agent/agent.py b/src/strands/experimental/bidirectional_streaming/agent/agent.py new file mode 100644 index 000000000..cfc005576 --- /dev/null +++ b/src/strands/experimental/bidirectional_streaming/agent/agent.py @@ -0,0 +1,167 @@ +"""Bidirectional Agent for real-time streaming conversations. + +AGENT PURPOSE: +------------- +Provides type-safe constructor and session management for real-time audio/text +interaction. Serves as the bidirectional equivalent to invoke_async() → stream_async() +but establishes sessions that continue indefinitely with concurrent task management. + +ARCHITECTURAL APPROACH: +---------------------- +While invoke_async() creates single request-response cycles that terminate after +stop_reason: "end_turn" with sequential tool processing, start_conversation() +establishes persistent sessions with concurrent processing of model events, tool +execution, and user input without session termination. + +DESIGN CHOICE: +------------- +Uses dedicated BidirectionalAgent class (Option 1 from design document) for: +- Type safety with no conditional behavior based on model type +- Separation of concerns - solely focused on bidirectional streaming +- Future proofing - allows changes without implications to existing Agent class +""" + +import asyncio +import logging +from typing import AsyncIterable, List, Optional + +from strands.tools.registry import ToolRegistry +from strands.types.content import Messages + +from ..event_loop.bidirectional_event_loop import start_bidirectional_connection, stop_bidirectional_connection +from ..models.bidirectional_model import BidirectionalModel +from ..types.bidirectional_streaming import AudioInputEvent, BidirectionalStreamEvent +from ..utils.debug import log_event, log_flow + +logger = logging.getLogger(__name__) + + +class BidirectionalAgent: + """Agent for bidirectional streaming conversations. + + Provides type-safe constructor and session management for real-time + audio/text interaction with concurrent processing capabilities. + """ + + def __init__( + self, + model: BidirectionalModel, + tools: Optional[List] = None, + system_prompt: Optional[str] = None, + messages: Optional[Messages] = None + ): + """Initialize bidirectional agent with required model and optional configuration. + + Args: + model: BidirectionalModel instance supporting streaming sessions. + tools: Optional list of tools available to the model. + system_prompt: Optional system prompt for conversations. + messages: Optional conversation history to initialize with. + """ + self.model = model + self.system_prompt = system_prompt + self.messages = messages or [] + + # Initialize tool registry using existing Strands infrastructure + self.tool_registry = ToolRegistry() + if tools: + self.tool_registry.process_tools(tools) + self.tool_registry.initialize_tools() + + # Initialize tool executor for concurrent execution + from strands.tools.executors import ConcurrentToolExecutor + self.tool_executor = ConcurrentToolExecutor() + + # Session management + self._session = None + self._output_queue = asyncio.Queue() + + async def start_conversation(self) -> None: + """Initialize persistent bidirectional session for real-time interaction. + + Creates provider-specific session and starts concurrent background tasks + for model events, tool execution, and session lifecycle management. + + Raises: + ValueError: If conversation already active. + ConnectionError: If session creation fails. + """ + if self._session and self._session.active: + raise ValueError("Conversation already active. Call end_conversation() first.") + + log_flow("conversation_start", "initializing session") + self._session = await start_bidirectional_connection(self) + log_event("conversation_ready") + + async def send_text(self, text: str) -> None: + """Send text input during active session without interrupting model generation. + + Args: + text: Text message to send to the model. + + Raises: + ValueError: If no active session. + """ + self._validate_active_session() + log_event("text_sent", length=len(text)) + await self._session.model_session.send_text_content(text) + + async def send_audio(self, audio_input: AudioInputEvent) -> None: + """Send audio input during active session for real-time speech interaction. + + Args: + audio_input: AudioInputEvent containing audio data and configuration. + + Raises: + ValueError: If no active session. + """ + self._validate_active_session() + await self._session.model_session.send_audio_content(audio_input) + + async def receive(self) -> AsyncIterable[BidirectionalStreamEvent]: + """Receive output events from the model including audio, text. + + Provides access to model output events processed by background tasks. + Events include audio output, text responses, tool calls, and session updates. + + Yields: + BidirectionalStreamEvent: Events from the model session. + """ + while self._session and self._session.active: + try: + event = await asyncio.wait_for(self._output_queue.get(), timeout=0.1) + yield event + except asyncio.TimeoutError: + continue + + async def interrupt(self) -> None: + """Interrupt current model generation and switch to listening mode. + + Sends interruption signal to immediately stop generation and clear + pending audio output for responsive conversational experience. + + Raises: + ValueError: If no active session. + """ + self._validate_active_session() + await self._session.model_session.send_interrupt() + + async def end_conversation(self) -> None: + """End session and cleanup resources including background tasks. + + Performs graceful session termination with proper resource cleanup + including background task cancellation and connection closure. + """ + if self._session: + await stop_bidirectional_connection(self._session) + self._session = None + + def _validate_active_session(self) -> None: + """Validate that an active session exists. + + Raises: + ValueError: If no active session. + """ + if not self._session or not self._session.active: + raise ValueError("No active conversation. Call start_conversation() first.") + diff --git a/src/strands/experimental/bidirectional_streaming/event_loop/__init__.py b/src/strands/experimental/bidirectional_streaming/event_loop/__init__.py new file mode 100644 index 000000000..24080b703 --- /dev/null +++ b/src/strands/experimental/bidirectional_streaming/event_loop/__init__.py @@ -0,0 +1,2 @@ +"""Bidirectional streaming event loop package.""" +# Event Loop package \ No newline at end of file diff --git a/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py b/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py new file mode 100644 index 000000000..2164115d8 --- /dev/null +++ b/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py @@ -0,0 +1,539 @@ +"""Bidirectional session management for concurrent streaming conversations. + +SESSION PURPOSE: +--------------- +Session wrapper for bidirectional communication that manages concurrent tasks for +model events, tool execution, and audio processing while providing simple interface +for Agent interaction. + +CONCURRENT ARCHITECTURE: +----------------------- +Unlike existing event_loop_cycle() that processes events sequentially where tool +execution blocks conversation, this module coordinates concurrent tasks through +asyncio queues and background task management. +""" + +import asyncio +import json +import logging +import traceback +import uuid +from typing import Any, Dict + +from strands.tools._validator import validate_and_prepare_tools +from strands.types.content import Message +from strands.types.tools import ToolResult, ToolUse + +from ..models.bidirectional_model import BidirectionalModelSession +from ..utils.debug import log_event, log_flow + +logger = logging.getLogger(__name__) + +# Session constants +TOOL_QUEUE_TIMEOUT = 0.5 +SUPERVISION_INTERVAL = 0.1 + + +class BidirectionalConnection: + """Session wrapper for bidirectional communication. + + Manages concurrent tasks for model events, tool execution, and audio processing + while providing simple interface for Agent interaction. + """ + + def __init__(self, model_session: BidirectionalModelSession, agent): + """Initialize session with model session and agent reference. + + Args: + model_session: Provider-specific bidirectional model session. + agent: BidirectionalAgent instance for tool registry access. + """ + self.model_session = model_session + self.agent = agent + self.active = True + + # Background processing coordination + self.background_tasks = [] + self.tool_queue = asyncio.Queue() + self.audio_output_queue = asyncio.Queue() + + # Task management for cleanup + self.pending_tool_tasks: Dict[str, asyncio.Task] = {} + + # Interruption handling (model-agnostic) + self.interrupted = False + +async def start_bidirectional_connection(agent) -> BidirectionalConnection: + """Initialize bidirectional session with concurrent background tasks. + + Creates provider-specific session and starts concurrent tasks for model events, + tool execution, and session lifecycle management. + + Args: + agent: BidirectionalAgent instance. + + Returns: + BidirectionalConnection: Active session with background tasks running. + """ + log_flow("session_start", "initializing model session") + + # Create provider-specific session + model_session = await agent.model.create_bidirectional_connection( + system_prompt=agent.system_prompt, + tools=agent.tool_registry.get_all_tool_specs(), + messages=agent.messages + ) + + # Create session wrapper for background processing + session = BidirectionalConnection(model_session=model_session, agent=agent) + + # Start concurrent background processors IMMEDIATELY after session creation + # This is critical - Nova Sonic needs response processing during initialization + log_flow("background_tasks", "starting processors") + session.background_tasks = [ + asyncio.create_task(_process_model_events(session)), # Handle model responses + asyncio.create_task(_process_tool_execution(session)) # Execute tools concurrently + ] + + # Start main coordination cycle + session.main_cycle_task = asyncio.create_task( + bidirectional_event_loop_cycle(session) + ) + + # Give background tasks a moment to start + await asyncio.sleep(0.1) + log_event("session_ready", tasks=len(session.background_tasks)) + + return session + + +async def stop_bidirectional_connection(session: BidirectionalConnection) -> None: + """End session and cleanup resources including background tasks. + + Args: + session: BidirectionalConnection to cleanup. + """ + if not session.active: + return + + log_flow("session_cleanup", "starting") + session.active = False + + # Cancel pending tool tasks + for _, task in session.pending_tool_tasks.items(): + if not task.done(): + task.cancel() + + # Cancel background tasks + for task in session.background_tasks: + if not task.done(): + task.cancel() + + # Cancel main cycle task + if hasattr(session, 'main_cycle_task') and not session.main_cycle_task.done(): + session.main_cycle_task.cancel() + + # Wait for tasks to complete + all_tasks = session.background_tasks + list(session.pending_tool_tasks.values()) + if hasattr(session, 'main_cycle_task'): + all_tasks.append(session.main_cycle_task) + + if all_tasks: + await asyncio.gather(*all_tasks, return_exceptions=True) + + # Close model session + await session.model_session.close() + log_event("session_closed") + + +async def bidirectional_event_loop_cycle(session: BidirectionalConnection) -> None: + """Main bidirectional event loop coordinator - runs continuously during session. + + Coordinates background tasks and manages session lifecycle. Unlike the + sequential event_loop_cycle() that processes events one by one, this coordinator + manages concurrent tasks and session state. + + Args: + session: BidirectionalConnection to coordinate. + """ + while session.active: + try: + # Check if background processors are still running + if all(task.done() for task in session.background_tasks): + log_event("session_end", reason="all_processors_completed") + session.active = False + break + + # Check for failed background tasks + for i, task in enumerate(session.background_tasks): + if task.done() and not task.cancelled(): + exception = task.exception() + if exception: + log_event("session_error", processor=i, error=str(exception)) + session.active = False + raise exception + + # Brief pause before next supervision check + await asyncio.sleep(SUPERVISION_INTERVAL) + + except asyncio.CancelledError: + break + except Exception as e: + log_event("event_loop_error", error=str(e)) + session.active = False + raise + + +async def _handle_interruption(session: BidirectionalConnection) -> None: + """Handle interruption detection with comprehensive task cancellation. + + Sets interruption flag, cancels pending tool tasks, and aggressively + clears audio output queue following Nova Sonic example patterns. + + Args: + session: BidirectionalConnection to handle interruption for. + """ + log_event("interruption_detected") + session.interrupted = True + + # 🔥 CANCEL ALL PENDING TOOL TASKS (Nova Sonic pattern) + cancelled_tools = 0 + for task_id, task in list(session.pending_tool_tasks.items()): + if not task.done(): + task.cancel() + cancelled_tools += 1 + log_event("tool_task_cancelled", task_id=task_id) + + if cancelled_tools > 0: + log_event("tool_tasks_cancelled", count=cancelled_tools) + + # 🔥 AGGRESSIVELY CLEAR AUDIO OUTPUT QUEUE (Nova Sonic pattern) + cleared_count = 0 + while True: + try: + session.audio_output_queue.get_nowait() + cleared_count += 1 + except asyncio.QueueEmpty: + break + + # Also clear the agent's audio output queue if it exists + if hasattr(session.agent, '_output_queue'): + audio_cleared = 0 + # Create a temporary list to hold non-audio events + temp_events = [] + try: + while True: + event = session.agent._output_queue.get_nowait() + if event.get("audioOutput"): + audio_cleared += 1 + else: + # Keep non-audio events + temp_events.append(event) + except asyncio.QueueEmpty: + pass + + # Put back non-audio events + for event in temp_events: + session.agent._output_queue.put_nowait(event) + + if audio_cleared > 0: + log_event("agent_audio_queue_cleared", count=audio_cleared) + + if cleared_count > 0: + log_event("session_audio_queue_cleared", count=cleared_count) + + # Brief sleep to allow audio system to settle (matches Nova Sonic timing) + await asyncio.sleep(0.05) + + # Reset interruption flag after clearing (automatic recovery) + session.interrupted = False + log_event("interruption_handled", tools_cancelled=cancelled_tools, audio_cleared=cleared_count) + + +async def _process_model_events(session: BidirectionalConnection) -> None: + """Process model events using existing Strands event types. + + This background task handles all model responses and converts + them to existing StreamEvent format for integration with Strands. + + Args: + session: BidirectionalConnection containing model session. + """ + log_flow("model_events", "processor started") + try: + async for provider_event in session.model_session.receive_events(): + if not session.active: + break + + # Convert provider events to Strands format + strands_event = _convert_to_strands_event(provider_event) + + # Handle interruption detection (multiple patterns) + if strands_event.get("interruptionDetected"): + log_event("interruption_forwarded") + await _handle_interruption(session) + # Forward interruption event to agent for application-level handling + await session.agent._output_queue.put(strands_event) + continue + + # Check for text-based interruption (Nova Sonic pattern) + if strands_event.get("textOutput"): + text_content = strands_event["textOutput"].get("content", "") + if '{ "interrupted" : true }' in text_content: + log_event("text_interruption_detected") + await _handle_interruption(session) + # Still forward the text event + await session.agent._output_queue.put(strands_event) + continue + + # Queue tool requests for concurrent execution + if strands_event.get("toolUse"): + log_event("tool_queued", name=strands_event["toolUse"].get("name")) + await session.tool_queue.put(strands_event["toolUse"]) + continue + + # Send output events to Agent for receive() method + if strands_event.get("audioOutput") or strands_event.get("textOutput"): + await session.agent._output_queue.put(strands_event) + + # Update Agent conversation history using existing patterns + if strands_event.get("messageStop"): + log_event("message_added_to_history") + session.agent.messages.append(strands_event["messageStop"]["message"]) + + except Exception as e: + log_event("model_events_error", error=str(e)) + traceback.print_exc() + finally: + log_flow("model_events", "processor stopped") + + +async def _process_tool_execution(session: BidirectionalConnection) -> None: + """Execute tools concurrently using existing Strands infrastructure with barge-in support. + + This background task manages tool execution without blocking + model event processing or user interaction. Includes proper + task cleanup and cancellation handling. + + Args: + session: BidirectionalConnection containing tool queue. + """ + log_flow("tool_execution", "processor started") + while session.active: + try: + tool_use = await asyncio.wait_for(session.tool_queue.get(), timeout=TOOL_QUEUE_TIMEOUT) + log_event("tool_execution_started", name=tool_use.get("name"), id=tool_use.get("toolUseId")) + + if not session.active: + break + + task_id = str(uuid.uuid4()) + task = asyncio.create_task(_execute_tool_with_strands(session, tool_use)) + session.pending_tool_tasks[task_id] = task + + # 🔥 ADD CLEANUP CALLBACK (Nova Sonic pattern) + def cleanup_task(completed_task): + try: + # Remove from pending tasks + if task_id in session.pending_tool_tasks: + del session.pending_tool_tasks[task_id] + + # Log completion status + if completed_task.cancelled(): + log_event("tool_task_cleanup_cancelled", task_id=task_id) + elif completed_task.exception(): + log_event("tool_task_cleanup_error", task_id=task_id, + error=str(completed_task.exception())) + else: + log_event("tool_task_cleanup_success", task_id=task_id) + except Exception as e: + log_event("tool_task_cleanup_failed", task_id=task_id, error=str(e)) + + task.add_done_callback(cleanup_task) + + except asyncio.TimeoutError: + if not session.active: + break + # 🔥 PERIODIC CLEANUP OF COMPLETED TASKS + completed_tasks = [ + task_id for task_id, task in session.pending_tool_tasks.items() + if task.done() + ] + for task_id in completed_tasks: + if task_id in session.pending_tool_tasks: + del session.pending_tool_tasks[task_id] + + if completed_tasks: + log_event("periodic_task_cleanup", count=len(completed_tasks)) + + continue + except Exception as e: + log_event("tool_execution_error", error=str(e)) + if not session.active: + break + + log_flow("tool_execution", "processor stopped") + + +def _convert_to_strands_event(provider_event: Dict) -> Dict: + """Pass-through for events already normalized by provider sessions. + + Providers convert their raw events to standard format before reaching here. + This just validates and passes through the normalized events. + + Args: + provider_event: Already normalized event from provider session. + + Returns: + Dict: The same event, validated and passed through. + """ + # Basic validation - ensure we have a dict + if not isinstance(provider_event, dict): + return {} + + # Pass through - conversion already done by provider session + return provider_event + + +async def _execute_tool_with_strands(session: BidirectionalConnection, tool_use: Dict) -> None: + """Execute tool using existing Strands infrastructure with barge-in support. + + Model-agnostic tool execution that uses existing Strands tool system, + handles interruption during execution, and delegates result formatting + to provider-specific session. + + Args: + session: BidirectionalConnection for context. + tool_use: Tool use event to execute. + """ + tool_name = tool_use.get('name') + tool_id = tool_use.get('toolUseId') + + try: + # 🔥 CHECK FOR INTERRUPTION BEFORE STARTING (Nova Sonic pattern) + if session.interrupted or not session.active: + log_event("tool_execution_cancelled_before_start", name=tool_name, id=tool_id) + return + + # Create message structure for existing tool system + tool_message: Message = { + "role": "assistant", + "content": [{"toolUse": tool_use}] + } + + tool_uses: list[ToolUse] = [] + tool_results: list[ToolResult] = [] + invalid_tool_use_ids: list[str] = [] + + # Validate using existing Strands validation + validate_and_prepare_tools(tool_message, tool_uses, tool_results, invalid_tool_use_ids) + + # Filter valid tool uses + valid_tool_uses = [tu for tu in tool_uses if tu.get("toolUseId") not in invalid_tool_use_ids] + + if not valid_tool_uses: + log_event("tool_validation_failed", name=tool_name, id=tool_id) + return + + # Execute tools directly (simpler approach for bidirectional) + for tool_use in valid_tool_uses: + # 🔥 CHECK FOR INTERRUPTION DURING EXECUTION + if session.interrupted or not session.active: + log_event("tool_execution_cancelled_during", name=tool_name, id=tool_id) + return + + tool_func = session.agent.tool_registry.registry.get(tool_use["name"]) + + if tool_func: + try: + actual_func = _extract_callable_function(tool_func) + + # 🔥 WRAP TOOL EXECUTION IN CANCELLATION CHECK + # For async tools, we could wrap with asyncio.wait_for with cancellation + # For sync tools, we execute directly but check interruption after + result = actual_func(**tool_use.get("input", {})) + + # 🔥 CHECK FOR INTERRUPTION AFTER TOOL EXECUTION + if session.interrupted or not session.active: + log_event("tool_result_discarded_interruption", name=tool_name, id=tool_id) + return + + tool_result = _create_success_result(tool_use["toolUseId"], result) + tool_results.append(tool_result) + + except asyncio.CancelledError: + # Tool was cancelled due to interruption + log_event("tool_execution_cancelled", name=tool_name, id=tool_id) + return + except Exception as e: + # 🔥 CHECK FOR INTERRUPTION EVEN ON ERROR + if session.interrupted or not session.active: + log_event("tool_error_discarded_interruption", name=tool_name, id=tool_id) + return + + log_event("tool_execution_failed", name=tool_name, error=str(e)) + tool_result = _create_error_result(tool_use["toolUseId"], str(e)) + tool_results.append(tool_result) + else: + log_event("tool_not_found", name=tool_name) + + # 🔥 FINAL INTERRUPTION CHECK BEFORE SENDING RESULTS + if session.interrupted or not session.active: + log_event("tool_results_discarded_interruption", name=tool_name, count=len(tool_results)) + return + + # Send results through provider-specific session + for result in tool_results: + await session.model_session.send_tool_result( + tool_use.get("toolUseId"), + result + ) + + log_event("tool_execution_completed", name=tool_name, results=len(tool_results)) + + except asyncio.CancelledError: + # Task was cancelled due to interruption - this is expected behavior + log_event("tool_task_cancelled_gracefully", name=tool_name, id=tool_id) + raise # Re-raise to properly handle cancellation + except Exception as e: + log_event("tool_execution_error", name=tool_use.get('name'), error=str(e)) + + # Only send error if not interrupted + if not session.interrupted and session.active: + try: + await session.model_session.send_tool_error( + tool_use.get("toolUseId"), + str(e) + ) + except Exception as send_error: + log_event("tool_error_send_failed", error=str(send_error)) + + +def _extract_callable_function(tool_func): + """Extract the callable function from different tool object types.""" + if hasattr(tool_func, '_tool_func'): + return tool_func._tool_func + elif hasattr(tool_func, 'func'): + return tool_func.func + elif callable(tool_func): + return tool_func + else: + raise ValueError(f"Tool function not callable: {type(tool_func).__name__}") + + +def _create_success_result(tool_use_id: str, result) -> Dict[str, Any]: + """Create a successful tool result.""" + return { + "toolUseId": tool_use_id, + "status": "success", + "content": [{"text": json.dumps(result)}] + } + + +def _create_error_result(tool_use_id: str, error: str) -> Dict[str, Any]: + """Create an error tool result.""" + return { + "toolUseId": tool_use_id, + "status": "error", + "content": [{"text": f"Error: {error}"}] + } \ No newline at end of file diff --git a/src/strands/experimental/bidirectional_streaming/models/__init__.py b/src/strands/experimental/bidirectional_streaming/models/__init__.py new file mode 100644 index 000000000..b2b10a5f2 --- /dev/null +++ b/src/strands/experimental/bidirectional_streaming/models/__init__.py @@ -0,0 +1,2 @@ +"""Bidirectional streaming models package.""" +# Models package \ No newline at end of file diff --git a/src/strands/experimental/bidirectional_streaming/models/bidirectional_model.py b/src/strands/experimental/bidirectional_streaming/models/bidirectional_model.py new file mode 100644 index 000000000..32727105d --- /dev/null +++ b/src/strands/experimental/bidirectional_streaming/models/bidirectional_model.py @@ -0,0 +1,115 @@ +"""Bidirectional model interface for real-time streaming conversations. + +INTERFACE PURPOSE: +----------------- +Declares bidirectional capabilities separate from existing Model hierarchy to maintain +clean separation of concerns. Models choose to implement this interface explicitly +for bidirectional streaming support. + +PROVIDER ABSTRACTION: +-------------------- +Abstracts incompatible initialization patterns: Nova Sonic's event-driven sequences, +Google's WebSocket setup, OpenAI's dual protocol support. Normalizes different tool +calling approaches and handles provider-specific session management with varying +time limits and connection patterns. + +SESSION-BASED APPROACH: +---------------------- +Unlike existing Model interface's stateless request-response pattern where each +stream() call processes complete messages independently, BidirectionalModel introduces +session-based approach where create_bidirectional_connection() establishes persistent +connections supporting real-time bidirectional communication during active generation. +""" + +import abc +import logging +from typing import Any, AsyncIterable, Dict, List, Optional + +from ....types.content import Messages +from ....types.tools import ToolSpec +from ..types.bidirectional_streaming import AudioInputEvent + +logger = logging.getLogger(__name__) + +class BidirectionalModelSession(abc.ABC): + """Model-specific session interface for bidirectional communication.""" + + @abc.abstractmethod + async def receive_events(self) -> AsyncIterable[Dict[str, Any]]: + """Receive events from model in provider-agnostic format. + + Normalizes different provider event formats so the event loop + can process all providers uniformly. + """ + raise NotImplementedError + + @abc.abstractmethod + async def send_audio_content(self, audio_input: AudioInputEvent) -> None: + """Send audio content to model during session. + + Manages complex audio encoding and provider-specific event sequences + while presenting simple AudioInputEvent interface to Agent. + """ + raise NotImplementedError + + @abc.abstractmethod + async def send_text_content(self, text: str, **kwargs) -> None: + """Send text content processed concurrently with ongoing generation. + + Enables natural interruption and follow-up questions without session restart. + """ + raise NotImplementedError + + @abc.abstractmethod + async def send_interrupt(self) -> None: + """Send interruption signal to immediately stop generation. + + Critical for responsive conversational experiences where users + can naturally interrupt mid-response. + """ + raise NotImplementedError + + @abc.abstractmethod + async def send_tool_result(self, tool_use_id: str, result: Dict[str, Any]) -> None: + """Send tool execution result to model in provider-specific format. + + Each provider handles result formatting according to their protocol: + - Nova Sonic: toolResult events with JSON content + - Google Live API: toolResponse with specific structure + - OpenAI Realtime: function call responses with call_id correlation + """ + raise NotImplementedError + + @abc.abstractmethod + async def send_tool_error(self, tool_use_id: str, error: str) -> None: + """Send tool execution error to model in provider-specific format.""" + raise NotImplementedError + + @abc.abstractmethod + async def close(self) -> None: + """Close session and cleanup resources with graceful termination.""" + raise NotImplementedError + + +class BidirectionalModel(abc.ABC): + """Interface for models that support bidirectional streaming. + + Separate from Model to maintain clean separation of concerns. + Models choose to implement this interface explicitly. + """ + + @abc.abstractmethod + async def create_bidirectional_connection( + self, + system_prompt: Optional[str] = None, + tools: Optional[List[ToolSpec]] = None, + messages: Optional[Messages] = None, + **kwargs + ) -> BidirectionalModelSession: + """Create bidirectional session with model-specific implementation. + + Abstracts complex provider-specific initialization while presenting + uniform interface to Agent. + """ + raise NotImplementedError + diff --git a/src/strands/experimental/bidirectional_streaming/models/novasonic.py b/src/strands/experimental/bidirectional_streaming/models/novasonic.py new file mode 100644 index 000000000..ba71cd4d3 --- /dev/null +++ b/src/strands/experimental/bidirectional_streaming/models/novasonic.py @@ -0,0 +1,777 @@ +"""Nova Sonic bidirectional model provider for real-time streaming conversations. + +PROVIDER PURPOSE: +---------------- +Implements BidirectionalModel and BidirectionalModelSession interfaces for Nova Sonic, +handling the complex three-tier event management and structured event cleanup sequences +required by Nova Sonic's InvokeModelWithBidirectionalStream protocol. + +NOVA SONIC SPECIFICS: +-------------------- +- Requires hierarchical event sequences: sessionStart → promptStart → content streaming +- Uses hex-encoded base64 audio format that needs conversion to raw bytes +- Implements toolUse/toolResult with content containers and identifier tracking +- Manages 8-minute session limits with proper cleanup sequences +- Handles stopReason: "INTERRUPTED" events for interruption detection + +INTEGRATION APPROACH: +-------------------- +Adapts existing Nova Sonic sample patterns to work with Strands bidirectional +infrastructure while maintaining provider-specific protocol requirements. +""" + +import asyncio +import base64 +import json +import logging +import time +import traceback +import uuid +from typing import Any, AsyncIterable, Dict, List, Optional + +from aws_sdk_bedrock_runtime.client import BedrockRuntimeClient, InvokeModelWithBidirectionalStreamOperationInput +from aws_sdk_bedrock_runtime.config import Config, HTTPAuthSchemeResolver, SigV4AuthScheme +from aws_sdk_bedrock_runtime.models import BidirectionalInputPayloadPart, InvokeModelWithBidirectionalStreamInputChunk +from smithy_aws_core.credentials_resolvers.environment import EnvironmentCredentialsResolver + +from ....types.content import Messages +from ....types.tools import ToolSpec, ToolUse +from ..types.bidirectional_streaming import ( + AudioInputEvent, + AudioOutputEvent, + BidirectionalConnectionEndEvent, + BidirectionalConnectionStartEvent, + InterruptionDetectedEvent, + TextOutputEvent, +) +from ..utils.debug import log_event, log_flow, time_it_async +from .bidirectional_model import BidirectionalModel, BidirectionalModelSession + +logger = logging.getLogger(__name__) + +# Nova Sonic configuration constants +NOVA_INFERENCE_CONFIG = { + "maxTokens": 1024, + "topP": 0.9, + "temperature": 0.7 +} + +NOVA_AUDIO_INPUT_CONFIG = { + "mediaType": "audio/lpcm", + "sampleRateHertz": 16000, + "sampleSizeBits": 16, + "channelCount": 1, + "audioType": "SPEECH", + "encoding": "base64" +} + +NOVA_AUDIO_OUTPUT_CONFIG = { + "mediaType": "audio/lpcm", + "sampleRateHertz": 24000, + "sampleSizeBits": 16, + "channelCount": 1, + "voiceId": "matthew", + "encoding": "base64", + "audioType": "SPEECH" +} + +NOVA_TEXT_CONFIG = {"mediaType": "text/plain"} +NOVA_TOOL_CONFIG = {"mediaType": "application/json"} + +# Timing constants +SILENCE_THRESHOLD = 2.0 +EVENT_DELAY = 0.1 +RESPONSE_TIMEOUT = 1.0 + + +class NovaSonicSession(BidirectionalModelSession): + """Nova Sonic session handling protocol-specific details.""" + + def __init__(self, stream, config: Dict[str, Any]): + """Initialize Nova Sonic session. + + Args: + stream: Nova Sonic bidirectional stream. + config: Model configuration. + """ + self.stream = stream + self.config = config + self.prompt_name = str(uuid.uuid4()) + self._active = True + + # Nova Sonic requires unique content names + self.audio_content_name = str(uuid.uuid4()) + self.text_content_name = str(uuid.uuid4()) + + # Audio session state + self.audio_session_active = False + self.last_audio_time = None + self.silence_threshold = SILENCE_THRESHOLD + self.silence_task = None + + # Validate stream + if not stream: + logger.error("Stream is None") + raise ValueError("Stream cannot be None") + + logger.debug("Nova Sonic session initialized with prompt: %s", self.prompt_name) + + async def initialize( + self, + system_prompt: Optional[str] = None, + tools: Optional[List[ToolSpec]] = None, + messages: Optional[Messages] = None + ) -> None: + """Initialize Nova Sonic session with required protocol sequence.""" + try: + system_prompt = system_prompt or "You are a helpful assistant. Keep responses brief." + + init_events = self._build_initialization_events(system_prompt, tools or [], messages) + + log_flow("nova_init", f"sending {len(init_events)} events") + await self._send_initialization_events(init_events) + + log_event("nova_session_initialized") + self._response_task = asyncio.create_task(self._process_responses()) + + except Exception as e: + logger.error("Error during Nova Sonic initialization: %s", e) + raise + + def _build_initialization_events(self, system_prompt: str, tools: List[ToolSpec], + messages: Optional[Messages]) -> List[str]: + """Build the sequence of initialization events.""" + events = [ + self._get_session_start_event(), + self._get_prompt_start_event(tools) + ] + + events.extend(self._get_system_prompt_events(system_prompt)) + + # Message history would be processed here if needed in the future + # Currently not implemented as it's not used in the existing test cases + + return events + + async def _send_initialization_events(self, events: List[str]) -> None: + """Send initialization events with required delays.""" + for i, event in enumerate(events): + await time_it_async(f"send_init_event_{i+1}", lambda: self._send_nova_event(event)) + await asyncio.sleep(EVENT_DELAY) + + async def _process_responses(self) -> None: + """Process Nova Sonic responses continuously.""" + log_flow("nova_responses", "processor started") + + try: + while self._active: + try: + output = await asyncio.wait_for(self.stream.await_output(), timeout=RESPONSE_TIMEOUT) + result = await output[1].receive() + + if result.value and result.value.bytes_: + await self._handle_response_data(result.value.bytes_.decode('utf-8')) + + except asyncio.TimeoutError: + await asyncio.sleep(0.1) + continue + except Exception as e: + log_event("nova_response_error", error=str(e)) + await asyncio.sleep(0.1) + continue + + except Exception as e: + log_event("nova_fatal_error", error=str(e)) + finally: + log_flow("nova_responses", "processor stopped") + + async def _handle_response_data(self, response_data: str) -> None: + """Handle decoded response data from Nova Sonic.""" + try: + json_data = json.loads(response_data) + + if 'event' in json_data: + nova_event = json_data['event'] + self._log_event_type(nova_event) + + if not hasattr(self, '_event_queue'): + self._event_queue = asyncio.Queue() + + await self._event_queue.put(nova_event) + except json.JSONDecodeError as e: + log_event("nova_json_error", error=str(e)) + + def _log_event_type(self, nova_event: Dict[str, Any]) -> None: + """Log specific Nova Sonic event types for debugging.""" + if 'usageEvent' in nova_event: + log_event("nova_usage", usage=nova_event['usageEvent']) + elif 'textOutput' in nova_event: + log_event("nova_text_output") + elif 'toolUse' in nova_event: + tool_use = nova_event['toolUse'] + log_event("nova_tool_use", name=tool_use['toolName'], id=tool_use['toolUseId']) + elif 'audioOutput' in nova_event: + audio_content = nova_event['audioOutput']['content'] + audio_bytes = base64.b64decode(audio_content) + log_event("nova_audio_output", bytes=len(audio_bytes)) + + async def receive_events(self) -> AsyncIterable[Dict[str, Any]]: + """Receive Nova Sonic events and convert to provider-agnostic format.""" + if not self.stream: + logger.error("Stream is None") + return + + log_flow("nova_events", "starting event stream") + + # Emit session start event to Strands event system + session_start: BidirectionalConnectionStartEvent = { + "sessionId": self.prompt_name, + "metadata": {"provider": "nova_sonic", "model_id": self.config.get("model_id")} + } + yield { + "BidirectionalConnectionStart": session_start + } + + # Initialize event queue if not already done + if not hasattr(self, '_event_queue'): + self._event_queue = asyncio.Queue() + + try: + while self._active: + try: + # Get events from the queue populated by _process_responses + nova_event = await asyncio.wait_for(self._event_queue.get(), timeout=1.0) + + # Convert to provider-agnostic format + provider_event = self._convert_nova_event(nova_event) + if provider_event: + yield provider_event + + except asyncio.TimeoutError: + # No events in queue - continue waiting + continue + + except Exception as e: + logger.error("Error receiving Nova Sonic event: %s", e) + logger.error(traceback.format_exc()) + finally: + # Emit session end event when exiting + session_end: BidirectionalConnectionEndEvent = { + "sessionId": self.prompt_name, + "reason": "session_complete", + "metadata": {"provider": "nova_sonic"} + } + yield { + "BidirectionalConnectionEnd": session_end + } + + async def start_audio_session(self) -> None: + """Start audio input session (call once before sending audio chunks).""" + if self.audio_session_active: + return + + log_event("nova_audio_session_start") + + audio_content_start = json.dumps({ + "event": { + "contentStart": { + "promptName": self.prompt_name, + "contentName": self.audio_content_name, + "type": "AUDIO", + "interactive": True, + "role": "USER", + "audioInputConfiguration": NOVA_AUDIO_INPUT_CONFIG + } + } + }) + + await self._send_nova_event(audio_content_start) + self.audio_session_active = True + + async def send_audio_content(self, audio_input: AudioInputEvent) -> None: + """Send audio using Nova Sonic protocol-specific format.""" + if not self._active: + return + + # Start audio session if not already active + if not self.audio_session_active: + await self.start_audio_session() + + # Update last audio time and cancel any pending silence task + self.last_audio_time = time.time() + if self.silence_task and not self.silence_task.done(): + self.silence_task.cancel() + + # Convert audio to Nova Sonic base64 format + nova_audio_data = base64.b64encode(audio_input["audioData"]).decode('utf-8') + + # Send audio input event + audio_event = json.dumps({ + "event": { + "audioInput": { + "promptName": self.prompt_name, + "contentName": self.audio_content_name, + "content": nova_audio_data + } + } + }) + + await self._send_nova_event(audio_event) + + # Start silence detection task + self.silence_task = asyncio.create_task(self._check_silence()) + + async def _check_silence(self): + """Check for silence and automatically end audio session.""" + try: + await asyncio.sleep(self.silence_threshold) + if self.audio_session_active and self.last_audio_time: + elapsed = time.time() - self.last_audio_time + if elapsed >= self.silence_threshold: + log_event("nova_silence_detected", elapsed=elapsed) + await self.end_audio_input() + except asyncio.CancelledError: + pass + + async def end_audio_input(self) -> None: + """End current audio input session to trigger Nova Sonic processing.""" + if not self.audio_session_active: + return + + log_event("nova_audio_session_end") + + audio_content_end = json.dumps({ + "event": { + "contentEnd": { + "promptName": self.prompt_name, + "contentName": self.audio_content_name + } + } + }) + + await self._send_nova_event(audio_content_end) + self.audio_session_active = False + + async def send_text_content(self, text: str, **kwargs) -> None: + """Send text content using Nova Sonic format.""" + if not self._active: + return + + content_name = str(uuid.uuid4()) + events = [ + self._get_text_content_start_event(content_name), + self._get_text_input_event(content_name, text), + self._get_content_end_event(content_name) + ] + + for event in events: + await self._send_nova_event(event) + + async def send_interrupt(self) -> None: + """Send interruption signal to Nova Sonic.""" + if not self._active: + return + + # Nova Sonic handles interruption through special input events + interrupt_event = { + "event": { + "audioInput": { + "promptName": self.prompt_name, + "contentName": self.audio_content_name, + "stopReason": "INTERRUPTED" + } + } + } + await self._send_nova_event(interrupt_event) + + async def send_tool_result(self, tool_use_id: str, result: Dict[str, Any]) -> None: + """Send tool result using Nova Sonic toolResult format.""" + if not self._active: + return + + log_event("nova_tool_result_send", id=tool_use_id) + content_name = str(uuid.uuid4()) + events = [ + self._get_tool_content_start_event(content_name, tool_use_id), + self._get_tool_result_event(content_name, result), + self._get_content_end_event(content_name) + ] + + for i, event in enumerate(events): + await time_it_async(f"send_tool_event_{i+1}", lambda: self._send_nova_event(event)) + + async def send_tool_error(self, tool_use_id: str, error: str) -> None: + """Send tool error using Nova Sonic format.""" + log_event("nova_tool_error_send", id=tool_use_id, error=error) + error_result = {"error": error} + await self.send_tool_result(tool_use_id, error_result) + + async def close(self) -> None: + """Close Nova Sonic session with proper cleanup sequence.""" + if not self._active: + return + + log_flow("nova_cleanup", "starting session close") + self._active = False + + # Cancel response processing task if running + if hasattr(self, '_response_task') and not self._response_task.done(): + self._response_task.cancel() + try: + await self._response_task + except asyncio.CancelledError: + pass + + try: + # End audio session if active + if self.audio_session_active: + await self.end_audio_input() + + # Send cleanup events + cleanup_events = [ + self._get_prompt_end_event(), + self._get_session_end_event() + ] + + for event in cleanup_events: + try: + await self._send_nova_event(event) + except Exception as e: + logger.warning("Error during Nova Sonic cleanup: %s", e) + + # Close stream + try: + await self.stream.input_stream.close() + except Exception as e: + logger.warning("Error closing Nova Sonic stream: %s", e) + + except Exception as e: + log_event("nova_cleanup_error", error=str(e)) + finally: + log_event("nova_session_closed") + + def _convert_nova_event(self, nova_event: Dict[str, Any]) -> Optional[Dict[str, Any]]: + """Convert Nova Sonic events to provider-agnostic format.""" + # Handle audio output + if "audioOutput" in nova_event: + audio_content = nova_event["audioOutput"]["content"] + audio_bytes = base64.b64decode(audio_content) + + audio_output: AudioOutputEvent = { + "audioData": audio_bytes, + "format": "pcm", + "sampleRate": 24000, + "channels": 1, + "encoding": "base64" + } + + return { + "audioOutput": audio_output + } + + # Handle text output + elif "textOutput" in nova_event: + text_content = nova_event["textOutput"]["content"] + # Use stored role from contentStart event, fallback to event role + role = getattr(self, '_current_role', nova_event["textOutput"].get("role", "assistant")) + + # Check for Nova Sonic interruption pattern (matches working sample) + if '{ "interrupted" : true }' in text_content: + log_event("nova_interruption_in_text") + interruption: InterruptionDetectedEvent = { + "reason": "user_input" + } + return { + "interruptionDetected": interruption + } + + # Show transcription for user speech - ALWAYS show these regardless of DEBUG flag + if role == "USER": + print(f"User: {text_content}") + elif role == "ASSISTANT": + print(f"Assistant: {text_content}") + + text_output: TextOutputEvent = { + "text": text_content, + "role": role.lower() + } + + return { + "textOutput": text_output + } + + # Handle tool use + elif "toolUse" in nova_event: + tool_use = nova_event["toolUse"] + + tool_use_event: ToolUse = { + "toolUseId": tool_use["toolUseId"], + "name": tool_use["toolName"], + "input": json.loads(tool_use["content"]) + } + + return { + "toolUse": tool_use_event + } + + # Handle interruption + elif nova_event.get("stopReason") == "INTERRUPTED": + log_event("nova_interruption_stop_reason") + + interruption: InterruptionDetectedEvent = { + "reason": "user_input" + } + + return { + "interruptionDetected": interruption + } + + # Handle usage events (ignore) + elif "usageEvent" in nova_event: + return None + + # Handle content start events (track role) + elif "contentStart" in nova_event: + role = nova_event["contentStart"].get("role", "unknown") + # Store role for subsequent text output events + self._current_role = role + return None + + # Handle other events + else: + return None + + # Nova Sonic event template methods + def _get_session_start_event(self) -> str: + """Generate Nova Sonic session start event.""" + return json.dumps({ + "event": { + "sessionStart": { + "inferenceConfiguration": NOVA_INFERENCE_CONFIG + } + } + }) + + def _get_prompt_start_event(self, tools: List[ToolSpec]) -> str: + """Generate Nova Sonic prompt start event with tool configuration.""" + prompt_start_event = { + "event": { + "promptStart": { + "promptName": self.prompt_name, + "textOutputConfiguration": NOVA_TEXT_CONFIG, + "audioOutputConfiguration": NOVA_AUDIO_OUTPUT_CONFIG + } + } + } + + if tools: + tool_config = self._build_tool_configuration(tools) + prompt_start_event["event"]["promptStart"]["toolUseOutputConfiguration"] = NOVA_TOOL_CONFIG + prompt_start_event["event"]["promptStart"]["toolConfiguration"] = {"tools": tool_config} + + return json.dumps(prompt_start_event) + + def _build_tool_configuration(self, tools: List[ToolSpec]) -> List[Dict]: + """Build tool configuration from tool specs.""" + tool_config = [] + for tool in tools: + input_schema = ({"json": json.dumps(tool['inputSchema']['json'])} + if 'json' in tool['inputSchema'] + else {"json": json.dumps(tool['inputSchema'])}) + + tool_config.append({ + "toolSpec": { + "name": tool["name"], + "description": tool["description"], + "inputSchema": input_schema + } + }) + return tool_config + + def _get_system_prompt_events(self, system_prompt: str) -> List[str]: + """Generate system prompt events.""" + content_name = str(uuid.uuid4()) + return [ + self._get_text_content_start_event(content_name, "SYSTEM"), + self._get_text_input_event(content_name, system_prompt), + self._get_content_end_event(content_name) + ] + + def _get_text_content_start_event(self, content_name: str, role: str = "USER") -> str: + """Generate text content start event.""" + return json.dumps({ + "event": { + "contentStart": { + "promptName": self.prompt_name, + "contentName": content_name, + "type": "TEXT", + "role": role, + "interactive": True, + "textInputConfiguration": NOVA_TEXT_CONFIG + } + } + }) + + def _get_tool_content_start_event(self, content_name: str, tool_use_id: str) -> str: + """Generate tool content start event.""" + return json.dumps({ + "event": { + "contentStart": { + "promptName": self.prompt_name, + "contentName": content_name, + "interactive": False, + "type": "TOOL", + "role": "TOOL", + "toolResultInputConfiguration": { + "toolUseId": tool_use_id, + "type": "TEXT", + "textInputConfiguration": NOVA_TEXT_CONFIG + } + } + } + }) + + def _get_text_input_event(self, content_name: str, text: str) -> str: + """Generate text input event.""" + return json.dumps({ + "event": { + "textInput": { + "promptName": self.prompt_name, + "contentName": content_name, + "content": text + } + } + }) + + def _get_tool_result_event(self, content_name: str, result: Dict[str, Any]) -> str: + """Generate tool result event.""" + return json.dumps({ + "event": { + "toolResult": { + "promptName": self.prompt_name, + "contentName": content_name, + "content": json.dumps(result) + } + } + }) + + def _get_content_end_event(self, content_name: str) -> str: + """Generate content end event.""" + return json.dumps({ + "event": { + "contentEnd": { + "promptName": self.prompt_name, + "contentName": content_name + } + } + }) + + def _get_prompt_end_event(self) -> str: + """Generate prompt end event.""" + return json.dumps({ + "event": { + "promptEnd": { + "promptName": self.prompt_name + } + } + }) + + def _get_session_end_event(self) -> str: + """Generate session end event.""" + return json.dumps({ + "event": { + "sessionEnd": {} + } + }) + + async def _send_nova_event(self, event: str) -> None: + """Send event JSON string to Nova Sonic stream.""" + try: + + # Event is already a JSON string + bytes_data = event.encode('utf-8') + chunk = InvokeModelWithBidirectionalStreamInputChunk( + value=BidirectionalInputPayloadPart(bytes_=bytes_data) + ) + await self.stream.input_stream.send(chunk) + logger.debug("Successfully sent Nova Sonic event") + + except Exception as e: + logger.error("Error sending Nova Sonic event: %s", e) + logger.error("Event was: %s", event) + raise + + +class NovaSonicBidirectionalModel(BidirectionalModel): + """Nova Sonic model implementing bidirectional capabilities.""" + + def __init__(self, model_id: str = "amazon.nova-sonic-v1:0", region: str = "us-east-1", **config): + """Initialize Nova Sonic bidirectional model. + + Args: + model_id: Nova Sonic model identifier. + region: AWS region. + **config: Additional configuration. + """ + self.model_id = model_id + self.region = region + self.config = config + self._client = None + + logger.debug("Nova Sonic bidirectional model initialized: %s", model_id) + + async def create_bidirectional_connection( + self, + system_prompt: Optional[str] = None, + tools: Optional[List[ToolSpec]] = None, + messages: Optional[Messages] = None, + **kwargs + ) -> BidirectionalModelSession: + """Create Nova Sonic bidirectional session.""" + log_flow("nova_session_create", "starting") + + # Initialize client if needed + if not self._client: + await time_it_async("initialize_client", lambda: self._initialize_client()) + + # Start Nova Sonic bidirectional stream + try: + stream = await time_it_async("invoke_model_with_bidirectional_stream", + lambda: self._client.invoke_model_with_bidirectional_stream( + InvokeModelWithBidirectionalStreamOperationInput(model_id=self.model_id) + )) + + # Create and initialize session + session = NovaSonicSession(stream, self.config) + await time_it_async("initialize_session", + lambda: session.initialize(system_prompt, tools, messages)) + + log_event("nova_session_created") + return session + except Exception as e: + log_event("nova_session_create_error", error=str(e)) + logger.error("Failed to create Nova Sonic session: %s", e) + raise + + async def _initialize_client(self) -> None: + """Initialize Nova Sonic client.""" + try: + + config = Config( + endpoint_uri=f"https://bedrock-runtime.{self.region}.amazonaws.com", + region=self.region, + aws_credentials_identity_resolver=EnvironmentCredentialsResolver(), + http_auth_scheme_resolver=HTTPAuthSchemeResolver(), + http_auth_schemes={"aws.auth#sigv4": SigV4AuthScheme()} + ) + + self._client = BedrockRuntimeClient(config=config) + logger.debug("Nova Sonic client initialized") + + except ImportError as e: + logger.error("Nova Sonic dependencies not available: %s", e) + raise + except Exception as e: + logger.error("Error initializing Nova Sonic client: %s", e) + raise + diff --git a/src/strands/experimental/bidirectional_streaming/tests/test_bidirectional_streaming.py b/src/strands/experimental/bidirectional_streaming/tests/test_bidirectional_streaming.py new file mode 100644 index 000000000..f35fd4462 --- /dev/null +++ b/src/strands/experimental/bidirectional_streaming/tests/test_bidirectional_streaming.py @@ -0,0 +1,203 @@ +"""Simple bidirectional streaming test with enhanced interruption support.""" + +import asyncio +import time +import pyaudio + +from src.strands.experimental.bidirectional_streaming.agent.agent import BidirectionalAgent +from src.strands.experimental.bidirectional_streaming.models.novasonic import NovaSonicBidirectionalModel +from strands_tools import calculator + + +async def play(context): + """Play audio output with responsive interruption support.""" + audio = pyaudio.PyAudio() + speaker = audio.open( + channels=1, + format=pyaudio.paInt16, + output=True, + rate=24000, + frames_per_buffer=1024, + ) + + try: + while context["active"]: + try: + # Check for interruption first + if context.get("interrupted", False): + # Clear entire audio queue immediately + while not context["audio_out"].empty(): + try: + context["audio_out"].get_nowait() + except asyncio.QueueEmpty: + break + + context["interrupted"] = False + await asyncio.sleep(0.05) + continue + + # Get next audio data + audio_data = await asyncio.wait_for( + context["audio_out"].get(), + timeout=0.1 + ) + + if audio_data and context["active"]: + chunk_size = 1024 + for i in range(0, len(audio_data), chunk_size): + # Check for interruption before each chunk + if context.get("interrupted", False) or not context["active"]: + break + + end = min(i + chunk_size, len(audio_data)) + chunk = audio_data[i:end] + speaker.write(chunk) + await asyncio.sleep(0.001) + + except asyncio.TimeoutError: + continue # No audio available + except asyncio.QueueEmpty: + await asyncio.sleep(0.01) + except asyncio.CancelledError: + break + + except asyncio.CancelledError: + pass + finally: + speaker.close() + audio.terminate() + + +async def record(context): + """Record audio input from microphone.""" + audio = pyaudio.PyAudio() + microphone = audio.open( + channels=1, + format=pyaudio.paInt16, + frames_per_buffer=1024, + input=True, + rate=16000, + ) + + try: + while context["active"]: + try: + audio_bytes = microphone.read(1024, exception_on_overflow=False) + context["audio_in"].put_nowait(audio_bytes) + await asyncio.sleep(0.01) + except asyncio.CancelledError: + break + except asyncio.CancelledError: + pass + finally: + microphone.close() + audio.terminate() + + +async def receive(agent, context): + """Receive and process events from agent.""" + try: + async for event in agent.receive(): + # Handle audio output + if "audioOutput" in event: + if not context.get("interrupted", False): + context["audio_out"].put_nowait(event["audioOutput"]["audioData"]) + + # Handle interruption events + elif "interruptionDetected" in event: + context["interrupted"] = True + elif "interrupted" in event: + context["interrupted"] = True + + # Handle text output with interruption detection + elif "textOutput" in event: + text_content = event["textOutput"].get("content", "") + role = event["textOutput"].get("role", "unknown") + + # Check for text-based interruption patterns + if '{ "interrupted" : true }' in text_content: + context["interrupted"] = True + elif "interrupted" in text_content.lower(): + context["interrupted"] = True + + # Log text output + if role.upper() == "USER": + print(f"User: {text_content}") + elif role.upper() == "ASSISTANT": + print(f"Assistant: {text_content}") + + except asyncio.CancelledError: + pass + + +async def send(agent, context): + """Send audio input to agent.""" + try: + while time.time() - context["start_time"] < context["duration"]: + try: + audio_bytes = context["audio_in"].get_nowait() + audio_event = { + "audioData": audio_bytes, + "format": "pcm", + "sampleRate": 16000 + } + await agent.send_audio(audio_event) + except asyncio.QueueEmpty: + await asyncio.sleep(0.01) # Restored to working timing + except asyncio.CancelledError: + break + + context["active"] = False + except asyncio.CancelledError: + pass + + +async def main(duration=180): + """Main function for bidirectional streaming test.""" + print("Starting bidirectional streaming test...") + print("Audio optimizations: 1024-byte buffers, balanced smooth playback + responsive interruption") + + # Initialize model and agent + model = NovaSonicBidirectionalModel(region="us-east-1") + agent = BidirectionalAgent( + model=model, + tools=[calculator], + system_prompt="You are a helpful assistant." + ) + + await agent.start_conversation() + + # Create shared context for all tasks + context = { + "active": True, + "audio_in": asyncio.Queue(), + "audio_out": asyncio.Queue(), + "session": agent._session, + "duration": duration, + "start_time": time.time(), + "interrupted": False, + } + + print("Speak into microphone. Press Ctrl+C to exit.") + + try: + # Run all tasks concurrently + await asyncio.gather( + play(context), + record(context), + receive(agent, context), + send(agent, context), + return_exceptions=True + ) + except KeyboardInterrupt: + print("\nInterrupted by user") + except asyncio.CancelledError: + print("\nTest cancelled") + finally: + print("Cleaning up...") + context["active"] = False + await agent.end_conversation() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/src/strands/experimental/bidirectional_streaming/types/__init__.py b/src/strands/experimental/bidirectional_streaming/types/__init__.py new file mode 100644 index 000000000..f6441d2f0 --- /dev/null +++ b/src/strands/experimental/bidirectional_streaming/types/__init__.py @@ -0,0 +1,3 @@ +"""Bidirectional streaming types package.""" +# Types package + diff --git a/src/strands/experimental/bidirectional_streaming/types/bidirectional_streaming.py b/src/strands/experimental/bidirectional_streaming/types/bidirectional_streaming.py new file mode 100644 index 000000000..2b1480e62 --- /dev/null +++ b/src/strands/experimental/bidirectional_streaming/types/bidirectional_streaming.py @@ -0,0 +1,167 @@ +"""Bidirectional streaming types for real-time audio/text conversations. + +PROBLEM ADDRESSED: +----------------- +Strands currently uses a request-response architecture without bidirectional streaming +support. Users cannot interrupt ongoing responses, provide additional context during +processing, or engage in real-time conversations. Each interaction requires a complete +request-response cycle. + +ARCHITECTURAL TRANSFORMATION: +---------------------------- +Current Limitations: Strands' unidirectional architecture follows sequential +request-response cycles that prevent real-time interaction. This represents a +pull-based architecture where the model receives the request, processes it, and +sends a response back. + +Bidirectional Solution: Uses persistent session-based connections with continuous +input and output flow. This implements a push-based architecture where the model +sends updates to the client as soon as response becomes available, without explicit +client requests. + +KEY CHARACTERISTICS: +------------------- +- Persistent Sessions: Connections remain open for extended periods (Nova Sonic: 8 minutes, + Google Live API: 15 minutes, OpenAI Realtime: 30 minutes) maintaining conversation context +- Bidirectional Communication: Users can send input while models generate responses +- Interruption Handling: Users can interrupt ongoing model responses in real-time without + terminating the session +- Tool Execution: Tools execute concurrently within the conversation flow rather than + requiring requests rebuilding + +PROVIDER NORMALIZATION: +---------------------- +Must normalize incompatible audio formats: Nova Sonic's hex-encoded base64, Google's +LINEAR16 PCM, OpenAI's Base64-encoded PCM16. Requires unified interruption event types +to handle Nova Sonic's stopReason = INTERRUPTED events, Google's VAD cancellation, and +OpenAI's conversation.item.truncate. + +This module extends existing StreamEvent types while maintaining backward compatibility +with existing Strands streaming patterns. +""" + +from typing import Any, Dict, Literal, Optional + +from strands.types.content import Role +from strands.types.streaming import StreamEvent +from typing_extensions import TypedDict + +# Audio format constants +SUPPORTED_AUDIO_FORMATS = ['pcm', 'wav', 'opus', 'mp3'] +SUPPORTED_SAMPLE_RATES = [16000, 24000, 48000] +SUPPORTED_CHANNELS = [1, 2] # 1=mono, 2=stereo +DEFAULT_SAMPLE_RATE = 16000 +DEFAULT_CHANNELS = 1 + +class AudioOutputEvent(TypedDict): + """Audio output event from the model. + + Standardizes audio output across different providers using raw bytes + instead of provider-specific encodings (base64, hex, etc.). + + Attributes: + audioData: Raw audio bytes (not base64 or hex encoded). + format: Audio format from SUPPORTED_AUDIO_FORMATS. + sampleRate: Sample rate from SUPPORTED_SAMPLE_RATES. + channels: Channel count from SUPPORTED_CHANNELS. + encoding: Original provider encoding for debugging purposes. + """ + + audioData: bytes + format: Literal['pcm', 'wav', 'opus', 'mp3'] + sampleRate: Literal[16000, 24000, 48000] + channels: Literal[1, 2] + encoding: Optional[str] + + +class AudioInputEvent(TypedDict): + """Audio input event for sending audio to the model. + + Used when sending audio data through send_audio() method. + + Attributes: + audioData: Raw audio bytes to send to model. + format: Audio format from SUPPORTED_AUDIO_FORMATS. + sampleRate: Sample rate from SUPPORTED_SAMPLE_RATES. + channels: Channel count from SUPPORTED_CHANNELS. + """ + + audioData: bytes + format: Literal['pcm', 'wav', 'opus', 'mp3'] + sampleRate: Literal[16000, 24000, 48000] + channels: Literal[1, 2] + + +class TextOutputEvent(TypedDict): + """Text output event from the model during bidirectional streaming. + + Attributes: + text: The text content from the model. + role: The role of the message sender. + """ + + text: str + role: Role + + +class InterruptionDetectedEvent(TypedDict): + """Interruption detection event. + + Signals when user interruption is detected during model generation. + + Attributes: + reason: Interruption reason from predefined set. + """ + + reason: Literal['user_input', 'vad_detected', 'manual'] + + +class BidirectionalConnectionStartEvent(TypedDict, total=False): + """Session start event for bidirectional streaming. + + Attributes: + sessionId: Unique session identifier. + metadata: Provider-specific session metadata. + """ + + sessionId: Optional[str] + metadata: Optional[Dict[str, Any]] + + +class BidirectionalConnectionEndEvent(TypedDict): + """Session end event for bidirectional streaming. + + Attributes: + reason: Reason for session end from predefined set. + sessionId: Unique session identifier. + metadata: Provider-specific session metadata. + """ + + reason: Literal['user_request', 'timeout', 'error'] + sessionId: Optional[str] + metadata: Optional[Dict[str, Any]] + + +class BidirectionalStreamEvent(StreamEvent, total=False): + """Bidirectional stream event extending existing StreamEvent. + + Inherits all existing StreamEvent fields (contentBlockDelta, toolUse, + messageStart, etc.) while adding bidirectional-specific events. + Maintains full backward compatibility with existing Strands streaming. + + Attributes: + audioOutput: Audio output from the model. + audioInput: Audio input sent to the model. + textOutput: Text output from the model. + interruptionDetected: User interruption detection. + BidirectionalConnectionStart: Session start event. + BidirectionalConnectionEnd: Session end event. + """ + + audioOutput: AudioOutputEvent + audioInput: AudioInputEvent + textOutput: TextOutputEvent + interruptionDetected: InterruptionDetectedEvent + BidirectionalConnectionStart: BidirectionalConnectionStartEvent + BidirectionalConnectionEnd: BidirectionalConnectionEndEvent + diff --git a/src/strands/experimental/bidirectional_streaming/utils/debug.py b/src/strands/experimental/bidirectional_streaming/utils/debug.py new file mode 100644 index 000000000..1e88b6ead --- /dev/null +++ b/src/strands/experimental/bidirectional_streaming/utils/debug.py @@ -0,0 +1,45 @@ +"""Debug utilities for Strands bidirectional streaming. + +Provides consistent debug logging across all bidirectional streaming components +with configurable output control matching the Nova Sonic tool use example. +""" + +import datetime +import inspect +import time + +# Debug logging system matching successful tool use example +DEBUG = False # Disable debug logging for clean output like tool use example + +def debug_print(message): + """Print debug message with timestamp and function name.""" + if DEBUG: + function_name = inspect.stack()[1].function + if function_name == 'time_it_async': + function_name = inspect.stack()[2].function + timestamp = '{:%Y-%m-%d %H:%M:%S.%f}'.format(datetime.datetime.now())[:-3] + print(f"{timestamp} {function_name} {message}") + +def log_event(event_type, **context): + """Log important events with structured context.""" + if DEBUG: + function_name = inspect.stack()[1].function + timestamp = '{:%Y-%m-%d %H:%M:%S.%f}'.format(datetime.datetime.now())[:-3] + context_str = " ".join([f"{k}={v}" for k, v in context.items()]) if context else "" + print(f"{timestamp} {function_name} EVENT: {event_type} {context_str}") + +def log_flow(step, details=""): + """Log important flow steps without excessive detail.""" + if DEBUG: + function_name = inspect.stack()[1].function + timestamp = '{:%Y-%m-%d %H:%M:%S.%f}'.format(datetime.datetime.now())[:-3] + print(f"{timestamp} {function_name} FLOW: {step} {details}") + +async def time_it_async(label, method_to_run): + """Time asynchronous method execution.""" + start_time = time.perf_counter() + result = await method_to_run() + end_time = time.perf_counter() + debug_print(f"Execution time for {label}: {end_time - start_time:.4f} seconds") + return result + From 9165a2074eaa3a35f1e7df01ddfdd04c7d6e523a Mon Sep 17 00:00:00 2001 From: Rachit Mehta Date: Tue, 30 Sep 2025 10:41:16 -0400 Subject: [PATCH 02/23] Updated doc strings, updated method from send_text() and send_audio() to send(), Updated imports --- pyproject.toml | 2 +- .../bidirectional_streaming/agent/agent.py | 105 +++++++------ .../event_loop/bidirectional_event_loop.py | 62 ++++---- .../models/bidirectional_model.py | 75 +++++----- .../models/novasonic.py | 141 +++++++++--------- .../tests/test_bidirectional_streaming.py | 26 +++- .../types/bidirectional_streaming.py | 86 ++++------- 7 files changed, 234 insertions(+), 263 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index d4f7e6eee..dd01ebde3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -75,7 +75,7 @@ a2a = [ "fastapi>=0.115.12,<1.0.0", "starlette>=0.46.2,<1.0.0", ] -all = ["strands-agents[a2a,anthropic,bidirectional-streaming,docs,litellm,llamaapi,mistral,ollama,openai,writer,sagemaker,otel]"] +all = ["strands-agents[a2a,anthropic,docs,gemini,bidirectional-streaming,docs,litellm,llamaapi,mistral,ollama,openai,writer,sagemaker,otel]"] dev = [ "commitizen>=4.4.0,<5.0.0", diff --git a/src/strands/experimental/bidirectional_streaming/agent/agent.py b/src/strands/experimental/bidirectional_streaming/agent/agent.py index cfc005576..023997551 100644 --- a/src/strands/experimental/bidirectional_streaming/agent/agent.py +++ b/src/strands/experimental/bidirectional_streaming/agent/agent.py @@ -1,30 +1,22 @@ """Bidirectional Agent for real-time streaming conversations. -AGENT PURPOSE: -------------- -Provides type-safe constructor and session management for real-time audio/text -interaction. Serves as the bidirectional equivalent to invoke_async() → stream_async() -but establishes sessions that continue indefinitely with concurrent task management. +Provides real-time audio and text interaction through persistent streaming sessions. +Unlike traditional request-response patterns, this agent maintains long-running +conversations where users can interrupt, provide additional input, and receive +continuous responses including audio output. -ARCHITECTURAL APPROACH: ----------------------- -While invoke_async() creates single request-response cycles that terminate after -stop_reason: "end_turn" with sequential tool processing, start_conversation() -establishes persistent sessions with concurrent processing of model events, tool -execution, and user input without session termination. - -DESIGN CHOICE: -------------- -Uses dedicated BidirectionalAgent class (Option 1 from design document) for: -- Type safety with no conditional behavior based on model type -- Separation of concerns - solely focused on bidirectional streaming -- Future proofing - allows changes without implications to existing Agent class +Key capabilities: +- Persistent conversation sessions with concurrent processing +- Real-time audio input/output streaming +- Mid-conversation interruption and tool execution +- Event-driven communication with model providers """ import asyncio import logging -from typing import AsyncIterable, List, Optional +from typing import AsyncIterable, List, Optional, Union +from strands.tools.executors import ConcurrentToolExecutor from strands.tools.registry import ToolRegistry from strands.types.content import Messages @@ -39,8 +31,8 @@ class BidirectionalAgent: """Agent for bidirectional streaming conversations. - Provides type-safe constructor and session management for real-time - audio/text interaction with concurrent processing capabilities. + Enables real-time audio and text interaction with AI models through persistent + sessions. Supports concurrent tool execution and interruption handling. """ def __init__( @@ -69,60 +61,63 @@ def __init__( self.tool_registry.initialize_tools() # Initialize tool executor for concurrent execution - from strands.tools.executors import ConcurrentToolExecutor self.tool_executor = ConcurrentToolExecutor() # Session management self._session = None self._output_queue = asyncio.Queue() - async def start_conversation(self) -> None: - """Initialize persistent bidirectional session for real-time interaction. + async def start(self) -> None: + """Start a persistent bidirectional conversation session. - Creates provider-specific session and starts concurrent background tasks - for model events, tool execution, and session lifecycle management. + Initializes the streaming session and starts background tasks for processing + model events, tool execution, and session management. Raises: ValueError: If conversation already active. ConnectionError: If session creation fails. """ if self._session and self._session.active: - raise ValueError("Conversation already active. Call end_conversation() first.") + raise ValueError("Conversation already active. Call end() first.") log_flow("conversation_start", "initializing session") self._session = await start_bidirectional_connection(self) log_event("conversation_ready") - async def send_text(self, text: str) -> None: - """Send text input during active session without interrupting model generation. + async def send(self, input_data: Union[str, AudioInputEvent]) -> None: + """Send input to the model (text or audio). - Args: - text: Text message to send to the model. - - Raises: - ValueError: If no active session. - """ - self._validate_active_session() - log_event("text_sent", length=len(text)) - await self._session.model_session.send_text_content(text) - - async def send_audio(self, audio_input: AudioInputEvent) -> None: - """Send audio input during active session for real-time speech interaction. + Unified method for sending both text and audio input to the model during + an active conversation session. Args: - audio_input: AudioInputEvent containing audio data and configuration. + input_data: Either a string for text input or AudioInputEvent for audio input. Raises: - ValueError: If no active session. + ValueError: If no active session or invalid input type. """ self._validate_active_session() - await self._session.model_session.send_audio_content(audio_input) + + if isinstance(input_data, str): + # Handle text input + log_event("text_sent", length=len(input_data)) + await self._session.model_session.send_text_content(input_data) + elif isinstance(input_data, dict) and "audioData" in input_data: + # Handle audio input (AudioInputEvent) + await self._session.model_session.send_audio_content(input_data) + else: + raise ValueError( + "Input must be either a string (text) or AudioInputEvent " + "(dict with audioData, format, sampleRate, channels)" + ) + + async def receive(self) -> AsyncIterable[BidirectionalStreamEvent]: - """Receive output events from the model including audio, text. + """Receive events from the model including audio, text, and tool calls. - Provides access to model output events processed by background tasks. - Events include audio output, text responses, tool calls, and session updates. + Yields model output events processed by background tasks including audio output, + text responses, tool calls, and session updates. Yields: BidirectionalStreamEvent: Events from the model session. @@ -135,10 +130,10 @@ async def receive(self) -> AsyncIterable[BidirectionalStreamEvent]: continue async def interrupt(self) -> None: - """Interrupt current model generation and switch to listening mode. + """Interrupt the current model generation and clear audio buffers. - Sends interruption signal to immediately stop generation and clear - pending audio output for responsive conversational experience. + Sends interruption signal to stop generation immediately and clears + pending audio output for responsive conversation flow. Raises: ValueError: If no active session. @@ -146,11 +141,11 @@ async def interrupt(self) -> None: self._validate_active_session() await self._session.model_session.send_interrupt() - async def end_conversation(self) -> None: - """End session and cleanup resources including background tasks. + async def end(self) -> None: + """End the conversation session and cleanup all resources. - Performs graceful session termination with proper resource cleanup - including background task cancellation and connection closure. + Terminates the streaming session, cancels background tasks, and + closes the connection to the model provider. """ if self._session: await stop_bidirectional_connection(self._session) @@ -163,5 +158,5 @@ def _validate_active_session(self) -> None: ValueError: If no active session. """ if not self._session or not self._session.active: - raise ValueError("No active conversation. Call start_conversation() first.") + raise ValueError("No active conversation. Call start() first.") diff --git a/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py b/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py index 2164115d8..3884750d5 100644 --- a/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py +++ b/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py @@ -1,16 +1,14 @@ """Bidirectional session management for concurrent streaming conversations. -SESSION PURPOSE: ---------------- -Session wrapper for bidirectional communication that manages concurrent tasks for -model events, tool execution, and audio processing while providing simple interface -for Agent interaction. +Manages bidirectional communication sessions with concurrent processing of model events, +tool execution, and audio processing. Provides coordination between background tasks +while maintaining a simple interface for agent interaction. -CONCURRENT ARCHITECTURE: ------------------------ -Unlike existing event_loop_cycle() that processes events sequentially where tool -execution blocks conversation, this module coordinates concurrent tasks through -asyncio queues and background task management. +Features: +- Concurrent task management for model events and tool execution +- Interruption handling with audio buffer clearing +- Tool execution with cancellation support +- Session lifecycle management """ import asyncio @@ -35,10 +33,10 @@ class BidirectionalConnection: - """Session wrapper for bidirectional communication. + """Session wrapper for bidirectional communication with concurrent task management. - Manages concurrent tasks for model events, tool execution, and audio processing - while providing simple interface for Agent interaction. + Coordinates background tasks for model event processing, tool execution, and audio + handling while providing a simple interface for agent interactions. """ def __init__(self, model_session: BidirectionalModelSession, agent): @@ -66,8 +64,8 @@ def __init__(self, model_session: BidirectionalModelSession, agent): async def start_bidirectional_connection(agent) -> BidirectionalConnection: """Initialize bidirectional session with concurrent background tasks. - Creates provider-specific session and starts concurrent tasks for model events, - tool execution, and session lifecycle management. + Creates a model-specific session and starts background tasks for processing + model events, executing tools, and managing the session lifecycle. Args: agent: BidirectionalAgent instance. @@ -147,11 +145,10 @@ async def stop_bidirectional_connection(session: BidirectionalConnection) -> Non async def bidirectional_event_loop_cycle(session: BidirectionalConnection) -> None: - """Main bidirectional event loop coordinator - runs continuously during session. + """Main event loop coordinator that runs continuously during the session. - Coordinates background tasks and manages session lifecycle. Unlike the - sequential event_loop_cycle() that processes events one by one, this coordinator - manages concurrent tasks and session state. + Monitors background tasks, manages session state, and handles session lifecycle. + Provides supervision for concurrent model event processing and tool execution. Args: session: BidirectionalConnection to coordinate. @@ -185,10 +182,10 @@ async def bidirectional_event_loop_cycle(session: BidirectionalConnection) -> No async def _handle_interruption(session: BidirectionalConnection) -> None: - """Handle interruption detection with comprehensive task cancellation. + """Handle interruption detection with task cancellation and audio buffer clearing. - Sets interruption flag, cancels pending tool tasks, and aggressively - clears audio output queue following Nova Sonic example patterns. + Cancels pending tool tasks and clears audio output queues to ensure responsive + interruption handling during conversations. Args: session: BidirectionalConnection to handle interruption for. @@ -251,10 +248,10 @@ async def _handle_interruption(session: BidirectionalConnection) -> None: async def _process_model_events(session: BidirectionalConnection) -> None: - """Process model events using existing Strands event types. + """Process model events and convert them to Strands format. - This background task handles all model responses and converts - them to existing StreamEvent format for integration with Strands. + Background task that handles all model responses, converts provider-specific + events to standardized formats, and manages interruption detection. Args: session: BidirectionalConnection containing model session. @@ -309,11 +306,11 @@ async def _process_model_events(session: BidirectionalConnection) -> None: async def _process_tool_execution(session: BidirectionalConnection) -> None: - """Execute tools concurrently using existing Strands infrastructure with barge-in support. + """Execute tools concurrently with interruption support. - This background task manages tool execution without blocking - model event processing or user interaction. Includes proper - task cleanup and cancellation handling. + Background task that manages tool execution without blocking model event + processing or user interaction. Includes proper task cleanup and cancellation + handling for interruptions. Args: session: BidirectionalConnection containing tool queue. @@ -396,11 +393,10 @@ def _convert_to_strands_event(provider_event: Dict) -> Dict: async def _execute_tool_with_strands(session: BidirectionalConnection, tool_use: Dict) -> None: - """Execute tool using existing Strands infrastructure with barge-in support. + """Execute tool using Strands infrastructure with interruption support. - Model-agnostic tool execution that uses existing Strands tool system, - handles interruption during execution, and delegates result formatting - to provider-specific session. + Executes tools using the existing Strands tool system, handles interruption + during execution, and sends results back to the model provider. Args: session: BidirectionalConnection for context. diff --git a/src/strands/experimental/bidirectional_streaming/models/bidirectional_model.py b/src/strands/experimental/bidirectional_streaming/models/bidirectional_model.py index 32727105d..81e5cd9d6 100644 --- a/src/strands/experimental/bidirectional_streaming/models/bidirectional_model.py +++ b/src/strands/experimental/bidirectional_streaming/models/bidirectional_model.py @@ -1,24 +1,14 @@ """Bidirectional model interface for real-time streaming conversations. -INTERFACE PURPOSE: ------------------ -Declares bidirectional capabilities separate from existing Model hierarchy to maintain -clean separation of concerns. Models choose to implement this interface explicitly -for bidirectional streaming support. +Defines the interface for models that support bidirectional streaming capabilities. +Provides abstractions for different model providers with connection-based communication +patterns that support real-time audio and text interaction. -PROVIDER ABSTRACTION: --------------------- -Abstracts incompatible initialization patterns: Nova Sonic's event-driven sequences, -Google's WebSocket setup, OpenAI's dual protocol support. Normalizes different tool -calling approaches and handles provider-specific session management with varying -time limits and connection patterns. - -SESSION-BASED APPROACH: ----------------------- -Unlike existing Model interface's stateless request-response pattern where each -stream() call processes complete messages independently, BidirectionalModel introduces -session-based approach where create_bidirectional_connection() establishes persistent -connections supporting real-time bidirectional communication during active generation. +Features: +- connection-based persistent connections +- Real-time bidirectional communication +- Provider-agnostic event normalization +- Tool execution integration """ import abc @@ -32,51 +22,54 @@ logger = logging.getLogger(__name__) class BidirectionalModelSession(abc.ABC): - """Model-specific session interface for bidirectional communication.""" + """Abstract interface for model-specific bidirectional communication connections. + + Defines the contract for managing persistent streaming connections with individual + model providers, handling audio/text input, receiving events, and managing + tool execution results. + """ @abc.abstractmethod async def receive_events(self) -> AsyncIterable[Dict[str, Any]]: - """Receive events from model in provider-agnostic format. + """Receive events from the model in standardized format. - Normalizes different provider event formats so the event loop - can process all providers uniformly. + Converts provider-specific events to a common format that can be + processed uniformly by the event loop. """ raise NotImplementedError @abc.abstractmethod async def send_audio_content(self, audio_input: AudioInputEvent) -> None: - """Send audio content to model during session. + """Send audio content to the model during an active connection. - Manages complex audio encoding and provider-specific event sequences - while presenting simple AudioInputEvent interface to Agent. + Handles audio encoding and provider-specific formatting while presenting + a simple AudioInputEvent interface. """ raise NotImplementedError @abc.abstractmethod async def send_text_content(self, text: str, **kwargs) -> None: - """Send text content processed concurrently with ongoing generation. + """Send text content to the model during ongoing generation. - Enables natural interruption and follow-up questions without session restart. + Allows natural interruption and follow-up questions without requiring + connection restart. """ raise NotImplementedError @abc.abstractmethod async def send_interrupt(self) -> None: - """Send interruption signal to immediately stop generation. + """Send interruption signal to stop generation immediately. - Critical for responsive conversational experiences where users - can naturally interrupt mid-response. + Enables responsive conversational experiences where users can + naturally interrupt during model responses. """ raise NotImplementedError @abc.abstractmethod async def send_tool_result(self, tool_use_id: str, result: Dict[str, Any]) -> None: - """Send tool execution result to model in provider-specific format. + """Send tool execution result to the model. - Each provider handles result formatting according to their protocol: - - Nova Sonic: toolResult events with JSON content - - Google Live API: toolResponse with specific structure - - OpenAI Realtime: function call responses with call_id correlation + Formats and sends tool results according to the provider's specific protocol. """ raise NotImplementedError @@ -87,15 +80,15 @@ async def send_tool_error(self, tool_use_id: str, error: str) -> None: @abc.abstractmethod async def close(self) -> None: - """Close session and cleanup resources with graceful termination.""" + """Close the connection and cleanup resources.""" raise NotImplementedError class BidirectionalModel(abc.ABC): """Interface for models that support bidirectional streaming. - Separate from Model to maintain clean separation of concerns. - Models choose to implement this interface explicitly. + Defines the contract for creating persistent streaming connections that support + real-time audio and text communication with AI models. """ @abc.abstractmethod @@ -106,10 +99,10 @@ async def create_bidirectional_connection( messages: Optional[Messages] = None, **kwargs ) -> BidirectionalModelSession: - """Create bidirectional session with model-specific implementation. + """Create a bidirectional connection with the model. - Abstracts complex provider-specific initialization while presenting - uniform interface to Agent. + Establishes a persistent connection for real-time communication while + abstracting provider-specific initialization requirements. """ raise NotImplementedError diff --git a/src/strands/experimental/bidirectional_streaming/models/novasonic.py b/src/strands/experimental/bidirectional_streaming/models/novasonic.py index ba71cd4d3..4332181b5 100644 --- a/src/strands/experimental/bidirectional_streaming/models/novasonic.py +++ b/src/strands/experimental/bidirectional_streaming/models/novasonic.py @@ -1,23 +1,15 @@ """Nova Sonic bidirectional model provider for real-time streaming conversations. -PROVIDER PURPOSE: ----------------- -Implements BidirectionalModel and BidirectionalModelSession interfaces for Nova Sonic, -handling the complex three-tier event management and structured event cleanup sequences -required by Nova Sonic's InvokeModelWithBidirectionalStream protocol. +Implements the BidirectionalModel interface for Amazon's Nova Sonic, handling the +complex event sequencing and audio processing required by Nova Sonic's +InvokeModelWithBidirectionalStream protocol. -NOVA SONIC SPECIFICS: --------------------- -- Requires hierarchical event sequences: sessionStart → promptStart → content streaming -- Uses hex-encoded base64 audio format that needs conversion to raw bytes -- Implements toolUse/toolResult with content containers and identifier tracking -- Manages 8-minute session limits with proper cleanup sequences -- Handles stopReason: "INTERRUPTED" events for interruption detection - -INTEGRATION APPROACH: --------------------- -Adapts existing Nova Sonic sample patterns to work with Strands bidirectional -infrastructure while maintaining provider-specific protocol requirements. +Nova Sonic specifics: +- Hierarchical event sequences: connectionStart → promptStart → content streaming +- Base64-encoded audio format with hex encoding +- Tool execution with content containers and identifier tracking +- 8-minute connection limits with proper cleanup sequences +- Interruption detection through stopReason events """ import asyncio @@ -85,10 +77,15 @@ class NovaSonicSession(BidirectionalModelSession): - """Nova Sonic session handling protocol-specific details.""" + """Nova Sonic connection implementation handling the provider's specific protocol. + + Manages Nova Sonic's complex event sequencing, audio format conversion, and + tool execution patterns while providing the standard BidirectionalModelSession + interface. + """ def __init__(self, stream, config: Dict[str, Any]): - """Initialize Nova Sonic session. + """Initialize Nova Sonic connection. Args: stream: Nova Sonic bidirectional stream. @@ -103,8 +100,8 @@ def __init__(self, stream, config: Dict[str, Any]): self.audio_content_name = str(uuid.uuid4()) self.text_content_name = str(uuid.uuid4()) - # Audio session state - self.audio_session_active = False + # Audio connection state + self.audio_connection_active = False self.last_audio_time = None self.silence_threshold = SILENCE_THRESHOLD self.silence_task = None @@ -114,7 +111,7 @@ def __init__(self, stream, config: Dict[str, Any]): logger.error("Stream is None") raise ValueError("Stream cannot be None") - logger.debug("Nova Sonic session initialized with prompt: %s", self.prompt_name) + logger.debug("Nova Sonic connection initialized with prompt: %s", self.prompt_name) async def initialize( self, @@ -122,7 +119,7 @@ async def initialize( tools: Optional[List[ToolSpec]] = None, messages: Optional[Messages] = None ) -> None: - """Initialize Nova Sonic session with required protocol sequence.""" + """Initialize Nova Sonic connection with required protocol sequence.""" try: system_prompt = system_prompt or "You are a helpful assistant. Keep responses brief." @@ -131,7 +128,7 @@ async def initialize( log_flow("nova_init", f"sending {len(init_events)} events") await self._send_initialization_events(init_events) - log_event("nova_session_initialized") + log_event("nova_connection_initialized") self._response_task = asyncio.create_task(self._process_responses()) except Exception as e: @@ -142,7 +139,7 @@ def _build_initialization_events(self, system_prompt: str, tools: List[ToolSpec] messages: Optional[Messages]) -> List[str]: """Build the sequence of initialization events.""" events = [ - self._get_session_start_event(), + self._get_connection_start_event(), self._get_prompt_start_event(tools) ] @@ -223,13 +220,13 @@ async def receive_events(self) -> AsyncIterable[Dict[str, Any]]: log_flow("nova_events", "starting event stream") - # Emit session start event to Strands event system - session_start: BidirectionalConnectionStartEvent = { - "sessionId": self.prompt_name, + # Emit connection start event to Strands event system + connection_start: BidirectionalConnectionStartEvent = { + "connectionId": self.prompt_name, "metadata": {"provider": "nova_sonic", "model_id": self.config.get("model_id")} } yield { - "BidirectionalConnectionStart": session_start + "BidirectionalConnectionStart": connection_start } # Initialize event queue if not already done @@ -255,22 +252,22 @@ async def receive_events(self) -> AsyncIterable[Dict[str, Any]]: logger.error("Error receiving Nova Sonic event: %s", e) logger.error(traceback.format_exc()) finally: - # Emit session end event when exiting - session_end: BidirectionalConnectionEndEvent = { - "sessionId": self.prompt_name, - "reason": "session_complete", + # Emit connection end event when exiting + connection_end: BidirectionalConnectionEndEvent = { + "connectionId": self.prompt_name, + "reason": "connection_complete", "metadata": {"provider": "nova_sonic"} } yield { - "BidirectionalConnectionEnd": session_end + "BidirectionalConnectionEnd": connection_end } - async def start_audio_session(self) -> None: - """Start audio input session (call once before sending audio chunks).""" - if self.audio_session_active: + async def start_audio_connection(self) -> None: + """Start audio input connection (call once before sending audio chunks).""" + if self.audio_connection_active: return - log_event("nova_audio_session_start") + log_event("nova_audio_connection_start") audio_content_start = json.dumps({ "event": { @@ -286,16 +283,16 @@ async def start_audio_session(self) -> None: }) await self._send_nova_event(audio_content_start) - self.audio_session_active = True + self.audio_connection_active = True async def send_audio_content(self, audio_input: AudioInputEvent) -> None: """Send audio using Nova Sonic protocol-specific format.""" if not self._active: return - # Start audio session if not already active - if not self.audio_session_active: - await self.start_audio_session() + # Start audio connection if not already active + if not self.audio_connection_active: + await self.start_audio_connection() # Update last audio time and cancel any pending silence task self.last_audio_time = time.time() @@ -322,10 +319,10 @@ async def send_audio_content(self, audio_input: AudioInputEvent) -> None: self.silence_task = asyncio.create_task(self._check_silence()) async def _check_silence(self): - """Check for silence and automatically end audio session.""" + """Check for silence and automatically end audio connection.""" try: await asyncio.sleep(self.silence_threshold) - if self.audio_session_active and self.last_audio_time: + if self.audio_connection_active and self.last_audio_time: elapsed = time.time() - self.last_audio_time if elapsed >= self.silence_threshold: log_event("nova_silence_detected", elapsed=elapsed) @@ -334,11 +331,11 @@ async def _check_silence(self): pass async def end_audio_input(self) -> None: - """End current audio input session to trigger Nova Sonic processing.""" - if not self.audio_session_active: + """End current audio input connection to trigger Nova Sonic processing.""" + if not self.audio_connection_active: return - log_event("nova_audio_session_end") + log_event("nova_audio_connection_end") audio_content_end = json.dumps({ "event": { @@ -350,7 +347,7 @@ async def end_audio_input(self) -> None: }) await self._send_nova_event(audio_content_end) - self.audio_session_active = False + self.audio_connection_active = False async def send_text_content(self, text: str, **kwargs) -> None: """Send text content using Nova Sonic format.""" @@ -407,11 +404,11 @@ async def send_tool_error(self, tool_use_id: str, error: str) -> None: await self.send_tool_result(tool_use_id, error_result) async def close(self) -> None: - """Close Nova Sonic session with proper cleanup sequence.""" + """Close Nova Sonic connection with proper cleanup sequence.""" if not self._active: return - log_flow("nova_cleanup", "starting session close") + log_flow("nova_cleanup", "starting connection close") self._active = False # Cancel response processing task if running @@ -423,14 +420,14 @@ async def close(self) -> None: pass try: - # End audio session if active - if self.audio_session_active: + # End audio connection if active + if self.audio_connection_active: await self.end_audio_input() # Send cleanup events cleanup_events = [ self._get_prompt_end_event(), - self._get_session_end_event() + self._get_connection_end_event() ] for event in cleanup_events: @@ -448,7 +445,7 @@ async def close(self) -> None: except Exception as e: log_event("nova_cleanup_error", error=str(e)) finally: - log_event("nova_session_closed") + log_event("nova_connection_closed") def _convert_nova_event(self, nova_event: Dict[str, Any]) -> Optional[Dict[str, Any]]: """Convert Nova Sonic events to provider-agnostic format.""" @@ -542,8 +539,8 @@ def _convert_nova_event(self, nova_event: Dict[str, Any]) -> Optional[Dict[str, return None # Nova Sonic event template methods - def _get_session_start_event(self) -> str: - """Generate Nova Sonic session start event.""" + def _get_connection_start_event(self) -> str: + """Generate Nova Sonic connection start event.""" return json.dumps({ "event": { "sessionStart": { @@ -676,11 +673,11 @@ def _get_prompt_end_event(self) -> str: } }) - def _get_session_end_event(self) -> str: - """Generate session end event.""" + def _get_connection_end_event(self) -> str: + """Generate connection end event.""" return json.dumps({ "event": { - "sessionEnd": {} + "connectionEnd": {} } }) @@ -703,7 +700,11 @@ async def _send_nova_event(self, event: str) -> None: class NovaSonicBidirectionalModel(BidirectionalModel): - """Nova Sonic model implementing bidirectional capabilities.""" + """Nova Sonic model implementation for bidirectional streaming. + + Provides access to Amazon's Nova Sonic model through the bidirectional + streaming interface, handling AWS authentication and connection management. + """ def __init__(self, model_id: str = "amazon.nova-sonic-v1:0", region: str = "us-east-1", **config): """Initialize Nova Sonic bidirectional model. @@ -727,8 +728,8 @@ async def create_bidirectional_connection( messages: Optional[Messages] = None, **kwargs ) -> BidirectionalModelSession: - """Create Nova Sonic bidirectional session.""" - log_flow("nova_session_create", "starting") + """Create Nova Sonic bidirectional connection.""" + log_flow("nova_connection_create", "starting") # Initialize client if needed if not self._client: @@ -741,16 +742,16 @@ async def create_bidirectional_connection( InvokeModelWithBidirectionalStreamOperationInput(model_id=self.model_id) )) - # Create and initialize session - session = NovaSonicSession(stream, self.config) - await time_it_async("initialize_session", - lambda: session.initialize(system_prompt, tools, messages)) + # Create and initialize connection + connection = NovaSonicSession(stream, self.config) + await time_it_async("initialize_connection", + lambda: connection.initialize(system_prompt, tools, messages)) - log_event("nova_session_created") - return session + log_event("nova_connection_created") + return connection except Exception as e: - log_event("nova_session_create_error", error=str(e)) - logger.error("Failed to create Nova Sonic session: %s", e) + log_event("nova_connection_create_error", error=str(e)) + logger.error("Failed to create Nova Sonic connection: %s", e) raise async def _initialize_client(self) -> None: diff --git a/src/strands/experimental/bidirectional_streaming/tests/test_bidirectional_streaming.py b/src/strands/experimental/bidirectional_streaming/tests/test_bidirectional_streaming.py index f35fd4462..d650aba9b 100644 --- a/src/strands/experimental/bidirectional_streaming/tests/test_bidirectional_streaming.py +++ b/src/strands/experimental/bidirectional_streaming/tests/test_bidirectional_streaming.py @@ -1,11 +1,20 @@ -"""Simple bidirectional streaming test with enhanced interruption support.""" +"""Test suite for bidirectional streaming with real-time audio interaction. + +Tests the complete bidirectional streaming system including audio input/output, +interruption handling, and concurrent tool execution using Nova Sonic. +""" import asyncio +import sys +from pathlib import Path + +# Add the src directory to Python path +sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent.parent)) import time import pyaudio -from src.strands.experimental.bidirectional_streaming.agent.agent import BidirectionalAgent -from src.strands.experimental.bidirectional_streaming.models.novasonic import NovaSonicBidirectionalModel +from strands.experimental.bidirectional_streaming.agent.agent import BidirectionalAgent +from strands.experimental.bidirectional_streaming.models.novasonic import NovaSonicBidirectionalModel from strands_tools import calculator @@ -139,9 +148,10 @@ async def send(agent, context): audio_event = { "audioData": audio_bytes, "format": "pcm", - "sampleRate": 16000 + "sampleRate": 16000, + "channels": 1 } - await agent.send_audio(audio_event) + await agent.send(audio_event) except asyncio.QueueEmpty: await asyncio.sleep(0.01) # Restored to working timing except asyncio.CancelledError: @@ -165,14 +175,14 @@ async def main(duration=180): system_prompt="You are a helpful assistant." ) - await agent.start_conversation() + await agent.start() # Create shared context for all tasks context = { "active": True, "audio_in": asyncio.Queue(), "audio_out": asyncio.Queue(), - "session": agent._session, + "connection": agent._session, "duration": duration, "start_time": time.time(), "interrupted": False, @@ -196,7 +206,7 @@ async def main(duration=180): finally: print("Cleaning up...") context["active"] = False - await agent.end_conversation() + await agent.end() if __name__ == "__main__": diff --git a/src/strands/experimental/bidirectional_streaming/types/bidirectional_streaming.py b/src/strands/experimental/bidirectional_streaming/types/bidirectional_streaming.py index 2b1480e62..fabe53ac9 100644 --- a/src/strands/experimental/bidirectional_streaming/types/bidirectional_streaming.py +++ b/src/strands/experimental/bidirectional_streaming/types/bidirectional_streaming.py @@ -1,43 +1,20 @@ """Bidirectional streaming types for real-time audio/text conversations. -PROBLEM ADDRESSED: ------------------ -Strands currently uses a request-response architecture without bidirectional streaming -support. Users cannot interrupt ongoing responses, provide additional context during -processing, or engage in real-time conversations. Each interaction requires a complete -request-response cycle. - -ARCHITECTURAL TRANSFORMATION: ----------------------------- -Current Limitations: Strands' unidirectional architecture follows sequential -request-response cycles that prevent real-time interaction. This represents a -pull-based architecture where the model receives the request, processes it, and -sends a response back. - -Bidirectional Solution: Uses persistent session-based connections with continuous -input and output flow. This implements a push-based architecture where the model -sends updates to the client as soon as response becomes available, without explicit -client requests. - -KEY CHARACTERISTICS: -------------------- -- Persistent Sessions: Connections remain open for extended periods (Nova Sonic: 8 minutes, - Google Live API: 15 minutes, OpenAI Realtime: 30 minutes) maintaining conversation context -- Bidirectional Communication: Users can send input while models generate responses -- Interruption Handling: Users can interrupt ongoing model responses in real-time without - terminating the session -- Tool Execution: Tools execute concurrently within the conversation flow rather than - requiring requests rebuilding - -PROVIDER NORMALIZATION: ----------------------- -Must normalize incompatible audio formats: Nova Sonic's hex-encoded base64, Google's -LINEAR16 PCM, OpenAI's Base64-encoded PCM16. Requires unified interruption event types -to handle Nova Sonic's stopReason = INTERRUPTED events, Google's VAD cancellation, and -OpenAI's conversation.item.truncate. - -This module extends existing StreamEvent types while maintaining backward compatibility -with existing Strands streaming patterns. +Type definitions for bidirectional streaming that extends Strands' existing streaming +capabilities with real-time audio and persistent connection support. + +Key features: +- Audio input/output events with standardized formats +- Interruption detection and handling +- connection lifecycle management +- Provider-agnostic event types +- Backwards compatibility with existing StreamEvent types + +Audio format normalization: +- Supports PCM, WAV, Opus, and MP3 formats +- Standardizes sample rates (16kHz, 24kHz, 48kHz) +- Normalizes channel configurations (mono/stereo) +- Abstracts provider-specific encodings """ from typing import Any, Dict, Literal, Optional @@ -56,8 +33,8 @@ class AudioOutputEvent(TypedDict): """Audio output event from the model. - Standardizes audio output across different providers using raw bytes - instead of provider-specific encodings (base64, hex, etc.). + Provides standardized audio output format across different providers using + raw bytes instead of provider-specific encodings. Attributes: audioData: Raw audio bytes (not base64 or hex encoded). @@ -77,7 +54,7 @@ class AudioOutputEvent(TypedDict): class AudioInputEvent(TypedDict): """Audio input event for sending audio to the model. - Used when sending audio data through send_audio() method. + Used for sending audio data through the send() method. Attributes: audioData: Raw audio bytes to send to model. @@ -117,45 +94,44 @@ class InterruptionDetectedEvent(TypedDict): class BidirectionalConnectionStartEvent(TypedDict, total=False): - """Session start event for bidirectional streaming. + """connection start event for bidirectional streaming. Attributes: - sessionId: Unique session identifier. - metadata: Provider-specific session metadata. + connectionId: Unique connection identifier. + metadata: Provider-specific connection metadata. """ - sessionId: Optional[str] + connectionId: Optional[str] metadata: Optional[Dict[str, Any]] class BidirectionalConnectionEndEvent(TypedDict): - """Session end event for bidirectional streaming. + """connection end event for bidirectional streaming. Attributes: - reason: Reason for session end from predefined set. - sessionId: Unique session identifier. - metadata: Provider-specific session metadata. + reason: Reason for connection end from predefined set. + connectionId: Unique connection identifier. + metadata: Provider-specific connection metadata. """ reason: Literal['user_request', 'timeout', 'error'] - sessionId: Optional[str] + connectionId: Optional[str] metadata: Optional[Dict[str, Any]] class BidirectionalStreamEvent(StreamEvent, total=False): """Bidirectional stream event extending existing StreamEvent. - Inherits all existing StreamEvent fields (contentBlockDelta, toolUse, - messageStart, etc.) while adding bidirectional-specific events. - Maintains full backward compatibility with existing Strands streaming. + Extends the existing StreamEvent type with bidirectional-specific events + while maintaining full backward compatibility with existing Strands streaming. Attributes: audioOutput: Audio output from the model. audioInput: Audio input sent to the model. textOutput: Text output from the model. interruptionDetected: User interruption detection. - BidirectionalConnectionStart: Session start event. - BidirectionalConnectionEnd: Session end event. + BidirectionalConnectionStart: connection start event. + BidirectionalConnectionEnd: connection end event. """ audioOutput: AudioOutputEvent From 15df9f9c06748c06376b596c7186e3712192e3cd Mon Sep 17 00:00:00 2001 From: Rachit Mehta Date: Tue, 30 Sep 2025 10:45:29 -0400 Subject: [PATCH 03/23] Updated minimum python runtime dependency --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index dd01ebde3..f45794d12 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -59,6 +59,7 @@ bidirectional-streaming = [ "smithy-aws-core>=0.0.1", "pytz", "aws_sdk_bedrock_runtime", + "python>=3.12" ] otel = ["opentelemetry-exporter-otlp-proto-http>=1.30.0,<2.0.0"] docs = [ From 3a0e7d5c360107ea4a0c890bf1c9f18ee3f1c603 Mon Sep 17 00:00:00 2001 From: Rachit Mehta Date: Wed, 1 Oct 2025 23:54:05 -0400 Subject: [PATCH 04/23] fix imports --- .../bidirectional_streaming/__init__.py | 5 + .../bidirectional_streaming/agent/__init__.py | 7 +- .../bidirectional_streaming/agent/agent.py | 70 ++- .../event_loop/__init__.py | 17 +- .../event_loop/bidirectional_event_loop.py | 243 ++++---- .../models/__init__.py | 8 +- .../models/bidirectional_model.py | 38 +- .../models/novasonic.py | 546 ++++++++---------- .../tests/test_bidirectional_streaming.py | 65 +-- .../bidirectional_streaming/types/__init__.py | 32 +- .../types/bidirectional_streaming.py | 53 +- .../bidirectional_streaming/utils/__init__.py | 5 + .../bidirectional_streaming/utils/debug.py | 13 +- 13 files changed, 530 insertions(+), 572 deletions(-) create mode 100644 src/strands/experimental/bidirectional_streaming/__init__.py create mode 100644 src/strands/experimental/bidirectional_streaming/utils/__init__.py diff --git a/src/strands/experimental/bidirectional_streaming/__init__.py b/src/strands/experimental/bidirectional_streaming/__init__.py new file mode 100644 index 000000000..f6a3b41bf --- /dev/null +++ b/src/strands/experimental/bidirectional_streaming/__init__.py @@ -0,0 +1,5 @@ +"""Bidirectional streaming package for real-time audio/text conversations.""" + +from .utils import log_event, log_flow, time_it_async + +__all__ = ["log_event", "log_flow", "time_it_async"] diff --git a/src/strands/experimental/bidirectional_streaming/agent/__init__.py b/src/strands/experimental/bidirectional_streaming/agent/__init__.py index bbd2c91f3..c490e001d 100644 --- a/src/strands/experimental/bidirectional_streaming/agent/__init__.py +++ b/src/strands/experimental/bidirectional_streaming/agent/__init__.py @@ -1,2 +1,5 @@ -"""Bidirectional streaming agent package.""" -# Agent package \ No newline at end of file +"""Bidirectional agent for real-time streaming conversations.""" + +from .agent import BidirectionalAgent + +__all__ = ["BidirectionalAgent"] diff --git a/src/strands/experimental/bidirectional_streaming/agent/agent.py b/src/strands/experimental/bidirectional_streaming/agent/agent.py index 023997551..d7a5f17a3 100644 --- a/src/strands/experimental/bidirectional_streaming/agent/agent.py +++ b/src/strands/experimental/bidirectional_streaming/agent/agent.py @@ -1,13 +1,13 @@ """Bidirectional Agent for real-time streaming conversations. Provides real-time audio and text interaction through persistent streaming sessions. -Unlike traditional request-response patterns, this agent maintains long-running -conversations where users can interrupt, provide additional input, and receive +Unlike traditional request-response patterns, this agent maintains long-running +conversations where users can interrupt, provide additional input, and receive continuous responses including audio output. Key capabilities: - Persistent conversation sessions with concurrent processing -- Real-time audio input/output streaming +- Real-time audio input/output streaming - Mid-conversation interruption and tool execution - Event-driven communication with model providers """ @@ -16,10 +16,9 @@ import logging from typing import AsyncIterable, List, Optional, Union -from strands.tools.executors import ConcurrentToolExecutor -from strands.tools.registry import ToolRegistry -from strands.types.content import Messages - +from ....tools.executors import ConcurrentToolExecutor +from ....tools.registry import ToolRegistry +from ....types.content import Messages from ..event_loop.bidirectional_event_loop import start_bidirectional_connection, stop_bidirectional_connection from ..models.bidirectional_model import BidirectionalModel from ..types.bidirectional_streaming import AudioInputEvent, BidirectionalStreamEvent @@ -30,20 +29,20 @@ class BidirectionalAgent: """Agent for bidirectional streaming conversations. - + Enables real-time audio and text interaction with AI models through persistent sessions. Supports concurrent tool execution and interruption handling. """ - + def __init__( self, model: BidirectionalModel, tools: Optional[List] = None, system_prompt: Optional[str] = None, - messages: Optional[Messages] = None + messages: Optional[Messages] = None, ): """Initialize bidirectional agent with required model and optional configuration. - + Args: model: BidirectionalModel instance supporting streaming sessions. tools: Optional list of tools available to the model. @@ -53,51 +52,51 @@ def __init__( self.model = model self.system_prompt = system_prompt self.messages = messages or [] - + # Initialize tool registry using existing Strands infrastructure self.tool_registry = ToolRegistry() if tools: self.tool_registry.process_tools(tools) self.tool_registry.initialize_tools() - + # Initialize tool executor for concurrent execution self.tool_executor = ConcurrentToolExecutor() - + # Session management self._session = None self._output_queue = asyncio.Queue() - + async def start(self) -> None: """Start a persistent bidirectional conversation session. - + Initializes the streaming session and starts background tasks for processing model events, tool execution, and session management. - + Raises: ValueError: If conversation already active. ConnectionError: If session creation fails. """ if self._session and self._session.active: raise ValueError("Conversation already active. Call end() first.") - + log_flow("conversation_start", "initializing session") self._session = await start_bidirectional_connection(self) log_event("conversation_ready") - + async def send(self, input_data: Union[str, AudioInputEvent]) -> None: """Send input to the model (text or audio). - + Unified method for sending both text and audio input to the model during an active conversation session. - + Args: input_data: Either a string for text input or AudioInputEvent for audio input. - + Raises: ValueError: If no active session or invalid input type. """ self._validate_active_session() - + if isinstance(input_data, str): # Handle text input log_event("text_sent", length=len(input_data)) @@ -110,15 +109,13 @@ async def send(self, input_data: Union[str, AudioInputEvent]) -> None: "Input must be either a string (text) or AudioInputEvent " "(dict with audioData, format, sampleRate, channels)" ) - - async def receive(self) -> AsyncIterable[BidirectionalStreamEvent]: """Receive events from the model including audio, text, and tool calls. - + Yields model output events processed by background tasks including audio output, text responses, tool calls, and session updates. - + Yields: BidirectionalStreamEvent: Events from the model session. """ @@ -128,35 +125,34 @@ async def receive(self) -> AsyncIterable[BidirectionalStreamEvent]: yield event except asyncio.TimeoutError: continue - + async def interrupt(self) -> None: """Interrupt the current model generation and clear audio buffers. - - Sends interruption signal to stop generation immediately and clears + + Sends interruption signal to stop generation immediately and clears pending audio output for responsive conversation flow. - + Raises: ValueError: If no active session. """ self._validate_active_session() await self._session.model_session.send_interrupt() - + async def end(self) -> None: """End the conversation session and cleanup all resources. - - Terminates the streaming session, cancels background tasks, and + + Terminates the streaming session, cancels background tasks, and closes the connection to the model provider. """ if self._session: await stop_bidirectional_connection(self._session) self._session = None - + def _validate_active_session(self) -> None: """Validate that an active session exists. - + Raises: ValueError: If no active session. """ if not self._session or not self._session.active: raise ValueError("No active conversation. Call start() first.") - diff --git a/src/strands/experimental/bidirectional_streaming/event_loop/__init__.py b/src/strands/experimental/bidirectional_streaming/event_loop/__init__.py index 24080b703..af8c4e1e1 100644 --- a/src/strands/experimental/bidirectional_streaming/event_loop/__init__.py +++ b/src/strands/experimental/bidirectional_streaming/event_loop/__init__.py @@ -1,2 +1,15 @@ -"""Bidirectional streaming event loop package.""" -# Event Loop package \ No newline at end of file +"""Event loop management for bidirectional streaming.""" + +from .bidirectional_event_loop import ( + BidirectionalConnection, + bidirectional_event_loop_cycle, + start_bidirectional_connection, + stop_bidirectional_connection, +) + +__all__ = [ + "BidirectionalConnection", + "start_bidirectional_connection", + "stop_bidirectional_connection", + "bidirectional_event_loop_cycle", +] diff --git a/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py b/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py index 3884750d5..c90d118ff 100644 --- a/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py +++ b/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py @@ -18,10 +18,9 @@ import uuid from typing import Any, Dict -from strands.tools._validator import validate_and_prepare_tools -from strands.types.content import Message -from strands.types.tools import ToolResult, ToolUse - +from ....tools._validator import validate_and_prepare_tools +from ....types.content import Message +from ....types.tools import ToolResult, ToolUse from ..models.bidirectional_model import BidirectionalModelSession from ..utils.debug import log_event, log_flow @@ -34,14 +33,14 @@ class BidirectionalConnection: """Session wrapper for bidirectional communication with concurrent task management. - + Coordinates background tasks for model event processing, tool execution, and audio handling while providing a simple interface for agent interactions. """ - + def __init__(self, model_session: BidirectionalModelSession, agent): """Initialize session with model session and agent reference. - + Args: model_session: Provider-specific bidirectional model session. agent: BidirectionalAgent instance for tool registry access. @@ -49,96 +48,93 @@ def __init__(self, model_session: BidirectionalModelSession, agent): self.model_session = model_session self.agent = agent self.active = True - + # Background processing coordination self.background_tasks = [] self.tool_queue = asyncio.Queue() self.audio_output_queue = asyncio.Queue() - + # Task management for cleanup self.pending_tool_tasks: Dict[str, asyncio.Task] = {} - + # Interruption handling (model-agnostic) self.interrupted = False -async def start_bidirectional_connection(agent) -> BidirectionalConnection: + +async def start_bidirectional_connection(agent: "BidirectionalAgent") -> BidirectionalConnection: """Initialize bidirectional session with concurrent background tasks. - + Creates a model-specific session and starts background tasks for processing model events, executing tools, and managing the session lifecycle. - + Args: agent: BidirectionalAgent instance. - + Returns: BidirectionalConnection: Active session with background tasks running. - """ + """ log_flow("session_start", "initializing model session") - + # Create provider-specific session model_session = await agent.model.create_bidirectional_connection( - system_prompt=agent.system_prompt, - tools=agent.tool_registry.get_all_tool_specs(), - messages=agent.messages + system_prompt=agent.system_prompt, tools=agent.tool_registry.get_all_tool_specs(), messages=agent.messages ) - + # Create session wrapper for background processing session = BidirectionalConnection(model_session=model_session, agent=agent) - + # Start concurrent background processors IMMEDIATELY after session creation # This is critical - Nova Sonic needs response processing during initialization log_flow("background_tasks", "starting processors") session.background_tasks = [ - asyncio.create_task(_process_model_events(session)), # Handle model responses - asyncio.create_task(_process_tool_execution(session)) # Execute tools concurrently + asyncio.create_task(_process_model_events(session)), # Handle model responses + asyncio.create_task(_process_tool_execution(session)), # Execute tools concurrently ] - + # Start main coordination cycle - session.main_cycle_task = asyncio.create_task( - bidirectional_event_loop_cycle(session) - ) - + session.main_cycle_task = asyncio.create_task(bidirectional_event_loop_cycle(session)) + # Give background tasks a moment to start await asyncio.sleep(0.1) log_event("session_ready", tasks=len(session.background_tasks)) - + return session async def stop_bidirectional_connection(session: BidirectionalConnection) -> None: """End session and cleanup resources including background tasks. - + Args: session: BidirectionalConnection to cleanup. """ if not session.active: return - + log_flow("session_cleanup", "starting") session.active = False - + # Cancel pending tool tasks for _, task in session.pending_tool_tasks.items(): if not task.done(): task.cancel() - + # Cancel background tasks for task in session.background_tasks: if not task.done(): task.cancel() - + # Cancel main cycle task - if hasattr(session, 'main_cycle_task') and not session.main_cycle_task.done(): + if hasattr(session, "main_cycle_task") and not session.main_cycle_task.done(): session.main_cycle_task.cancel() - + # Wait for tasks to complete all_tasks = session.background_tasks + list(session.pending_tool_tasks.values()) - if hasattr(session, 'main_cycle_task'): + if hasattr(session, "main_cycle_task"): all_tasks.append(session.main_cycle_task) - + if all_tasks: await asyncio.gather(*all_tasks, return_exceptions=True) - + # Close model session await session.model_session.close() log_event("session_closed") @@ -146,10 +142,10 @@ async def stop_bidirectional_connection(session: BidirectionalConnection) -> Non async def bidirectional_event_loop_cycle(session: BidirectionalConnection) -> None: """Main event loop coordinator that runs continuously during the session. - + Monitors background tasks, manages session state, and handles session lifecycle. Provides supervision for concurrent model event processing and tool execution. - + Args: session: BidirectionalConnection to coordinate. """ @@ -160,7 +156,7 @@ async def bidirectional_event_loop_cycle(session: BidirectionalConnection) -> No log_event("session_end", reason="all_processors_completed") session.active = False break - + # Check for failed background tasks for i, task in enumerate(session.background_tasks): if task.done() and not task.cancelled(): @@ -169,10 +165,10 @@ async def bidirectional_event_loop_cycle(session: BidirectionalConnection) -> No log_event("session_error", processor=i, error=str(exception)) session.active = False raise exception - + # Brief pause before next supervision check await asyncio.sleep(SUPERVISION_INTERVAL) - + except asyncio.CancelledError: break except Exception as e: @@ -183,16 +179,16 @@ async def bidirectional_event_loop_cycle(session: BidirectionalConnection) -> No async def _handle_interruption(session: BidirectionalConnection) -> None: """Handle interruption detection with task cancellation and audio buffer clearing. - + Cancels pending tool tasks and clears audio output queues to ensure responsive interruption handling during conversations. - + Args: session: BidirectionalConnection to handle interruption for. """ log_event("interruption_detected") session.interrupted = True - + # 🔥 CANCEL ALL PENDING TOOL TASKS (Nova Sonic pattern) cancelled_tools = 0 for task_id, task in list(session.pending_tool_tasks.items()): @@ -200,10 +196,10 @@ async def _handle_interruption(session: BidirectionalConnection) -> None: task.cancel() cancelled_tools += 1 log_event("tool_task_cancelled", task_id=task_id) - + if cancelled_tools > 0: log_event("tool_tasks_cancelled", count=cancelled_tools) - + # 🔥 AGGRESSIVELY CLEAR AUDIO OUTPUT QUEUE (Nova Sonic pattern) cleared_count = 0 while True: @@ -212,9 +208,9 @@ async def _handle_interruption(session: BidirectionalConnection) -> None: cleared_count += 1 except asyncio.QueueEmpty: break - + # Also clear the agent's audio output queue if it exists - if hasattr(session.agent, '_output_queue'): + if hasattr(session.agent, "_output_queue"): audio_cleared = 0 # Create a temporary list to hold non-audio events temp_events = [] @@ -228,20 +224,20 @@ async def _handle_interruption(session: BidirectionalConnection) -> None: temp_events.append(event) except asyncio.QueueEmpty: pass - + # Put back non-audio events for event in temp_events: session.agent._output_queue.put_nowait(event) - + if audio_cleared > 0: log_event("agent_audio_queue_cleared", count=audio_cleared) - + if cleared_count > 0: log_event("session_audio_queue_cleared", count=cleared_count) - + # Brief sleep to allow audio system to settle (matches Nova Sonic timing) await asyncio.sleep(0.05) - + # Reset interruption flag after clearing (automatic recovery) session.interrupted = False log_event("interruption_handled", tools_cancelled=cancelled_tools, audio_cleared=cleared_count) @@ -249,10 +245,10 @@ async def _handle_interruption(session: BidirectionalConnection) -> None: async def _process_model_events(session: BidirectionalConnection) -> None: """Process model events and convert them to Strands format. - + Background task that handles all model responses, converts provider-specific events to standardized formats, and manages interruption detection. - + Args: session: BidirectionalConnection containing model session. """ @@ -261,10 +257,10 @@ async def _process_model_events(session: BidirectionalConnection) -> None: async for provider_event in session.model_session.receive_events(): if not session.active: break - + # Convert provider events to Strands format strands_event = _convert_to_strands_event(provider_event) - + # Handle interruption detection (multiple patterns) if strands_event.get("interruptionDetected"): log_event("interruption_forwarded") @@ -272,7 +268,7 @@ async def _process_model_events(session: BidirectionalConnection) -> None: # Forward interruption event to agent for application-level handling await session.agent._output_queue.put(strands_event) continue - + # Check for text-based interruption (Nova Sonic pattern) if strands_event.get("textOutput"): text_content = strands_event["textOutput"].get("content", "") @@ -282,22 +278,22 @@ async def _process_model_events(session: BidirectionalConnection) -> None: # Still forward the text event await session.agent._output_queue.put(strands_event) continue - + # Queue tool requests for concurrent execution if strands_event.get("toolUse"): log_event("tool_queued", name=strands_event["toolUse"].get("name")) await session.tool_queue.put(strands_event["toolUse"]) continue - + # Send output events to Agent for receive() method if strands_event.get("audioOutput") or strands_event.get("textOutput"): await session.agent._output_queue.put(strands_event) - + # Update Agent conversation history using existing patterns if strands_event.get("messageStop"): log_event("message_added_to_history") session.agent.messages.append(strands_event["messageStop"]["message"]) - + except Exception as e: log_event("model_events_error", error=str(e)) traceback.print_exc() @@ -307,11 +303,11 @@ async def _process_model_events(session: BidirectionalConnection) -> None: async def _process_tool_execution(session: BidirectionalConnection) -> None: """Execute tools concurrently with interruption support. - + Background task that manages tool execution without blocking model event processing or user interaction. Includes proper task cleanup and cancellation handling for interruptions. - + Args: session: BidirectionalConnection containing tool queue. """ @@ -320,143 +316,136 @@ async def _process_tool_execution(session: BidirectionalConnection) -> None: try: tool_use = await asyncio.wait_for(session.tool_queue.get(), timeout=TOOL_QUEUE_TIMEOUT) log_event("tool_execution_started", name=tool_use.get("name"), id=tool_use.get("toolUseId")) - + if not session.active: break - + task_id = str(uuid.uuid4()) task = asyncio.create_task(_execute_tool_with_strands(session, tool_use)) session.pending_tool_tasks[task_id] = task - + # 🔥 ADD CLEANUP CALLBACK (Nova Sonic pattern) def cleanup_task(completed_task): try: # Remove from pending tasks if task_id in session.pending_tool_tasks: del session.pending_tool_tasks[task_id] - + # Log completion status if completed_task.cancelled(): log_event("tool_task_cleanup_cancelled", task_id=task_id) elif completed_task.exception(): - log_event("tool_task_cleanup_error", task_id=task_id, - error=str(completed_task.exception())) + log_event("tool_task_cleanup_error", task_id=task_id, error=str(completed_task.exception())) else: log_event("tool_task_cleanup_success", task_id=task_id) except Exception as e: log_event("tool_task_cleanup_failed", task_id=task_id, error=str(e)) - + task.add_done_callback(cleanup_task) - + except asyncio.TimeoutError: if not session.active: break # 🔥 PERIODIC CLEANUP OF COMPLETED TASKS - completed_tasks = [ - task_id for task_id, task in session.pending_tool_tasks.items() - if task.done() - ] + completed_tasks = [task_id for task_id, task in session.pending_tool_tasks.items() if task.done()] for task_id in completed_tasks: if task_id in session.pending_tool_tasks: del session.pending_tool_tasks[task_id] - + if completed_tasks: log_event("periodic_task_cleanup", count=len(completed_tasks)) - + continue except Exception as e: log_event("tool_execution_error", error=str(e)) if not session.active: break - + log_flow("tool_execution", "processor stopped") def _convert_to_strands_event(provider_event: Dict) -> Dict: """Pass-through for events already normalized by provider sessions. - + Providers convert their raw events to standard format before reaching here. This just validates and passes through the normalized events. - + Args: provider_event: Already normalized event from provider session. - + Returns: Dict: The same event, validated and passed through. """ # Basic validation - ensure we have a dict if not isinstance(provider_event, dict): return {} - + # Pass through - conversion already done by provider session return provider_event async def _execute_tool_with_strands(session: BidirectionalConnection, tool_use: Dict) -> None: """Execute tool using Strands infrastructure with interruption support. - + Executes tools using the existing Strands tool system, handles interruption during execution, and sends results back to the model provider. - + Args: session: BidirectionalConnection for context. tool_use: Tool use event to execute. """ - tool_name = tool_use.get('name') - tool_id = tool_use.get('toolUseId') - + tool_name = tool_use.get("name") + tool_id = tool_use.get("toolUseId") + try: # 🔥 CHECK FOR INTERRUPTION BEFORE STARTING (Nova Sonic pattern) if session.interrupted or not session.active: log_event("tool_execution_cancelled_before_start", name=tool_name, id=tool_id) return - + # Create message structure for existing tool system - tool_message: Message = { - "role": "assistant", - "content": [{"toolUse": tool_use}] - } - + tool_message: Message = {"role": "assistant", "content": [{"toolUse": tool_use}]} + tool_uses: list[ToolUse] = [] tool_results: list[ToolResult] = [] invalid_tool_use_ids: list[str] = [] - + # Validate using existing Strands validation validate_and_prepare_tools(tool_message, tool_uses, tool_results, invalid_tool_use_ids) - + # Filter valid tool uses valid_tool_uses = [tu for tu in tool_uses if tu.get("toolUseId") not in invalid_tool_use_ids] - + if not valid_tool_uses: log_event("tool_validation_failed", name=tool_name, id=tool_id) return - + # Execute tools directly (simpler approach for bidirectional) for tool_use in valid_tool_uses: # 🔥 CHECK FOR INTERRUPTION DURING EXECUTION if session.interrupted or not session.active: log_event("tool_execution_cancelled_during", name=tool_name, id=tool_id) return - + tool_func = session.agent.tool_registry.registry.get(tool_use["name"]) - + if tool_func: try: actual_func = _extract_callable_function(tool_func) - + # 🔥 WRAP TOOL EXECUTION IN CANCELLATION CHECK # For async tools, we could wrap with asyncio.wait_for with cancellation # For sync tools, we execute directly but check interruption after result = actual_func(**tool_use.get("input", {})) - + # 🔥 CHECK FOR INTERRUPTION AFTER TOOL EXECUTION if session.interrupted or not session.active: log_event("tool_result_discarded_interruption", name=tool_name, id=tool_id) return - + tool_result = _create_success_result(tool_use["toolUseId"], result) tool_results.append(tool_result) - + except asyncio.CancelledError: # Tool was cancelled due to interruption log_event("tool_execution_cancelled", name=tool_name, id=tool_id) @@ -466,50 +455,44 @@ async def _execute_tool_with_strands(session: BidirectionalConnection, tool_use: if session.interrupted or not session.active: log_event("tool_error_discarded_interruption", name=tool_name, id=tool_id) return - + log_event("tool_execution_failed", name=tool_name, error=str(e)) tool_result = _create_error_result(tool_use["toolUseId"], str(e)) tool_results.append(tool_result) else: log_event("tool_not_found", name=tool_name) - + # 🔥 FINAL INTERRUPTION CHECK BEFORE SENDING RESULTS if session.interrupted or not session.active: log_event("tool_results_discarded_interruption", name=tool_name, count=len(tool_results)) return - + # Send results through provider-specific session for result in tool_results: - await session.model_session.send_tool_result( - tool_use.get("toolUseId"), - result - ) - + await session.model_session.send_tool_result(tool_use.get("toolUseId"), result) + log_event("tool_execution_completed", name=tool_name, results=len(tool_results)) - + except asyncio.CancelledError: # Task was cancelled due to interruption - this is expected behavior log_event("tool_task_cancelled_gracefully", name=tool_name, id=tool_id) raise # Re-raise to properly handle cancellation except Exception as e: - log_event("tool_execution_error", name=tool_use.get('name'), error=str(e)) - + log_event("tool_execution_error", name=tool_use.get("name"), error=str(e)) + # Only send error if not interrupted if not session.interrupted and session.active: try: - await session.model_session.send_tool_error( - tool_use.get("toolUseId"), - str(e) - ) + await session.model_session.send_tool_error(tool_use.get("toolUseId"), str(e)) except Exception as send_error: log_event("tool_error_send_failed", error=str(send_error)) def _extract_callable_function(tool_func): """Extract the callable function from different tool object types.""" - if hasattr(tool_func, '_tool_func'): + if hasattr(tool_func, "_tool_func"): return tool_func._tool_func - elif hasattr(tool_func, 'func'): + elif hasattr(tool_func, "func"): return tool_func.func elif callable(tool_func): return tool_func @@ -519,17 +502,9 @@ def _extract_callable_function(tool_func): def _create_success_result(tool_use_id: str, result) -> Dict[str, Any]: """Create a successful tool result.""" - return { - "toolUseId": tool_use_id, - "status": "success", - "content": [{"text": json.dumps(result)}] - } + return {"toolUseId": tool_use_id, "status": "success", "content": [{"text": json.dumps(result)}]} def _create_error_result(tool_use_id: str, error: str) -> Dict[str, Any]: """Create an error tool result.""" - return { - "toolUseId": tool_use_id, - "status": "error", - "content": [{"text": f"Error: {error}"}] - } \ No newline at end of file + return {"toolUseId": tool_use_id, "status": "error", "content": [{"text": f"Error: {error}"}]} diff --git a/src/strands/experimental/bidirectional_streaming/models/__init__.py b/src/strands/experimental/bidirectional_streaming/models/__init__.py index b2b10a5f2..6cba974e0 100644 --- a/src/strands/experimental/bidirectional_streaming/models/__init__.py +++ b/src/strands/experimental/bidirectional_streaming/models/__init__.py @@ -1,2 +1,6 @@ -"""Bidirectional streaming models package.""" -# Models package \ No newline at end of file +"""Bidirectional model interfaces and implementations.""" + +from .bidirectional_model import BidirectionalModel, BidirectionalModelSession +from .novasonic import NovaSonicBidirectionalModel, NovaSonicSession + +__all__ = ["BidirectionalModel", "BidirectionalModelSession", "NovaSonicBidirectionalModel", "NovaSonicSession"] diff --git a/src/strands/experimental/bidirectional_streaming/models/bidirectional_model.py b/src/strands/experimental/bidirectional_streaming/models/bidirectional_model.py index 81e5cd9d6..cc803458b 100644 --- a/src/strands/experimental/bidirectional_streaming/models/bidirectional_model.py +++ b/src/strands/experimental/bidirectional_streaming/models/bidirectional_model.py @@ -7,7 +7,7 @@ Features: - connection-based persistent connections - Real-time bidirectional communication -- Provider-agnostic event normalization +- Provider-agnostic event normalization - Tool execution integration """ @@ -21,63 +21,64 @@ logger = logging.getLogger(__name__) + class BidirectionalModelSession(abc.ABC): """Abstract interface for model-specific bidirectional communication connections. - + Defines the contract for managing persistent streaming connections with individual model providers, handling audio/text input, receiving events, and managing tool execution results. """ - + @abc.abstractmethod async def receive_events(self) -> AsyncIterable[Dict[str, Any]]: """Receive events from the model in standardized format. - + Converts provider-specific events to a common format that can be processed uniformly by the event loop. """ raise NotImplementedError - + @abc.abstractmethod async def send_audio_content(self, audio_input: AudioInputEvent) -> None: """Send audio content to the model during an active connection. - + Handles audio encoding and provider-specific formatting while presenting a simple AudioInputEvent interface. """ raise NotImplementedError - + @abc.abstractmethod async def send_text_content(self, text: str, **kwargs) -> None: """Send text content to the model during ongoing generation. - + Allows natural interruption and follow-up questions without requiring connection restart. """ raise NotImplementedError - + @abc.abstractmethod async def send_interrupt(self) -> None: """Send interruption signal to stop generation immediately. - + Enables responsive conversational experiences where users can naturally interrupt during model responses. """ raise NotImplementedError - + @abc.abstractmethod async def send_tool_result(self, tool_use_id: str, result: Dict[str, Any]) -> None: """Send tool execution result to the model. - + Formats and sends tool results according to the provider's specific protocol. """ raise NotImplementedError - + @abc.abstractmethod async def send_tool_error(self, tool_use_id: str, error: str) -> None: """Send tool execution error to model in provider-specific format.""" raise NotImplementedError - + @abc.abstractmethod async def close(self) -> None: """Close the connection and cleanup resources.""" @@ -86,23 +87,22 @@ async def close(self) -> None: class BidirectionalModel(abc.ABC): """Interface for models that support bidirectional streaming. - + Defines the contract for creating persistent streaming connections that support real-time audio and text communication with AI models. """ - + @abc.abstractmethod async def create_bidirectional_connection( self, system_prompt: Optional[str] = None, tools: Optional[List[ToolSpec]] = None, messages: Optional[Messages] = None, - **kwargs + **kwargs, ) -> BidirectionalModelSession: """Create a bidirectional connection with the model. - + Establishes a persistent connection for real-time communication while abstracting provider-specific initialization requirements. """ raise NotImplementedError - diff --git a/src/strands/experimental/bidirectional_streaming/models/novasonic.py b/src/strands/experimental/bidirectional_streaming/models/novasonic.py index 4332181b5..0efd2413c 100644 --- a/src/strands/experimental/bidirectional_streaming/models/novasonic.py +++ b/src/strands/experimental/bidirectional_streaming/models/novasonic.py @@ -1,7 +1,7 @@ """Nova Sonic bidirectional model provider for real-time streaming conversations. Implements the BidirectionalModel interface for Amazon's Nova Sonic, handling the -complex event sequencing and audio processing required by Nova Sonic's +complex event sequencing and audio processing required by Nova Sonic's InvokeModelWithBidirectionalStream protocol. Nova Sonic specifics: @@ -42,11 +42,7 @@ logger = logging.getLogger(__name__) # Nova Sonic configuration constants -NOVA_INFERENCE_CONFIG = { - "maxTokens": 1024, - "topP": 0.9, - "temperature": 0.7 -} +NOVA_INFERENCE_CONFIG = {"maxTokens": 1024, "topP": 0.9, "temperature": 0.7} NOVA_AUDIO_INPUT_CONFIG = { "mediaType": "audio/lpcm", @@ -54,7 +50,7 @@ "sampleSizeBits": 16, "channelCount": 1, "audioType": "SPEECH", - "encoding": "base64" + "encoding": "base64", } NOVA_AUDIO_OUTPUT_CONFIG = { @@ -64,7 +60,7 @@ "channelCount": 1, "voiceId": "matthew", "encoding": "base64", - "audioType": "SPEECH" + "audioType": "SPEECH", } NOVA_TEXT_CONFIG = {"mediaType": "text/plain"} @@ -78,15 +74,15 @@ class NovaSonicSession(BidirectionalModelSession): """Nova Sonic connection implementation handling the provider's specific protocol. - + Manages Nova Sonic's complex event sequencing, audio format conversion, and tool execution patterns while providing the standard BidirectionalModelSession interface. """ - + def __init__(self, stream, config: Dict[str, Any]): """Initialize Nova Sonic connection. - + Args: stream: Nova Sonic bidirectional stream. config: Model configuration. @@ -95,80 +91,78 @@ def __init__(self, stream, config: Dict[str, Any]): self.config = config self.prompt_name = str(uuid.uuid4()) self._active = True - + # Nova Sonic requires unique content names self.audio_content_name = str(uuid.uuid4()) self.text_content_name = str(uuid.uuid4()) - + # Audio connection state self.audio_connection_active = False self.last_audio_time = None self.silence_threshold = SILENCE_THRESHOLD self.silence_task = None - + # Validate stream if not stream: logger.error("Stream is None") raise ValueError("Stream cannot be None") - + logger.debug("Nova Sonic connection initialized with prompt: %s", self.prompt_name) - + async def initialize( self, system_prompt: Optional[str] = None, tools: Optional[List[ToolSpec]] = None, - messages: Optional[Messages] = None + messages: Optional[Messages] = None, ) -> None: """Initialize Nova Sonic connection with required protocol sequence.""" try: system_prompt = system_prompt or "You are a helpful assistant. Keep responses brief." - + init_events = self._build_initialization_events(system_prompt, tools or [], messages) - + log_flow("nova_init", f"sending {len(init_events)} events") await self._send_initialization_events(init_events) - + log_event("nova_connection_initialized") self._response_task = asyncio.create_task(self._process_responses()) - + except Exception as e: logger.error("Error during Nova Sonic initialization: %s", e) raise - - def _build_initialization_events(self, system_prompt: str, tools: List[ToolSpec], - messages: Optional[Messages]) -> List[str]: + + def _build_initialization_events( + self, system_prompt: str, tools: List[ToolSpec], messages: Optional[Messages] + ) -> List[str]: """Build the sequence of initialization events.""" - events = [ - self._get_connection_start_event(), - self._get_prompt_start_event(tools) - ] - + events = [self._get_connection_start_event(), self._get_prompt_start_event(tools)] + events.extend(self._get_system_prompt_events(system_prompt)) - + # Message history would be processed here if needed in the future # Currently not implemented as it's not used in the existing test cases - + return events - + async def _send_initialization_events(self, events: List[str]) -> None: """Send initialization events with required delays.""" for i, event in enumerate(events): - await time_it_async(f"send_init_event_{i+1}", lambda: self._send_nova_event(event)) + await time_it_async(f"send_init_event_{i + 1}", lambda: self._send_nova_event(event)) await asyncio.sleep(EVENT_DELAY) - + async def _process_responses(self) -> None: """Process Nova Sonic responses continuously.""" log_flow("nova_responses", "processor started") - + try: while self._active: try: output = await asyncio.wait_for(self.stream.await_output(), timeout=RESPONSE_TIMEOUT) result = await output[1].receive() - + if result.value and result.value.bytes_: - await self._handle_response_data(result.value.bytes_.decode('utf-8')) - + await self._handle_response_data(result.value.bytes_.decode("utf-8")) + except asyncio.TimeoutError: await asyncio.sleep(0.1) continue @@ -176,39 +170,39 @@ async def _process_responses(self) -> None: log_event("nova_response_error", error=str(e)) await asyncio.sleep(0.1) continue - + except Exception as e: log_event("nova_fatal_error", error=str(e)) finally: log_flow("nova_responses", "processor stopped") - + async def _handle_response_data(self, response_data: str) -> None: """Handle decoded response data from Nova Sonic.""" try: json_data = json.loads(response_data) - - if 'event' in json_data: - nova_event = json_data['event'] + + if "event" in json_data: + nova_event = json_data["event"] self._log_event_type(nova_event) - - if not hasattr(self, '_event_queue'): + + if not hasattr(self, "_event_queue"): self._event_queue = asyncio.Queue() - + await self._event_queue.put(nova_event) except json.JSONDecodeError as e: log_event("nova_json_error", error=str(e)) - + def _log_event_type(self, nova_event: Dict[str, Any]) -> None: """Log specific Nova Sonic event types for debugging.""" - if 'usageEvent' in nova_event: - log_event("nova_usage", usage=nova_event['usageEvent']) - elif 'textOutput' in nova_event: + if "usageEvent" in nova_event: + log_event("nova_usage", usage=nova_event["usageEvent"]) + elif "textOutput" in nova_event: log_event("nova_text_output") - elif 'toolUse' in nova_event: - tool_use = nova_event['toolUse'] - log_event("nova_tool_use", name=tool_use['toolName'], id=tool_use['toolUseId']) - elif 'audioOutput' in nova_event: - audio_content = nova_event['audioOutput']['content'] + elif "toolUse" in nova_event: + tool_use = nova_event["toolUse"] + log_event("nova_tool_use", name=tool_use["toolName"], id=tool_use["toolUseId"]) + elif "audioOutput" in nova_event: + audio_content = nova_event["audioOutput"]["content"] audio_bytes = base64.b64decode(audio_content) log_event("nova_audio_output", bytes=len(audio_bytes)) @@ -217,37 +211,35 @@ async def receive_events(self) -> AsyncIterable[Dict[str, Any]]: if not self.stream: logger.error("Stream is None") return - + log_flow("nova_events", "starting event stream") - + # Emit connection start event to Strands event system connection_start: BidirectionalConnectionStartEvent = { "connectionId": self.prompt_name, - "metadata": {"provider": "nova_sonic", "model_id": self.config.get("model_id")} - } - yield { - "BidirectionalConnectionStart": connection_start + "metadata": {"provider": "nova_sonic", "model_id": self.config.get("model_id")}, } - + yield {"BidirectionalConnectionStart": connection_start} + # Initialize event queue if not already done - if not hasattr(self, '_event_queue'): + if not hasattr(self, "_event_queue"): self._event_queue = asyncio.Queue() - + try: while self._active: try: # Get events from the queue populated by _process_responses nova_event = await asyncio.wait_for(self._event_queue.get(), timeout=1.0) - + # Convert to provider-agnostic format provider_event = self._convert_nova_event(nova_event) if provider_event: yield provider_event - + except asyncio.TimeoutError: # No events in queue - continue waiting continue - + except Exception as e: logger.error("Error receiving Nova Sonic event: %s", e) logger.error(traceback.format_exc()) @@ -256,68 +248,70 @@ async def receive_events(self) -> AsyncIterable[Dict[str, Any]]: connection_end: BidirectionalConnectionEndEvent = { "connectionId": self.prompt_name, "reason": "connection_complete", - "metadata": {"provider": "nova_sonic"} + "metadata": {"provider": "nova_sonic"}, } - yield { - "BidirectionalConnectionEnd": connection_end - } - + yield {"BidirectionalConnectionEnd": connection_end} + async def start_audio_connection(self) -> None: """Start audio input connection (call once before sending audio chunks).""" if self.audio_connection_active: return - + log_event("nova_audio_connection_start") - - audio_content_start = json.dumps({ - "event": { - "contentStart": { - "promptName": self.prompt_name, - "contentName": self.audio_content_name, - "type": "AUDIO", - "interactive": True, - "role": "USER", - "audioInputConfiguration": NOVA_AUDIO_INPUT_CONFIG + + audio_content_start = json.dumps( + { + "event": { + "contentStart": { + "promptName": self.prompt_name, + "contentName": self.audio_content_name, + "type": "AUDIO", + "interactive": True, + "role": "USER", + "audioInputConfiguration": NOVA_AUDIO_INPUT_CONFIG, + } } } - }) - + ) + await self._send_nova_event(audio_content_start) self.audio_connection_active = True - + async def send_audio_content(self, audio_input: AudioInputEvent) -> None: """Send audio using Nova Sonic protocol-specific format.""" if not self._active: return - + # Start audio connection if not already active if not self.audio_connection_active: await self.start_audio_connection() - + # Update last audio time and cancel any pending silence task self.last_audio_time = time.time() if self.silence_task and not self.silence_task.done(): self.silence_task.cancel() - + # Convert audio to Nova Sonic base64 format - nova_audio_data = base64.b64encode(audio_input["audioData"]).decode('utf-8') - + nova_audio_data = base64.b64encode(audio_input["audioData"]).decode("utf-8") + # Send audio input event - audio_event = json.dumps({ - "event": { - "audioInput": { - "promptName": self.prompt_name, - "contentName": self.audio_content_name, - "content": nova_audio_data + audio_event = json.dumps( + { + "event": { + "audioInput": { + "promptName": self.prompt_name, + "contentName": self.audio_content_name, + "content": nova_audio_data, + } } } - }) - + ) + await self._send_nova_event(audio_event) - + # Start silence detection task self.silence_task = asyncio.create_task(self._check_silence()) - + async def _check_silence(self): """Check for silence and automatically end audio connection.""" try: @@ -329,226 +323,195 @@ async def _check_silence(self): await self.end_audio_input() except asyncio.CancelledError: pass - + async def end_audio_input(self) -> None: """End current audio input connection to trigger Nova Sonic processing.""" if not self.audio_connection_active: return - + log_event("nova_audio_connection_end") - - audio_content_end = json.dumps({ - "event": { - "contentEnd": { - "promptName": self.prompt_name, - "contentName": self.audio_content_name - } - } - }) - + + audio_content_end = json.dumps( + {"event": {"contentEnd": {"promptName": self.prompt_name, "contentName": self.audio_content_name}}} + ) + await self._send_nova_event(audio_content_end) self.audio_connection_active = False - + async def send_text_content(self, text: str, **kwargs) -> None: """Send text content using Nova Sonic format.""" if not self._active: return - + content_name = str(uuid.uuid4()) events = [ self._get_text_content_start_event(content_name), self._get_text_input_event(content_name, text), - self._get_content_end_event(content_name) + self._get_content_end_event(content_name), ] - + for event in events: await self._send_nova_event(event) - + async def send_interrupt(self) -> None: """Send interruption signal to Nova Sonic.""" if not self._active: return - + # Nova Sonic handles interruption through special input events interrupt_event = { "event": { "audioInput": { "promptName": self.prompt_name, "contentName": self.audio_content_name, - "stopReason": "INTERRUPTED" + "stopReason": "INTERRUPTED", } } } await self._send_nova_event(interrupt_event) - + async def send_tool_result(self, tool_use_id: str, result: Dict[str, Any]) -> None: """Send tool result using Nova Sonic toolResult format.""" if not self._active: return - + log_event("nova_tool_result_send", id=tool_use_id) content_name = str(uuid.uuid4()) events = [ self._get_tool_content_start_event(content_name, tool_use_id), self._get_tool_result_event(content_name, result), - self._get_content_end_event(content_name) + self._get_content_end_event(content_name), ] - + for i, event in enumerate(events): - await time_it_async(f"send_tool_event_{i+1}", lambda: self._send_nova_event(event)) - + await time_it_async(f"send_tool_event_{i + 1}", lambda: self._send_nova_event(event)) + async def send_tool_error(self, tool_use_id: str, error: str) -> None: """Send tool error using Nova Sonic format.""" log_event("nova_tool_error_send", id=tool_use_id, error=error) error_result = {"error": error} await self.send_tool_result(tool_use_id, error_result) - + async def close(self) -> None: """Close Nova Sonic connection with proper cleanup sequence.""" if not self._active: return - + log_flow("nova_cleanup", "starting connection close") self._active = False - + # Cancel response processing task if running - if hasattr(self, '_response_task') and not self._response_task.done(): + if hasattr(self, "_response_task") and not self._response_task.done(): self._response_task.cancel() try: await self._response_task except asyncio.CancelledError: pass - + try: # End audio connection if active if self.audio_connection_active: await self.end_audio_input() - + # Send cleanup events - cleanup_events = [ - self._get_prompt_end_event(), - self._get_connection_end_event() - ] - + cleanup_events = [self._get_prompt_end_event(), self._get_connection_end_event()] + for event in cleanup_events: try: await self._send_nova_event(event) except Exception as e: logger.warning("Error during Nova Sonic cleanup: %s", e) - + # Close stream try: await self.stream.input_stream.close() except Exception as e: logger.warning("Error closing Nova Sonic stream: %s", e) - + except Exception as e: log_event("nova_cleanup_error", error=str(e)) finally: log_event("nova_connection_closed") - + def _convert_nova_event(self, nova_event: Dict[str, Any]) -> Optional[Dict[str, Any]]: """Convert Nova Sonic events to provider-agnostic format.""" # Handle audio output if "audioOutput" in nova_event: audio_content = nova_event["audioOutput"]["content"] audio_bytes = base64.b64decode(audio_content) - + audio_output: AudioOutputEvent = { "audioData": audio_bytes, "format": "pcm", "sampleRate": 24000, "channels": 1, - "encoding": "base64" - } - - return { - "audioOutput": audio_output + "encoding": "base64", } - + + return {"audioOutput": audio_output} + # Handle text output elif "textOutput" in nova_event: text_content = nova_event["textOutput"]["content"] # Use stored role from contentStart event, fallback to event role - role = getattr(self, '_current_role', nova_event["textOutput"].get("role", "assistant")) - + role = getattr(self, "_current_role", nova_event["textOutput"].get("role", "assistant")) + # Check for Nova Sonic interruption pattern (matches working sample) if '{ "interrupted" : true }' in text_content: log_event("nova_interruption_in_text") - interruption: InterruptionDetectedEvent = { - "reason": "user_input" - } - return { - "interruptionDetected": interruption - } - + interruption: InterruptionDetectedEvent = {"reason": "user_input"} + return {"interruptionDetected": interruption} + # Show transcription for user speech - ALWAYS show these regardless of DEBUG flag if role == "USER": print(f"User: {text_content}") elif role == "ASSISTANT": print(f"Assistant: {text_content}") - - text_output: TextOutputEvent = { - "text": text_content, - "role": role.lower() - } - - return { - "textOutput": text_output - } - + + text_output: TextOutputEvent = {"text": text_content, "role": role.lower()} + + return {"textOutput": text_output} + # Handle tool use elif "toolUse" in nova_event: tool_use = nova_event["toolUse"] - + tool_use_event: ToolUse = { "toolUseId": tool_use["toolUseId"], "name": tool_use["toolName"], - "input": json.loads(tool_use["content"]) - } - - return { - "toolUse": tool_use_event + "input": json.loads(tool_use["content"]), } - + + return {"toolUse": tool_use_event} + # Handle interruption elif nova_event.get("stopReason") == "INTERRUPTED": log_event("nova_interruption_stop_reason") - - interruption: InterruptionDetectedEvent = { - "reason": "user_input" - } - - return { - "interruptionDetected": interruption - } - + + interruption: InterruptionDetectedEvent = {"reason": "user_input"} + + return {"interruptionDetected": interruption} + # Handle usage events (ignore) elif "usageEvent" in nova_event: return None - + # Handle content start events (track role) elif "contentStart" in nova_event: role = nova_event["contentStart"].get("role", "unknown") # Store role for subsequent text output events self._current_role = role return None - + # Handle other events else: return None - + # Nova Sonic event template methods def _get_connection_start_event(self) -> str: """Generate Nova Sonic connection start event.""" - return json.dumps({ - "event": { - "sessionStart": { - "inferenceConfiguration": NOVA_INFERENCE_CONFIG - } - } - }) - + return json.dumps({"event": {"sessionStart": {"inferenceConfiguration": NOVA_INFERENCE_CONFIG}}}) + def _get_prompt_start_event(self, tools: List[ToolSpec]) -> str: """Generate Nova Sonic prompt start event with tool configuration.""" prompt_start_event = { @@ -556,143 +519,121 @@ def _get_prompt_start_event(self, tools: List[ToolSpec]) -> str: "promptStart": { "promptName": self.prompt_name, "textOutputConfiguration": NOVA_TEXT_CONFIG, - "audioOutputConfiguration": NOVA_AUDIO_OUTPUT_CONFIG + "audioOutputConfiguration": NOVA_AUDIO_OUTPUT_CONFIG, } } } - + if tools: tool_config = self._build_tool_configuration(tools) prompt_start_event["event"]["promptStart"]["toolUseOutputConfiguration"] = NOVA_TOOL_CONFIG prompt_start_event["event"]["promptStart"]["toolConfiguration"] = {"tools": tool_config} - + return json.dumps(prompt_start_event) - + def _build_tool_configuration(self, tools: List[ToolSpec]) -> List[Dict]: """Build tool configuration from tool specs.""" tool_config = [] for tool in tools: - input_schema = ({"json": json.dumps(tool['inputSchema']['json'])} - if 'json' in tool['inputSchema'] - else {"json": json.dumps(tool['inputSchema'])}) - - tool_config.append({ - "toolSpec": { - "name": tool["name"], - "description": tool["description"], - "inputSchema": input_schema - } - }) + input_schema = ( + {"json": json.dumps(tool["inputSchema"]["json"])} + if "json" in tool["inputSchema"] + else {"json": json.dumps(tool["inputSchema"])} + ) + + tool_config.append( + {"toolSpec": {"name": tool["name"], "description": tool["description"], "inputSchema": input_schema}} + ) return tool_config - + def _get_system_prompt_events(self, system_prompt: str) -> List[str]: """Generate system prompt events.""" content_name = str(uuid.uuid4()) return [ self._get_text_content_start_event(content_name, "SYSTEM"), self._get_text_input_event(content_name, system_prompt), - self._get_content_end_event(content_name) + self._get_content_end_event(content_name), ] - + def _get_text_content_start_event(self, content_name: str, role: str = "USER") -> str: """Generate text content start event.""" - return json.dumps({ - "event": { - "contentStart": { - "promptName": self.prompt_name, - "contentName": content_name, - "type": "TEXT", - "role": role, - "interactive": True, - "textInputConfiguration": NOVA_TEXT_CONFIG + return json.dumps( + { + "event": { + "contentStart": { + "promptName": self.prompt_name, + "contentName": content_name, + "type": "TEXT", + "role": role, + "interactive": True, + "textInputConfiguration": NOVA_TEXT_CONFIG, + } } } - }) - + ) + def _get_tool_content_start_event(self, content_name: str, tool_use_id: str) -> str: """Generate tool content start event.""" - return json.dumps({ - "event": { - "contentStart": { - "promptName": self.prompt_name, - "contentName": content_name, - "interactive": False, - "type": "TOOL", - "role": "TOOL", - "toolResultInputConfiguration": { - "toolUseId": tool_use_id, - "type": "TEXT", - "textInputConfiguration": NOVA_TEXT_CONFIG + return json.dumps( + { + "event": { + "contentStart": { + "promptName": self.prompt_name, + "contentName": content_name, + "interactive": False, + "type": "TOOL", + "role": "TOOL", + "toolResultInputConfiguration": { + "toolUseId": tool_use_id, + "type": "TEXT", + "textInputConfiguration": NOVA_TEXT_CONFIG, + }, } } } - }) - + ) + def _get_text_input_event(self, content_name: str, text: str) -> str: """Generate text input event.""" - return json.dumps({ - "event": { - "textInput": { - "promptName": self.prompt_name, - "contentName": content_name, - "content": text - } - } - }) - + return json.dumps( + {"event": {"textInput": {"promptName": self.prompt_name, "contentName": content_name, "content": text}}} + ) + def _get_tool_result_event(self, content_name: str, result: Dict[str, Any]) -> str: """Generate tool result event.""" - return json.dumps({ - "event": { - "toolResult": { - "promptName": self.prompt_name, - "contentName": content_name, - "content": json.dumps(result) + return json.dumps( + { + "event": { + "toolResult": { + "promptName": self.prompt_name, + "contentName": content_name, + "content": json.dumps(result), + } } } - }) - + ) + def _get_content_end_event(self, content_name: str) -> str: """Generate content end event.""" - return json.dumps({ - "event": { - "contentEnd": { - "promptName": self.prompt_name, - "contentName": content_name - } - } - }) - + return json.dumps({"event": {"contentEnd": {"promptName": self.prompt_name, "contentName": content_name}}}) + def _get_prompt_end_event(self) -> str: """Generate prompt end event.""" - return json.dumps({ - "event": { - "promptEnd": { - "promptName": self.prompt_name - } - } - }) - + return json.dumps({"event": {"promptEnd": {"promptName": self.prompt_name}}}) + def _get_connection_end_event(self) -> str: """Generate connection end event.""" - return json.dumps({ - "event": { - "connectionEnd": {} - } - }) - + return json.dumps({"event": {"connectionEnd": {}}}) + async def _send_nova_event(self, event: str) -> None: """Send event JSON string to Nova Sonic stream.""" try: - # Event is already a JSON string - bytes_data = event.encode('utf-8') - chunk = InvokeModelWithBidirectionalStreamInputChunk( - value=BidirectionalInputPayloadPart(bytes_=bytes_data) - ) + bytes_data = event.encode("utf-8") + chunk = InvokeModelWithBidirectionalStreamInputChunk(value=BidirectionalInputPayloadPart(bytes_=bytes_data)) await self.stream.input_stream.send(chunk) logger.debug("Successfully sent Nova Sonic event") - + except Exception as e: logger.error("Error sending Nova Sonic event: %s", e) logger.error("Event was: %s", event) @@ -701,14 +642,14 @@ async def _send_nova_event(self, event: str) -> None: class NovaSonicBidirectionalModel(BidirectionalModel): """Nova Sonic model implementation for bidirectional streaming. - + Provides access to Amazon's Nova Sonic model through the bidirectional streaming interface, handling AWS authentication and connection management. """ - + def __init__(self, model_id: str = "amazon.nova-sonic-v1:0", region: str = "us-east-1", **config): """Initialize Nova Sonic bidirectional model. - + Args: model_id: Nova Sonic model identifier. region: AWS region. @@ -718,61 +659,60 @@ def __init__(self, model_id: str = "amazon.nova-sonic-v1:0", region: str = "us-e self.region = region self.config = config self._client = None - + logger.debug("Nova Sonic bidirectional model initialized: %s", model_id) - + async def create_bidirectional_connection( self, system_prompt: Optional[str] = None, tools: Optional[List[ToolSpec]] = None, messages: Optional[Messages] = None, - **kwargs + **kwargs, ) -> BidirectionalModelSession: """Create Nova Sonic bidirectional connection.""" log_flow("nova_connection_create", "starting") - + # Initialize client if needed if not self._client: await time_it_async("initialize_client", lambda: self._initialize_client()) - + # Start Nova Sonic bidirectional stream try: - stream = await time_it_async("invoke_model_with_bidirectional_stream", + stream = await time_it_async( + "invoke_model_with_bidirectional_stream", lambda: self._client.invoke_model_with_bidirectional_stream( InvokeModelWithBidirectionalStreamOperationInput(model_id=self.model_id) - )) - + ), + ) + # Create and initialize connection connection = NovaSonicSession(stream, self.config) - await time_it_async("initialize_connection", - lambda: connection.initialize(system_prompt, tools, messages)) - + await time_it_async("initialize_connection", lambda: connection.initialize(system_prompt, tools, messages)) + log_event("nova_connection_created") return connection except Exception as e: log_event("nova_connection_create_error", error=str(e)) logger.error("Failed to create Nova Sonic connection: %s", e) raise - + async def _initialize_client(self) -> None: """Initialize Nova Sonic client.""" try: - config = Config( endpoint_uri=f"https://bedrock-runtime.{self.region}.amazonaws.com", region=self.region, aws_credentials_identity_resolver=EnvironmentCredentialsResolver(), http_auth_scheme_resolver=HTTPAuthSchemeResolver(), - http_auth_schemes={"aws.auth#sigv4": SigV4AuthScheme()} + http_auth_schemes={"aws.auth#sigv4": SigV4AuthScheme()}, ) - + self._client = BedrockRuntimeClient(config=config) logger.debug("Nova Sonic client initialized") - + except ImportError as e: logger.error("Nova Sonic dependencies not available: %s", e) raise except Exception as e: logger.error("Error initializing Nova Sonic client: %s", e) raise - diff --git a/src/strands/experimental/bidirectional_streaming/tests/test_bidirectional_streaming.py b/src/strands/experimental/bidirectional_streaming/tests/test_bidirectional_streaming.py index d650aba9b..6ef96f919 100644 --- a/src/strands/experimental/bidirectional_streaming/tests/test_bidirectional_streaming.py +++ b/src/strands/experimental/bidirectional_streaming/tests/test_bidirectional_streaming.py @@ -11,12 +11,13 @@ # Add the src directory to Python path sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent.parent)) import time -import pyaudio -from strands.experimental.bidirectional_streaming.agent.agent import BidirectionalAgent -from strands.experimental.bidirectional_streaming.models.novasonic import NovaSonicBidirectionalModel +import pyaudio from strands_tools import calculator +from ..agent.agent import BidirectionalAgent +from ..models.novasonic import NovaSonicBidirectionalModel + async def play(context): """Play audio output with responsive interruption support.""" @@ -26,7 +27,7 @@ async def play(context): format=pyaudio.paInt16, output=True, rate=24000, - frames_per_buffer=1024, + frames_per_buffer=1024, ) try: @@ -40,36 +41,33 @@ async def play(context): context["audio_out"].get_nowait() except asyncio.QueueEmpty: break - + context["interrupted"] = False - await asyncio.sleep(0.05) + await asyncio.sleep(0.05) continue - + # Get next audio data - audio_data = await asyncio.wait_for( - context["audio_out"].get(), - timeout=0.1 - ) - + audio_data = await asyncio.wait_for(context["audio_out"].get(), timeout=0.1) + if audio_data and context["active"]: - chunk_size = 1024 + chunk_size = 1024 for i in range(0, len(audio_data), chunk_size): # Check for interruption before each chunk if context.get("interrupted", False) or not context["active"]: break - + end = min(i + chunk_size, len(audio_data)) chunk = audio_data[i:end] speaker.write(chunk) await asyncio.sleep(0.001) - + except asyncio.TimeoutError: continue # No audio available except asyncio.QueueEmpty: await asyncio.sleep(0.01) except asyncio.CancelledError: break - + except asyncio.CancelledError: pass finally: @@ -111,30 +109,30 @@ async def receive(agent, context): if "audioOutput" in event: if not context.get("interrupted", False): context["audio_out"].put_nowait(event["audioOutput"]["audioData"]) - + # Handle interruption events elif "interruptionDetected" in event: context["interrupted"] = True elif "interrupted" in event: context["interrupted"] = True - + # Handle text output with interruption detection elif "textOutput" in event: text_content = event["textOutput"].get("content", "") role = event["textOutput"].get("role", "unknown") - + # Check for text-based interruption patterns if '{ "interrupted" : true }' in text_content: context["interrupted"] = True elif "interrupted" in text_content.lower(): context["interrupted"] = True - + # Log text output if role.upper() == "USER": print(f"User: {text_content}") elif role.upper() == "ASSISTANT": print(f"Assistant: {text_content}") - + except asyncio.CancelledError: pass @@ -145,18 +143,13 @@ async def send(agent, context): while time.time() - context["start_time"] < context["duration"]: try: audio_bytes = context["audio_in"].get_nowait() - audio_event = { - "audioData": audio_bytes, - "format": "pcm", - "sampleRate": 16000, - "channels": 1 - } + audio_event = {"audioData": audio_bytes, "format": "pcm", "sampleRate": 16000, "channels": 1} await agent.send(audio_event) except asyncio.QueueEmpty: await asyncio.sleep(0.01) # Restored to working timing except asyncio.CancelledError: break - + context["active"] = False except asyncio.CancelledError: pass @@ -166,14 +159,10 @@ async def main(duration=180): """Main function for bidirectional streaming test.""" print("Starting bidirectional streaming test...") print("Audio optimizations: 1024-byte buffers, balanced smooth playback + responsive interruption") - + # Initialize model and agent model = NovaSonicBidirectionalModel(region="us-east-1") - agent = BidirectionalAgent( - model=model, - tools=[calculator], - system_prompt="You are a helpful assistant." - ) + agent = BidirectionalAgent(model=model, tools=[calculator], system_prompt="You are a helpful assistant.") await agent.start() @@ -189,15 +178,11 @@ async def main(duration=180): } print("Speak into microphone. Press Ctrl+C to exit.") - + try: # Run all tasks concurrently await asyncio.gather( - play(context), - record(context), - receive(agent, context), - send(agent, context), - return_exceptions=True + play(context), record(context), receive(agent, context), send(agent, context), return_exceptions=True ) except KeyboardInterrupt: print("\nInterrupted by user") diff --git a/src/strands/experimental/bidirectional_streaming/types/__init__.py b/src/strands/experimental/bidirectional_streaming/types/__init__.py index f6441d2f0..510285f06 100644 --- a/src/strands/experimental/bidirectional_streaming/types/__init__.py +++ b/src/strands/experimental/bidirectional_streaming/types/__init__.py @@ -1,3 +1,31 @@ -"""Bidirectional streaming types package.""" -# Types package +"""Type definitions for bidirectional streaming.""" +from .bidirectional_streaming import ( + DEFAULT_CHANNELS, + DEFAULT_SAMPLE_RATE, + SUPPORTED_AUDIO_FORMATS, + SUPPORTED_CHANNELS, + SUPPORTED_SAMPLE_RATES, + AudioInputEvent, + AudioOutputEvent, + BidirectionalConnectionEndEvent, + BidirectionalConnectionStartEvent, + BidirectionalStreamEvent, + InterruptionDetectedEvent, + TextOutputEvent, +) + +__all__ = [ + "AudioInputEvent", + "AudioOutputEvent", + "BidirectionalConnectionEndEvent", + "BidirectionalConnectionStartEvent", + "BidirectionalStreamEvent", + "InterruptionDetectedEvent", + "TextOutputEvent", + "SUPPORTED_AUDIO_FORMATS", + "SUPPORTED_SAMPLE_RATES", + "SUPPORTED_CHANNELS", + "DEFAULT_SAMPLE_RATE", + "DEFAULT_CHANNELS", +] diff --git a/src/strands/experimental/bidirectional_streaming/types/bidirectional_streaming.py b/src/strands/experimental/bidirectional_streaming/types/bidirectional_streaming.py index fabe53ac9..01d72356a 100644 --- a/src/strands/experimental/bidirectional_streaming/types/bidirectional_streaming.py +++ b/src/strands/experimental/bidirectional_streaming/types/bidirectional_streaming.py @@ -19,23 +19,25 @@ from typing import Any, Dict, Literal, Optional -from strands.types.content import Role -from strands.types.streaming import StreamEvent from typing_extensions import TypedDict +from ....types.content import Role +from ....types.streaming import StreamEvent + # Audio format constants -SUPPORTED_AUDIO_FORMATS = ['pcm', 'wav', 'opus', 'mp3'] +SUPPORTED_AUDIO_FORMATS = ["pcm", "wav", "opus", "mp3"] SUPPORTED_SAMPLE_RATES = [16000, 24000, 48000] SUPPORTED_CHANNELS = [1, 2] # 1=mono, 2=stereo DEFAULT_SAMPLE_RATE = 16000 DEFAULT_CHANNELS = 1 + class AudioOutputEvent(TypedDict): """Audio output event from the model. - + Provides standardized audio output format across different providers using raw bytes instead of provider-specific encodings. - + Attributes: audioData: Raw audio bytes (not base64 or hex encoded). format: Audio format from SUPPORTED_AUDIO_FORMATS. @@ -43,9 +45,9 @@ class AudioOutputEvent(TypedDict): channels: Channel count from SUPPORTED_CHANNELS. encoding: Original provider encoding for debugging purposes. """ - + audioData: bytes - format: Literal['pcm', 'wav', 'opus', 'mp3'] + format: Literal["pcm", "wav", "opus", "mp3"] sampleRate: Literal[16000, 24000, 48000] channels: Literal[1, 2] encoding: Optional[str] @@ -53,78 +55,78 @@ class AudioOutputEvent(TypedDict): class AudioInputEvent(TypedDict): """Audio input event for sending audio to the model. - + Used for sending audio data through the send() method. - + Attributes: audioData: Raw audio bytes to send to model. format: Audio format from SUPPORTED_AUDIO_FORMATS. sampleRate: Sample rate from SUPPORTED_SAMPLE_RATES. channels: Channel count from SUPPORTED_CHANNELS. """ - + audioData: bytes - format: Literal['pcm', 'wav', 'opus', 'mp3'] + format: Literal["pcm", "wav", "opus", "mp3"] sampleRate: Literal[16000, 24000, 48000] channels: Literal[1, 2] class TextOutputEvent(TypedDict): """Text output event from the model during bidirectional streaming. - + Attributes: text: The text content from the model. role: The role of the message sender. """ - + text: str role: Role class InterruptionDetectedEvent(TypedDict): """Interruption detection event. - + Signals when user interruption is detected during model generation. - + Attributes: reason: Interruption reason from predefined set. """ - - reason: Literal['user_input', 'vad_detected', 'manual'] + + reason: Literal["user_input", "vad_detected", "manual"] class BidirectionalConnectionStartEvent(TypedDict, total=False): """connection start event for bidirectional streaming. - + Attributes: connectionId: Unique connection identifier. metadata: Provider-specific connection metadata. """ - + connectionId: Optional[str] metadata: Optional[Dict[str, Any]] class BidirectionalConnectionEndEvent(TypedDict): """connection end event for bidirectional streaming. - + Attributes: reason: Reason for connection end from predefined set. connectionId: Unique connection identifier. metadata: Provider-specific connection metadata. """ - - reason: Literal['user_request', 'timeout', 'error'] + + reason: Literal["user_request", "timeout", "error"] connectionId: Optional[str] metadata: Optional[Dict[str, Any]] class BidirectionalStreamEvent(StreamEvent, total=False): """Bidirectional stream event extending existing StreamEvent. - + Extends the existing StreamEvent type with bidirectional-specific events while maintaining full backward compatibility with existing Strands streaming. - + Attributes: audioOutput: Audio output from the model. audioInput: Audio input sent to the model. @@ -133,11 +135,10 @@ class BidirectionalStreamEvent(StreamEvent, total=False): BidirectionalConnectionStart: connection start event. BidirectionalConnectionEnd: connection end event. """ - + audioOutput: AudioOutputEvent audioInput: AudioInputEvent textOutput: TextOutputEvent interruptionDetected: InterruptionDetectedEvent BidirectionalConnectionStart: BidirectionalConnectionStartEvent BidirectionalConnectionEnd: BidirectionalConnectionEndEvent - diff --git a/src/strands/experimental/bidirectional_streaming/utils/__init__.py b/src/strands/experimental/bidirectional_streaming/utils/__init__.py new file mode 100644 index 000000000..579478436 --- /dev/null +++ b/src/strands/experimental/bidirectional_streaming/utils/__init__.py @@ -0,0 +1,5 @@ +"""Utility functions for bidirectional streaming.""" + +from .debug import log_event, log_flow, time_it_async + +__all__ = ["log_event", "log_flow", "time_it_async"] diff --git a/src/strands/experimental/bidirectional_streaming/utils/debug.py b/src/strands/experimental/bidirectional_streaming/utils/debug.py index 1e88b6ead..6a7fc3982 100644 --- a/src/strands/experimental/bidirectional_streaming/utils/debug.py +++ b/src/strands/experimental/bidirectional_streaming/utils/debug.py @@ -11,30 +11,34 @@ # Debug logging system matching successful tool use example DEBUG = False # Disable debug logging for clean output like tool use example + def debug_print(message): """Print debug message with timestamp and function name.""" if DEBUG: function_name = inspect.stack()[1].function - if function_name == 'time_it_async': + if function_name == "time_it_async": function_name = inspect.stack()[2].function - timestamp = '{:%Y-%m-%d %H:%M:%S.%f}'.format(datetime.datetime.now())[:-3] + timestamp = "{:%Y-%m-%d %H:%M:%S.%f}".format(datetime.datetime.now())[:-3] print(f"{timestamp} {function_name} {message}") + def log_event(event_type, **context): """Log important events with structured context.""" if DEBUG: function_name = inspect.stack()[1].function - timestamp = '{:%Y-%m-%d %H:%M:%S.%f}'.format(datetime.datetime.now())[:-3] + timestamp = "{:%Y-%m-%d %H:%M:%S.%f}".format(datetime.datetime.now())[:-3] context_str = " ".join([f"{k}={v}" for k, v in context.items()]) if context else "" print(f"{timestamp} {function_name} EVENT: {event_type} {context_str}") + def log_flow(step, details=""): """Log important flow steps without excessive detail.""" if DEBUG: function_name = inspect.stack()[1].function - timestamp = '{:%Y-%m-%d %H:%M:%S.%f}'.format(datetime.datetime.now())[:-3] + timestamp = "{:%Y-%m-%d %H:%M:%S.%f}".format(datetime.datetime.now())[:-3] print(f"{timestamp} {function_name} FLOW: {step} {details}") + async def time_it_async(label, method_to_run): """Time asynchronous method execution.""" start_time = time.perf_counter() @@ -42,4 +46,3 @@ async def time_it_async(label, method_to_run): end_time = time.perf_counter() debug_print(f"Execution time for {label}: {end_time - start_time:.4f} seconds") return result - From f7e67aec65640b9e262e88d4f82d020308143250 Mon Sep 17 00:00:00 2001 From: Rachit Mehta Date: Wed, 1 Oct 2025 23:59:44 -0400 Subject: [PATCH 05/23] fix linting issues --- pyproject.toml | 1 - .../event_loop/bidirectional_event_loop.py | 5 +++-- .../experimental/bidirectional_streaming/models/novasonic.py | 4 ++-- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index f45794d12..dd01ebde3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -59,7 +59,6 @@ bidirectional-streaming = [ "smithy-aws-core>=0.0.1", "pytz", "aws_sdk_bedrock_runtime", - "python>=3.12" ] otel = ["opentelemetry-exporter-otlp-proto-http>=1.30.0,<2.0.0"] docs = [ diff --git a/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py b/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py index c90d118ff..4fbae3992 100644 --- a/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py +++ b/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py @@ -21,6 +21,7 @@ from ....tools._validator import validate_and_prepare_tools from ....types.content import Message from ....types.tools import ToolResult, ToolUse +from ..agent.agent import BidirectionalAgent from ..models.bidirectional_model import BidirectionalModelSession from ..utils.debug import log_event, log_flow @@ -61,7 +62,7 @@ def __init__(self, model_session: BidirectionalModelSession, agent): self.interrupted = False -async def start_bidirectional_connection(agent: "BidirectionalAgent") -> BidirectionalConnection: +async def start_bidirectional_connection(agent: BidirectionalAgent) -> BidirectionalConnection: """Initialize bidirectional session with concurrent background tasks. Creates a model-specific session and starts background tasks for processing @@ -325,7 +326,7 @@ async def _process_tool_execution(session: BidirectionalConnection) -> None: session.pending_tool_tasks[task_id] = task # 🔥 ADD CLEANUP CALLBACK (Nova Sonic pattern) - def cleanup_task(completed_task): + def cleanup_task(completed_task, task_id=task_id): try: # Remove from pending tasks if task_id in session.pending_tool_tasks: diff --git a/src/strands/experimental/bidirectional_streaming/models/novasonic.py b/src/strands/experimental/bidirectional_streaming/models/novasonic.py index 0efd2413c..22912354d 100644 --- a/src/strands/experimental/bidirectional_streaming/models/novasonic.py +++ b/src/strands/experimental/bidirectional_streaming/models/novasonic.py @@ -147,7 +147,7 @@ def _build_initialization_events( async def _send_initialization_events(self, events: List[str]) -> None: """Send initialization events with required delays.""" for i, event in enumerate(events): - await time_it_async(f"send_init_event_{i + 1}", lambda: self._send_nova_event(event)) + await time_it_async(f"send_init_event_{i + 1}", lambda event=event: self._send_nova_event(event)) await asyncio.sleep(EVENT_DELAY) async def _process_responses(self) -> None: @@ -384,7 +384,7 @@ async def send_tool_result(self, tool_use_id: str, result: Dict[str, Any]) -> No ] for i, event in enumerate(events): - await time_it_async(f"send_tool_event_{i + 1}", lambda: self._send_nova_event(event)) + await time_it_async(f"send_tool_event_{i + 1}", lambda event=event: self._send_nova_event(event)) async def send_tool_error(self, tool_use_id: str, error: str) -> None: """Send tool error using Nova Sonic format.""" From c654621d9c345316c90e6895a430e2f1918a9b8c Mon Sep 17 00:00:00 2001 From: Rachit Mehta Date: Thu, 2 Oct 2025 00:07:34 -0400 Subject: [PATCH 06/23] Remove typing module and rely on python's built-in types --- .../bidirectional_streaming/agent/agent.py | 10 ++--- .../event_loop/bidirectional_event_loop.py | 13 +++---- .../models/bidirectional_model.py | 12 +++--- .../models/novasonic.py | 38 +++++++++---------- 4 files changed, 36 insertions(+), 37 deletions(-) diff --git a/src/strands/experimental/bidirectional_streaming/agent/agent.py b/src/strands/experimental/bidirectional_streaming/agent/agent.py index d7a5f17a3..997a0d1df 100644 --- a/src/strands/experimental/bidirectional_streaming/agent/agent.py +++ b/src/strands/experimental/bidirectional_streaming/agent/agent.py @@ -14,7 +14,7 @@ import asyncio import logging -from typing import AsyncIterable, List, Optional, Union +from typing import AsyncIterable from ....tools.executors import ConcurrentToolExecutor from ....tools.registry import ToolRegistry @@ -37,9 +37,9 @@ class BidirectionalAgent: def __init__( self, model: BidirectionalModel, - tools: Optional[List] = None, - system_prompt: Optional[str] = None, - messages: Optional[Messages] = None, + tools: list | None = None, + system_prompt: str | None = None, + messages: Messages | None = None, ): """Initialize bidirectional agent with required model and optional configuration. @@ -83,7 +83,7 @@ async def start(self) -> None: self._session = await start_bidirectional_connection(self) log_event("conversation_ready") - async def send(self, input_data: Union[str, AudioInputEvent]) -> None: + async def send(self, input_data: str | AudioInputEvent) -> None: """Send input to the model (text or audio). Unified method for sending both text and audio input to the model during diff --git a/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py b/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py index 4fbae3992..65ee6b905 100644 --- a/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py +++ b/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py @@ -16,7 +16,6 @@ import logging import traceback import uuid -from typing import Any, Dict from ....tools._validator import validate_and_prepare_tools from ....types.content import Message @@ -56,14 +55,14 @@ def __init__(self, model_session: BidirectionalModelSession, agent): self.audio_output_queue = asyncio.Queue() # Task management for cleanup - self.pending_tool_tasks: Dict[str, asyncio.Task] = {} + self.pending_tool_tasks: dict[str, asyncio.Task] = {} # Interruption handling (model-agnostic) self.interrupted = False async def start_bidirectional_connection(agent: BidirectionalAgent) -> BidirectionalConnection: - """Initialize bidirectional session with concurrent background tasks. + """Initialize bidirectional session with conycurrent background tasks. Creates a model-specific session and starts background tasks for processing model events, executing tools, and managing the session lifecycle. @@ -365,7 +364,7 @@ def cleanup_task(completed_task, task_id=task_id): log_flow("tool_execution", "processor stopped") -def _convert_to_strands_event(provider_event: Dict) -> Dict: +def _convert_to_strands_event(provider_event: dict) -> dict: """Pass-through for events already normalized by provider sessions. Providers convert their raw events to standard format before reaching here. @@ -385,7 +384,7 @@ def _convert_to_strands_event(provider_event: Dict) -> Dict: return provider_event -async def _execute_tool_with_strands(session: BidirectionalConnection, tool_use: Dict) -> None: +async def _execute_tool_with_strands(session: BidirectionalConnection, tool_use: dict) -> None: """Execute tool using Strands infrastructure with interruption support. Executes tools using the existing Strands tool system, handles interruption @@ -501,11 +500,11 @@ def _extract_callable_function(tool_func): raise ValueError(f"Tool function not callable: {type(tool_func).__name__}") -def _create_success_result(tool_use_id: str, result) -> Dict[str, Any]: +def _create_success_result(tool_use_id: str, result) -> dict[str, any]: """Create a successful tool result.""" return {"toolUseId": tool_use_id, "status": "success", "content": [{"text": json.dumps(result)}]} -def _create_error_result(tool_use_id: str, error: str) -> Dict[str, Any]: +def _create_error_result(tool_use_id: str, error: str) -> dict[str, any]: """Create an error tool result.""" return {"toolUseId": tool_use_id, "status": "error", "content": [{"text": f"Error: {error}"}]} diff --git a/src/strands/experimental/bidirectional_streaming/models/bidirectional_model.py b/src/strands/experimental/bidirectional_streaming/models/bidirectional_model.py index cc803458b..1432b112a 100644 --- a/src/strands/experimental/bidirectional_streaming/models/bidirectional_model.py +++ b/src/strands/experimental/bidirectional_streaming/models/bidirectional_model.py @@ -13,7 +13,7 @@ import abc import logging -from typing import Any, AsyncIterable, Dict, List, Optional +from typing import AsyncIterable from ....types.content import Messages from ....types.tools import ToolSpec @@ -31,7 +31,7 @@ class BidirectionalModelSession(abc.ABC): """ @abc.abstractmethod - async def receive_events(self) -> AsyncIterable[Dict[str, Any]]: + async def receive_events(self) -> AsyncIterable[dict[str, any]]: """Receive events from the model in standardized format. Converts provider-specific events to a common format that can be @@ -67,7 +67,7 @@ async def send_interrupt(self) -> None: raise NotImplementedError @abc.abstractmethod - async def send_tool_result(self, tool_use_id: str, result: Dict[str, Any]) -> None: + async def send_tool_result(self, tool_use_id: str, result: dict[str, any]) -> None: """Send tool execution result to the model. Formats and sends tool results according to the provider's specific protocol. @@ -95,9 +95,9 @@ class BidirectionalModel(abc.ABC): @abc.abstractmethod async def create_bidirectional_connection( self, - system_prompt: Optional[str] = None, - tools: Optional[List[ToolSpec]] = None, - messages: Optional[Messages] = None, + system_prompt: str | None = None, + tools: list[ToolSpec] | None = None, + messages: Messages | None = None, **kwargs, ) -> BidirectionalModelSession: """Create a bidirectional connection with the model. diff --git a/src/strands/experimental/bidirectional_streaming/models/novasonic.py b/src/strands/experimental/bidirectional_streaming/models/novasonic.py index 22912354d..969cac159 100644 --- a/src/strands/experimental/bidirectional_streaming/models/novasonic.py +++ b/src/strands/experimental/bidirectional_streaming/models/novasonic.py @@ -19,7 +19,7 @@ import time import traceback import uuid -from typing import Any, AsyncIterable, Dict, List, Optional +from typing import AsyncIterable from aws_sdk_bedrock_runtime.client import BedrockRuntimeClient, InvokeModelWithBidirectionalStreamOperationInput from aws_sdk_bedrock_runtime.config import Config, HTTPAuthSchemeResolver, SigV4AuthScheme @@ -80,7 +80,7 @@ class NovaSonicSession(BidirectionalModelSession): interface. """ - def __init__(self, stream, config: Dict[str, Any]): + def __init__(self, stream, config: dict[str, any]): """Initialize Nova Sonic connection. Args: @@ -111,9 +111,9 @@ def __init__(self, stream, config: Dict[str, Any]): async def initialize( self, - system_prompt: Optional[str] = None, - tools: Optional[List[ToolSpec]] = None, - messages: Optional[Messages] = None, + system_prompt: str | None = None, + tools: list[ToolSpec] | None = None, + messages: Messages | None = None, ) -> None: """Initialize Nova Sonic connection with required protocol sequence.""" try: @@ -132,8 +132,8 @@ async def initialize( raise def _build_initialization_events( - self, system_prompt: str, tools: List[ToolSpec], messages: Optional[Messages] - ) -> List[str]: + self, system_prompt: str, tools: list[ToolSpec], messages: Messages | None + ) -> list[str]: """Build the sequence of initialization events.""" events = [self._get_connection_start_event(), self._get_prompt_start_event(tools)] @@ -144,7 +144,7 @@ def _build_initialization_events( return events - async def _send_initialization_events(self, events: List[str]) -> None: + async def _send_initialization_events(self, events: list[str]) -> None: """Send initialization events with required delays.""" for i, event in enumerate(events): await time_it_async(f"send_init_event_{i + 1}", lambda event=event: self._send_nova_event(event)) @@ -192,7 +192,7 @@ async def _handle_response_data(self, response_data: str) -> None: except json.JSONDecodeError as e: log_event("nova_json_error", error=str(e)) - def _log_event_type(self, nova_event: Dict[str, Any]) -> None: + def _log_event_type(self, nova_event: dict[str, any]) -> None: """Log specific Nova Sonic event types for debugging.""" if "usageEvent" in nova_event: log_event("nova_usage", usage=nova_event["usageEvent"]) @@ -206,7 +206,7 @@ def _log_event_type(self, nova_event: Dict[str, Any]) -> None: audio_bytes = base64.b64decode(audio_content) log_event("nova_audio_output", bytes=len(audio_bytes)) - async def receive_events(self) -> AsyncIterable[Dict[str, Any]]: + async def receive_events(self) -> AsyncIterable[dict[str, any]]: """Receive Nova Sonic events and convert to provider-agnostic format.""" if not self.stream: logger.error("Stream is None") @@ -370,7 +370,7 @@ async def send_interrupt(self) -> None: } await self._send_nova_event(interrupt_event) - async def send_tool_result(self, tool_use_id: str, result: Dict[str, Any]) -> None: + async def send_tool_result(self, tool_use_id: str, result: dict[str, any]) -> None: """Send tool result using Nova Sonic toolResult format.""" if not self._active: return @@ -433,7 +433,7 @@ async def close(self) -> None: finally: log_event("nova_connection_closed") - def _convert_nova_event(self, nova_event: Dict[str, Any]) -> Optional[Dict[str, Any]]: + def _convert_nova_event(self, nova_event: dict[str, any]) -> dict[str, any] | None: """Convert Nova Sonic events to provider-agnostic format.""" # Handle audio output if "audioOutput" in nova_event: @@ -512,7 +512,7 @@ def _get_connection_start_event(self) -> str: """Generate Nova Sonic connection start event.""" return json.dumps({"event": {"sessionStart": {"inferenceConfiguration": NOVA_INFERENCE_CONFIG}}}) - def _get_prompt_start_event(self, tools: List[ToolSpec]) -> str: + def _get_prompt_start_event(self, tools: list[ToolSpec]) -> str: """Generate Nova Sonic prompt start event with tool configuration.""" prompt_start_event = { "event": { @@ -531,7 +531,7 @@ def _get_prompt_start_event(self, tools: List[ToolSpec]) -> str: return json.dumps(prompt_start_event) - def _build_tool_configuration(self, tools: List[ToolSpec]) -> List[Dict]: + def _build_tool_configuration(self, tools: list[ToolSpec]) -> list[dict]: """Build tool configuration from tool specs.""" tool_config = [] for tool in tools: @@ -546,7 +546,7 @@ def _build_tool_configuration(self, tools: List[ToolSpec]) -> List[Dict]: ) return tool_config - def _get_system_prompt_events(self, system_prompt: str) -> List[str]: + def _get_system_prompt_events(self, system_prompt: str) -> list[str]: """Generate system prompt events.""" content_name = str(uuid.uuid4()) return [ @@ -599,7 +599,7 @@ def _get_text_input_event(self, content_name: str, text: str) -> str: {"event": {"textInput": {"promptName": self.prompt_name, "contentName": content_name, "content": text}}} ) - def _get_tool_result_event(self, content_name: str, result: Dict[str, Any]) -> str: + def _get_tool_result_event(self, content_name: str, result: dict[str, any]) -> str: """Generate tool result event.""" return json.dumps( { @@ -664,9 +664,9 @@ def __init__(self, model_id: str = "amazon.nova-sonic-v1:0", region: str = "us-e async def create_bidirectional_connection( self, - system_prompt: Optional[str] = None, - tools: Optional[List[ToolSpec]] = None, - messages: Optional[Messages] = None, + system_prompt: str | None = None, + tools: list[ToolSpec] | None = None, + messages: Messages | None = None, **kwargs, ) -> BidirectionalModelSession: """Create Nova Sonic bidirectional connection.""" From 1f1abacd839cd6ed26ebd9a84bfa2e8aeb50be01 Mon Sep 17 00:00:00 2001 From: Rachit Mehta Date: Thu, 2 Oct 2025 00:12:15 -0400 Subject: [PATCH 07/23] add typing to methods --- .../event_loop/bidirectional_event_loop.py | 8 ++++---- .../bidirectional_streaming/models/novasonic.py | 6 +++--- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py b/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py index 65ee6b905..ea00468bb 100644 --- a/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py +++ b/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py @@ -38,7 +38,7 @@ class BidirectionalConnection: handling while providing a simple interface for agent interactions. """ - def __init__(self, model_session: BidirectionalModelSession, agent): + def __init__(self, model_session: BidirectionalModelSession, agent: BidirectionalAgent) -> None: """Initialize session with model session and agent reference. Args: @@ -325,7 +325,7 @@ async def _process_tool_execution(session: BidirectionalConnection) -> None: session.pending_tool_tasks[task_id] = task # 🔥 ADD CLEANUP CALLBACK (Nova Sonic pattern) - def cleanup_task(completed_task, task_id=task_id): + def cleanup_task(completed_task: asyncio.Task, task_id: str = task_id) -> None: try: # Remove from pending tasks if task_id in session.pending_tool_tasks: @@ -488,7 +488,7 @@ async def _execute_tool_with_strands(session: BidirectionalConnection, tool_use: log_event("tool_error_send_failed", error=str(send_error)) -def _extract_callable_function(tool_func): +def _extract_callable_function(tool_func: any) -> any: """Extract the callable function from different tool object types.""" if hasattr(tool_func, "_tool_func"): return tool_func._tool_func @@ -500,7 +500,7 @@ def _extract_callable_function(tool_func): raise ValueError(f"Tool function not callable: {type(tool_func).__name__}") -def _create_success_result(tool_use_id: str, result) -> dict[str, any]: +def _create_success_result(tool_use_id: str, result: any) -> dict[str, any]: """Create a successful tool result.""" return {"toolUseId": tool_use_id, "status": "success", "content": [{"text": json.dumps(result)}]} diff --git a/src/strands/experimental/bidirectional_streaming/models/novasonic.py b/src/strands/experimental/bidirectional_streaming/models/novasonic.py index 969cac159..89472350b 100644 --- a/src/strands/experimental/bidirectional_streaming/models/novasonic.py +++ b/src/strands/experimental/bidirectional_streaming/models/novasonic.py @@ -80,7 +80,7 @@ class NovaSonicSession(BidirectionalModelSession): interface. """ - def __init__(self, stream, config: dict[str, any]): + def __init__(self, stream: any, config: dict[str, any]) -> None: """Initialize Nova Sonic connection. Args: @@ -312,7 +312,7 @@ async def send_audio_content(self, audio_input: AudioInputEvent) -> None: # Start silence detection task self.silence_task = asyncio.create_task(self._check_silence()) - async def _check_silence(self): + async def _check_silence(self) -> None: """Check for silence and automatically end audio connection.""" try: await asyncio.sleep(self.silence_threshold) @@ -647,7 +647,7 @@ class NovaSonicBidirectionalModel(BidirectionalModel): streaming interface, handling AWS authentication and connection management. """ - def __init__(self, model_id: str = "amazon.nova-sonic-v1:0", region: str = "us-east-1", **config): + def __init__(self, model_id: str = "amazon.nova-sonic-v1:0", region: str = "us-east-1", **config: any) -> None: """Initialize Nova Sonic bidirectional model. Args: From eb543b52434dbe6af1f2f309f77446a97ed08871 Mon Sep 17 00:00:00 2001 From: Rachit Mehta Date: Thu, 2 Oct 2025 12:00:04 -0400 Subject: [PATCH 08/23] Improve comments and remove unused method _convert_to_strands_event --- .../bidirectional_streaming/agent/agent.py | 1 - .../event_loop/bidirectional_event_loop.py | 45 +++++++------------ 2 files changed, 15 insertions(+), 31 deletions(-) diff --git a/src/strands/experimental/bidirectional_streaming/agent/agent.py b/src/strands/experimental/bidirectional_streaming/agent/agent.py index 997a0d1df..e27885c7e 100644 --- a/src/strands/experimental/bidirectional_streaming/agent/agent.py +++ b/src/strands/experimental/bidirectional_streaming/agent/agent.py @@ -98,7 +98,6 @@ async def send(self, input_data: str | AudioInputEvent) -> None: self._validate_active_session() if isinstance(input_data, str): - # Handle text input log_event("text_sent", length=len(input_data)) await self._session.model_session.send_text_content(input_data) elif isinstance(input_data, dict) and "audioData" in input_data: diff --git a/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py b/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py index ea00468bb..fddd1245a 100644 --- a/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py +++ b/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py @@ -189,7 +189,7 @@ async def _handle_interruption(session: BidirectionalConnection) -> None: log_event("interruption_detected") session.interrupted = True - # 🔥 CANCEL ALL PENDING TOOL TASKS (Nova Sonic pattern) + # Cancel all pending tool execution tasks cancelled_tools = 0 for task_id, task in list(session.pending_tool_tasks.items()): if not task.done(): @@ -200,7 +200,7 @@ async def _handle_interruption(session: BidirectionalConnection) -> None: if cancelled_tools > 0: log_event("tool_tasks_cancelled", count=cancelled_tools) - # 🔥 AGGRESSIVELY CLEAR AUDIO OUTPUT QUEUE (Nova Sonic pattern) + # Clear all queued audio output events cleared_count = 0 while True: try: @@ -258,8 +258,11 @@ async def _process_model_events(session: BidirectionalConnection) -> None: if not session.active: break - # Convert provider events to Strands format - strands_event = _convert_to_strands_event(provider_event) + # Basic validation - skip invalid events + if not isinstance(provider_event, dict): + continue + + strands_event = provider_event # Handle interruption detection (multiple patterns) if strands_event.get("interruptionDetected"): @@ -269,7 +272,7 @@ async def _process_model_events(session: BidirectionalConnection) -> None: await session.agent._output_queue.put(strands_event) continue - # Check for text-based interruption (Nova Sonic pattern) + # Check for text-based interruption if strands_event.get("textOutput"): text_content = strands_event["textOutput"].get("content", "") if '{ "interrupted" : true }' in text_content: @@ -324,7 +327,6 @@ async def _process_tool_execution(session: BidirectionalConnection) -> None: task = asyncio.create_task(_execute_tool_with_strands(session, tool_use)) session.pending_tool_tasks[task_id] = task - # 🔥 ADD CLEANUP CALLBACK (Nova Sonic pattern) def cleanup_task(completed_task: asyncio.Task, task_id: str = task_id) -> None: try: # Remove from pending tasks @@ -346,7 +348,7 @@ def cleanup_task(completed_task: asyncio.Task, task_id: str = task_id) -> None: except asyncio.TimeoutError: if not session.active: break - # 🔥 PERIODIC CLEANUP OF COMPLETED TASKS + # Remove completed tasks from tracking completed_tasks = [task_id for task_id, task in session.pending_tool_tasks.items() if task.done()] for task_id in completed_tasks: if task_id in session.pending_tool_tasks: @@ -364,24 +366,7 @@ def cleanup_task(completed_task: asyncio.Task, task_id: str = task_id) -> None: log_flow("tool_execution", "processor stopped") -def _convert_to_strands_event(provider_event: dict) -> dict: - """Pass-through for events already normalized by provider sessions. - - Providers convert their raw events to standard format before reaching here. - This just validates and passes through the normalized events. - - Args: - provider_event: Already normalized event from provider session. - - Returns: - Dict: The same event, validated and passed through. - """ - # Basic validation - ensure we have a dict - if not isinstance(provider_event, dict): - return {} - # Pass through - conversion already done by provider session - return provider_event async def _execute_tool_with_strands(session: BidirectionalConnection, tool_use: dict) -> None: @@ -398,7 +383,7 @@ async def _execute_tool_with_strands(session: BidirectionalConnection, tool_use: tool_id = tool_use.get("toolUseId") try: - # 🔥 CHECK FOR INTERRUPTION BEFORE STARTING (Nova Sonic pattern) + # Skip execution if session is interrupted or inactive if session.interrupted or not session.active: log_event("tool_execution_cancelled_before_start", name=tool_name, id=tool_id) return @@ -422,7 +407,7 @@ async def _execute_tool_with_strands(session: BidirectionalConnection, tool_use: # Execute tools directly (simpler approach for bidirectional) for tool_use in valid_tool_uses: - # 🔥 CHECK FOR INTERRUPTION DURING EXECUTION + # Return early if session was interrupted during execution if session.interrupted or not session.active: log_event("tool_execution_cancelled_during", name=tool_name, id=tool_id) return @@ -433,12 +418,12 @@ async def _execute_tool_with_strands(session: BidirectionalConnection, tool_use: try: actual_func = _extract_callable_function(tool_func) - # 🔥 WRAP TOOL EXECUTION IN CANCELLATION CHECK + # Execute tool function with provided input # For async tools, we could wrap with asyncio.wait_for with cancellation # For sync tools, we execute directly but check interruption after result = actual_func(**tool_use.get("input", {})) - # 🔥 CHECK FOR INTERRUPTION AFTER TOOL EXECUTION + # Discard result if session was interrupted during execution if session.interrupted or not session.active: log_event("tool_result_discarded_interruption", name=tool_name, id=tool_id) return @@ -451,7 +436,7 @@ async def _execute_tool_with_strands(session: BidirectionalConnection, tool_use: log_event("tool_execution_cancelled", name=tool_name, id=tool_id) return except Exception as e: - # 🔥 CHECK FOR INTERRUPTION EVEN ON ERROR + # Discard error result if session was interrupted if session.interrupted or not session.active: log_event("tool_error_discarded_interruption", name=tool_name, id=tool_id) return @@ -462,7 +447,7 @@ async def _execute_tool_with_strands(session: BidirectionalConnection, tool_use: else: log_event("tool_not_found", name=tool_name) - # 🔥 FINAL INTERRUPTION CHECK BEFORE SENDING RESULTS + # Skip sending results if session was interrupted if session.interrupted or not session.active: log_event("tool_results_discarded_interruption", name=tool_name, count=len(tool_results)) return From 5921f8bdb24740adb2b6ad2af609218674b4b4b5 Mon Sep 17 00:00:00 2001 From: Rachit Mehta Date: Thu, 2 Oct 2025 12:23:21 -0400 Subject: [PATCH 09/23] Updated: fixed module imports baesd on the new smithy python release on 09-29, added a lock for interruption handling --- .../event_loop/bidirectional_event_loop.py | 118 ++++++++++-------- .../models/novasonic.py | 6 +- .../tests/test_bidirectional_streaming.py | 4 +- 3 files changed, 68 insertions(+), 60 deletions(-) diff --git a/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py b/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py index fddd1245a..358fdcea3 100644 --- a/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py +++ b/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py @@ -20,7 +20,7 @@ from ....tools._validator import validate_and_prepare_tools from ....types.content import Message from ....types.tools import ToolResult, ToolUse -from ..agent.agent import BidirectionalAgent + from ..models.bidirectional_model import BidirectionalModelSession from ..utils.debug import log_event, log_flow @@ -38,7 +38,7 @@ class BidirectionalConnection: handling while providing a simple interface for agent interactions. """ - def __init__(self, model_session: BidirectionalModelSession, agent: BidirectionalAgent) -> None: + def __init__(self, model_session: BidirectionalModelSession, agent: "BidirectionalAgent") -> None: """Initialize session with model session and agent reference. Args: @@ -59,9 +59,10 @@ def __init__(self, model_session: BidirectionalModelSession, agent: Bidirectiona # Interruption handling (model-agnostic) self.interrupted = False + self.interruption_lock = asyncio.Lock() -async def start_bidirectional_connection(agent: BidirectionalAgent) -> BidirectionalConnection: +async def start_bidirectional_connection(agent: "BidirectionalAgent") -> BidirectionalConnection: """Initialize bidirectional session with conycurrent background tasks. Creates a model-specific session and starts background tasks for processing @@ -181,66 +182,73 @@ async def _handle_interruption(session: BidirectionalConnection) -> None: """Handle interruption detection with task cancellation and audio buffer clearing. Cancels pending tool tasks and clears audio output queues to ensure responsive - interruption handling during conversations. + interruption handling during conversations. Protected by async lock to prevent + concurrent execution and race conditions. Args: session: BidirectionalConnection to handle interruption for. """ - log_event("interruption_detected") - session.interrupted = True + async with session.interruption_lock: + # If already interrupted, skip duplicate processing + if session.interrupted: + log_event("interruption_already_in_progress") + return - # Cancel all pending tool execution tasks - cancelled_tools = 0 - for task_id, task in list(session.pending_tool_tasks.items()): - if not task.done(): - task.cancel() - cancelled_tools += 1 - log_event("tool_task_cancelled", task_id=task_id) + log_event("interruption_detected") + session.interrupted = True - if cancelled_tools > 0: - log_event("tool_tasks_cancelled", count=cancelled_tools) + # Cancel all pending tool execution tasks + cancelled_tools = 0 + for task_id, task in list(session.pending_tool_tasks.items()): + if not task.done(): + task.cancel() + cancelled_tools += 1 + log_event("tool_task_cancelled", task_id=task_id) - # Clear all queued audio output events - cleared_count = 0 - while True: - try: - session.audio_output_queue.get_nowait() - cleared_count += 1 - except asyncio.QueueEmpty: - break + if cancelled_tools > 0: + log_event("tool_tasks_cancelled", count=cancelled_tools) - # Also clear the agent's audio output queue if it exists - if hasattr(session.agent, "_output_queue"): - audio_cleared = 0 - # Create a temporary list to hold non-audio events - temp_events = [] - try: - while True: - event = session.agent._output_queue.get_nowait() - if event.get("audioOutput"): - audio_cleared += 1 - else: - # Keep non-audio events - temp_events.append(event) - except asyncio.QueueEmpty: - pass - - # Put back non-audio events - for event in temp_events: - session.agent._output_queue.put_nowait(event) - - if audio_cleared > 0: - log_event("agent_audio_queue_cleared", count=audio_cleared) - - if cleared_count > 0: - log_event("session_audio_queue_cleared", count=cleared_count) - - # Brief sleep to allow audio system to settle (matches Nova Sonic timing) - await asyncio.sleep(0.05) - - # Reset interruption flag after clearing (automatic recovery) - session.interrupted = False - log_event("interruption_handled", tools_cancelled=cancelled_tools, audio_cleared=cleared_count) + # Clear all queued audio output events + cleared_count = 0 + while True: + try: + session.audio_output_queue.get_nowait() + cleared_count += 1 + except asyncio.QueueEmpty: + break + + # Also clear the agent's audio output queue if it exists + if hasattr(session.agent, "_output_queue"): + audio_cleared = 0 + # Create a temporary list to hold non-audio events + temp_events = [] + try: + while True: + event = session.agent._output_queue.get_nowait() + if event.get("audioOutput"): + audio_cleared += 1 + else: + # Keep non-audio events + temp_events.append(event) + except asyncio.QueueEmpty: + pass + + # Put back non-audio events + for event in temp_events: + session.agent._output_queue.put_nowait(event) + + if audio_cleared > 0: + log_event("agent_audio_queue_cleared", count=audio_cleared) + + if cleared_count > 0: + log_event("session_audio_queue_cleared", count=cleared_count) + + # Brief sleep to allow audio system to settle (matches Nova Sonic timing) + await asyncio.sleep(0.05) + + # Reset interruption flag after clearing (automatic recovery) + session.interrupted = False + log_event("interruption_handled", tools_cancelled=cancelled_tools, audio_cleared=cleared_count) async def _process_model_events(session: BidirectionalConnection) -> None: diff --git a/src/strands/experimental/bidirectional_streaming/models/novasonic.py b/src/strands/experimental/bidirectional_streaming/models/novasonic.py index 89472350b..e79229623 100644 --- a/src/strands/experimental/bidirectional_streaming/models/novasonic.py +++ b/src/strands/experimental/bidirectional_streaming/models/novasonic.py @@ -24,7 +24,7 @@ from aws_sdk_bedrock_runtime.client import BedrockRuntimeClient, InvokeModelWithBidirectionalStreamOperationInput from aws_sdk_bedrock_runtime.config import Config, HTTPAuthSchemeResolver, SigV4AuthScheme from aws_sdk_bedrock_runtime.models import BidirectionalInputPayloadPart, InvokeModelWithBidirectionalStreamInputChunk -from smithy_aws_core.credentials_resolvers.environment import EnvironmentCredentialsResolver +from smithy_aws_core.identity.environment import EnvironmentCredentialsResolver from ....types.content import Messages from ....types.tools import ToolSpec, ToolUse @@ -703,8 +703,8 @@ async def _initialize_client(self) -> None: endpoint_uri=f"https://bedrock-runtime.{self.region}.amazonaws.com", region=self.region, aws_credentials_identity_resolver=EnvironmentCredentialsResolver(), - http_auth_scheme_resolver=HTTPAuthSchemeResolver(), - http_auth_schemes={"aws.auth#sigv4": SigV4AuthScheme()}, + auth_scheme_resolver=HTTPAuthSchemeResolver(), + auth_schemes={"aws.auth#sigv4": SigV4AuthScheme(service="bedrock")}, ) self._client = BedrockRuntimeClient(config=config) diff --git a/src/strands/experimental/bidirectional_streaming/tests/test_bidirectional_streaming.py b/src/strands/experimental/bidirectional_streaming/tests/test_bidirectional_streaming.py index 6ef96f919..b31607966 100644 --- a/src/strands/experimental/bidirectional_streaming/tests/test_bidirectional_streaming.py +++ b/src/strands/experimental/bidirectional_streaming/tests/test_bidirectional_streaming.py @@ -15,8 +15,8 @@ import pyaudio from strands_tools import calculator -from ..agent.agent import BidirectionalAgent -from ..models.novasonic import NovaSonicBidirectionalModel +from strands.experimental.bidirectional_streaming.agent.agent import BidirectionalAgent +from strands.experimental.bidirectional_streaming.models.novasonic import NovaSonicBidirectionalModel async def play(context): From 8cb4d98ba035d021cdff1953cf9705cca114e270 Mon Sep 17 00:00:00 2001 From: Rachit Mehta Date: Thu, 2 Oct 2025 12:37:33 -0400 Subject: [PATCH 10/23] Removed unnecessary _output_queue check as the queue will always be initialized, and removed asyncio.sleep() as they were mainly for defensive purposes and following the pattern of nova sonic samples. --- .../event_loop/bidirectional_event_loop.py | 50 ++++++++----------- 1 file changed, 21 insertions(+), 29 deletions(-) diff --git a/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py b/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py index 358fdcea3..b4395f38e 100644 --- a/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py +++ b/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py @@ -20,7 +20,6 @@ from ....tools._validator import validate_and_prepare_tools from ....types.content import Message from ....types.tools import ToolResult, ToolUse - from ..models.bidirectional_model import BidirectionalModelSession from ..utils.debug import log_event, log_flow @@ -95,10 +94,7 @@ async def start_bidirectional_connection(agent: "BidirectionalAgent") -> Bidirec # Start main coordination cycle session.main_cycle_task = asyncio.create_task(bidirectional_event_loop_cycle(session)) - # Give background tasks a moment to start - await asyncio.sleep(0.1) log_event("session_ready", tasks=len(session.background_tasks)) - return session @@ -217,35 +213,31 @@ async def _handle_interruption(session: BidirectionalConnection) -> None: except asyncio.QueueEmpty: break - # Also clear the agent's audio output queue if it exists - if hasattr(session.agent, "_output_queue"): - audio_cleared = 0 - # Create a temporary list to hold non-audio events - temp_events = [] - try: - while True: - event = session.agent._output_queue.get_nowait() - if event.get("audioOutput"): - audio_cleared += 1 - else: - # Keep non-audio events - temp_events.append(event) - except asyncio.QueueEmpty: - pass - - # Put back non-audio events - for event in temp_events: - session.agent._output_queue.put_nowait(event) - - if audio_cleared > 0: - log_event("agent_audio_queue_cleared", count=audio_cleared) + # Also clear the agent's audio output queue + audio_cleared = 0 + # Create a temporary list to hold non-audio events + temp_events = [] + try: + while True: + event = session.agent._output_queue.get_nowait() + if event.get("audioOutput"): + audio_cleared += 1 + else: + # Keep non-audio events + temp_events.append(event) + except asyncio.QueueEmpty: + pass + + # Put back non-audio events + for event in temp_events: + session.agent._output_queue.put_nowait(event) + + if audio_cleared > 0: + log_event("agent_audio_queue_cleared", count=audio_cleared) if cleared_count > 0: log_event("session_audio_queue_cleared", count=cleared_count) - # Brief sleep to allow audio system to settle (matches Nova Sonic timing) - await asyncio.sleep(0.05) - # Reset interruption flag after clearing (automatic recovery) session.interrupted = False log_event("interruption_handled", tools_cancelled=cancelled_tools, audio_cleared=cleared_count) From 7a6e53efdf669352bd18f19531178d46589c214d Mon Sep 17 00:00:00 2001 From: Rachit Mehta Date: Thu, 2 Oct 2025 13:03:42 -0400 Subject: [PATCH 11/23] Remove redundant interruption checks --- .../event_loop/bidirectional_event_loop.py | 67 +++---------------- 1 file changed, 11 insertions(+), 56 deletions(-) diff --git a/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py b/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py index b4395f38e..cc4f416b7 100644 --- a/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py +++ b/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py @@ -264,7 +264,7 @@ async def _process_model_events(session: BidirectionalConnection) -> None: strands_event = provider_event - # Handle interruption detection (multiple patterns) + # Handle interruption detection (provider converts raw patterns to interruptionDetected) if strands_event.get("interruptionDetected"): log_event("interruption_forwarded") await _handle_interruption(session) @@ -272,16 +272,6 @@ async def _process_model_events(session: BidirectionalConnection) -> None: await session.agent._output_queue.put(strands_event) continue - # Check for text-based interruption - if strands_event.get("textOutput"): - text_content = strands_event["textOutput"].get("content", "") - if '{ "interrupted" : true }' in text_content: - log_event("text_interruption_detected") - await _handle_interruption(session) - # Still forward the text event - await session.agent._output_queue.put(strands_event) - continue - # Queue tool requests for concurrent execution if strands_event.get("toolUse"): log_event("tool_queued", name=strands_event["toolUse"].get("name")) @@ -308,8 +298,8 @@ async def _process_tool_execution(session: BidirectionalConnection) -> None: """Execute tools concurrently with interruption support. Background task that manages tool execution without blocking model event - processing or user interaction. Includes proper task cleanup and cancellation - handling for interruptions. + processing or user interaction. Uses proper asyncio cancellation for + interruption handling rather than manual state checks. Args: session: BidirectionalConnection containing tool queue. @@ -320,9 +310,6 @@ async def _process_tool_execution(session: BidirectionalConnection) -> None: tool_use = await asyncio.wait_for(session.tool_queue.get(), timeout=TOOL_QUEUE_TIMEOUT) log_event("tool_execution_started", name=tool_use.get("name"), id=tool_use.get("toolUseId")) - if not session.active: - break - task_id = str(uuid.uuid4()) task = asyncio.create_task(_execute_tool_with_strands(session, tool_use)) session.pending_tool_tasks[task_id] = task @@ -372,8 +359,9 @@ def cleanup_task(completed_task: asyncio.Task, task_id: str = task_id) -> None: async def _execute_tool_with_strands(session: BidirectionalConnection, tool_use: dict) -> None: """Execute tool using Strands infrastructure with interruption support. - Executes tools using the existing Strands tool system, handles interruption - during execution, and sends results back to the model provider. + Executes tools using the existing Strands tool system with proper asyncio + cancellation handling. Tool execution is stopped via task cancellation, + not manual state checks. Args: session: BidirectionalConnection for context. @@ -383,11 +371,6 @@ async def _execute_tool_with_strands(session: BidirectionalConnection, tool_use: tool_id = tool_use.get("toolUseId") try: - # Skip execution if session is interrupted or inactive - if session.interrupted or not session.active: - log_event("tool_execution_cancelled_before_start", name=tool_name, id=tool_id) - return - # Create message structure for existing tool system tool_message: Message = {"role": "assistant", "content": [{"toolUse": tool_use}]} @@ -407,11 +390,6 @@ async def _execute_tool_with_strands(session: BidirectionalConnection, tool_use: # Execute tools directly (simpler approach for bidirectional) for tool_use in valid_tool_uses: - # Return early if session was interrupted during execution - if session.interrupted or not session.active: - log_event("tool_execution_cancelled_during", name=tool_name, id=tool_id) - return - tool_func = session.agent.tool_registry.registry.get(tool_use["name"]) if tool_func: @@ -419,39 +397,18 @@ async def _execute_tool_with_strands(session: BidirectionalConnection, tool_use: actual_func = _extract_callable_function(tool_func) # Execute tool function with provided input - # For async tools, we could wrap with asyncio.wait_for with cancellation - # For sync tools, we execute directly but check interruption after result = actual_func(**tool_use.get("input", {})) - # Discard result if session was interrupted during execution - if session.interrupted or not session.active: - log_event("tool_result_discarded_interruption", name=tool_name, id=tool_id) - return - tool_result = _create_success_result(tool_use["toolUseId"], result) tool_results.append(tool_result) - except asyncio.CancelledError: - # Tool was cancelled due to interruption - log_event("tool_execution_cancelled", name=tool_name, id=tool_id) - return except Exception as e: - # Discard error result if session was interrupted - if session.interrupted or not session.active: - log_event("tool_error_discarded_interruption", name=tool_name, id=tool_id) - return - log_event("tool_execution_failed", name=tool_name, error=str(e)) tool_result = _create_error_result(tool_use["toolUseId"], str(e)) tool_results.append(tool_result) else: log_event("tool_not_found", name=tool_name) - # Skip sending results if session was interrupted - if session.interrupted or not session.active: - log_event("tool_results_discarded_interruption", name=tool_name, count=len(tool_results)) - return - # Send results through provider-specific session for result in tool_results: await session.model_session.send_tool_result(tool_use.get("toolUseId"), result) @@ -464,13 +421,11 @@ async def _execute_tool_with_strands(session: BidirectionalConnection, tool_use: raise # Re-raise to properly handle cancellation except Exception as e: log_event("tool_execution_error", name=tool_use.get("name"), error=str(e)) - - # Only send error if not interrupted - if not session.interrupted and session.active: - try: - await session.model_session.send_tool_error(tool_use.get("toolUseId"), str(e)) - except Exception as send_error: - log_event("tool_error_send_failed", error=str(send_error)) + + try: + await session.model_session.send_tool_error(tool_use.get("toolUseId"), str(e)) + except Exception as send_error: + log_event("tool_error_send_failed", error=str(send_error)) def _extract_callable_function(tool_func: any) -> any: From a58626107b21dad40a52bf27320f35e1af9a5df8 Mon Sep 17 00:00:00 2001 From: Rachit Mehta Date: Thu, 2 Oct 2025 13:25:51 -0400 Subject: [PATCH 12/23] Unified tool result and tool error methods, Added implementation to add user messages to the agent messages --- .../bidirectional_streaming/agent/agent.py | 8 ++++++-- .../event_loop/bidirectional_event_loop.py | 19 ++++++++++++------- .../models/bidirectional_model.py | 6 +----- .../models/novasonic.py | 6 ------ 4 files changed, 19 insertions(+), 20 deletions(-) diff --git a/src/strands/experimental/bidirectional_streaming/agent/agent.py b/src/strands/experimental/bidirectional_streaming/agent/agent.py index e27885c7e..46bc38ef2 100644 --- a/src/strands/experimental/bidirectional_streaming/agent/agent.py +++ b/src/strands/experimental/bidirectional_streaming/agent/agent.py @@ -87,7 +87,8 @@ async def send(self, input_data: str | AudioInputEvent) -> None: """Send input to the model (text or audio). Unified method for sending both text and audio input to the model during - an active conversation session. + an active conversation session. User input is automatically added to + conversation history for complete message tracking. Args: input_data: Either a string for text input or AudioInputEvent for audio input. @@ -98,10 +99,13 @@ async def send(self, input_data: str | AudioInputEvent) -> None: self._validate_active_session() if isinstance(input_data, str): + # Add user text message to history + self.messages.append({"role": "user", "content": input_data}) + log_event("text_sent", length=len(input_data)) await self._session.model_session.send_text_content(input_data) elif isinstance(input_data, dict) and "audioData" in input_data: - # Handle audio input (AudioInputEvent) + # Handle audio input await self._session.model_session.send_audio_content(input_data) else: raise ValueError( diff --git a/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py b/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py index cc4f416b7..684c0037e 100644 --- a/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py +++ b/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py @@ -261,7 +261,7 @@ async def _process_model_events(session: BidirectionalConnection) -> None: # Basic validation - skip invalid events if not isinstance(provider_event, dict): continue - + strands_event = provider_event # Handle interruption detection (provider converts raw patterns to interruptionDetected) @@ -287,6 +287,14 @@ async def _process_model_events(session: BidirectionalConnection) -> None: log_event("message_added_to_history") session.agent.messages.append(strands_event["messageStop"]["message"]) + # Handle user audio transcripts - add to message history + if strands_event.get("textOutput") and strands_event["textOutput"].get("role") == "user": + user_transcript = strands_event["textOutput"]["text"] + if user_transcript.strip(): # Only add non-empty transcripts + user_message = {"role": "user", "content": user_transcript} + session.agent.messages.append(user_message) + log_event("user_transcript_added_to_history") + except Exception as e: log_event("model_events_error", error=str(e)) traceback.print_exc() @@ -298,7 +306,7 @@ async def _process_tool_execution(session: BidirectionalConnection) -> None: """Execute tools concurrently with interruption support. Background task that manages tool execution without blocking model event - processing or user interaction. Uses proper asyncio cancellation for + processing or user interaction. Uses proper asyncio cancellation for interruption handling rather than manual state checks. Args: @@ -353,9 +361,6 @@ def cleanup_task(completed_task: asyncio.Task, task_id: str = task_id) -> None: log_flow("tool_execution", "processor stopped") - - - async def _execute_tool_with_strands(session: BidirectionalConnection, tool_use: dict) -> None: """Execute tool using Strands infrastructure with interruption support. @@ -421,9 +426,9 @@ async def _execute_tool_with_strands(session: BidirectionalConnection, tool_use: raise # Re-raise to properly handle cancellation except Exception as e: log_event("tool_execution_error", name=tool_use.get("name"), error=str(e)) - + try: - await session.model_session.send_tool_error(tool_use.get("toolUseId"), str(e)) + await session.model_session.send_tool_result(tool_use.get("toolUseId"), {"error": str(e)}) except Exception as send_error: log_event("tool_error_send_failed", error=str(send_error)) diff --git a/src/strands/experimental/bidirectional_streaming/models/bidirectional_model.py b/src/strands/experimental/bidirectional_streaming/models/bidirectional_model.py index 1432b112a..4cd9cc6b8 100644 --- a/src/strands/experimental/bidirectional_streaming/models/bidirectional_model.py +++ b/src/strands/experimental/bidirectional_streaming/models/bidirectional_model.py @@ -71,14 +71,10 @@ async def send_tool_result(self, tool_use_id: str, result: dict[str, any]) -> No """Send tool execution result to the model. Formats and sends tool results according to the provider's specific protocol. + Handles both successful results and error cases. """ raise NotImplementedError - @abc.abstractmethod - async def send_tool_error(self, tool_use_id: str, error: str) -> None: - """Send tool execution error to model in provider-specific format.""" - raise NotImplementedError - @abc.abstractmethod async def close(self) -> None: """Close the connection and cleanup resources.""" diff --git a/src/strands/experimental/bidirectional_streaming/models/novasonic.py b/src/strands/experimental/bidirectional_streaming/models/novasonic.py index e79229623..dfd911172 100644 --- a/src/strands/experimental/bidirectional_streaming/models/novasonic.py +++ b/src/strands/experimental/bidirectional_streaming/models/novasonic.py @@ -386,12 +386,6 @@ async def send_tool_result(self, tool_use_id: str, result: dict[str, any]) -> No for i, event in enumerate(events): await time_it_async(f"send_tool_event_{i + 1}", lambda event=event: self._send_nova_event(event)) - async def send_tool_error(self, tool_use_id: str, error: str) -> None: - """Send tool error using Nova Sonic format.""" - log_event("nova_tool_error_send", id=tool_use_id, error=error) - error_result = {"error": error} - await self.send_tool_result(tool_use_id, error_result) - async def close(self) -> None: """Close Nova Sonic connection with proper cleanup sequence.""" if not self._active: From 16d9b461d187b45ee6d3305268ef23293accd3b0 Mon Sep 17 00:00:00 2001 From: Rachit Mehta Date: Thu, 2 Oct 2025 14:00:25 -0400 Subject: [PATCH 13/23] Modified logging to use python logger --- .../bidirectional_streaming/agent/agent.py | 8 +- .../event_loop/bidirectional_event_loop.py | 89 ++++++++++--------- .../models/novasonic.py | 67 +++++++------- 3 files changed, 83 insertions(+), 81 deletions(-) diff --git a/src/strands/experimental/bidirectional_streaming/agent/agent.py b/src/strands/experimental/bidirectional_streaming/agent/agent.py index 46bc38ef2..68d371a51 100644 --- a/src/strands/experimental/bidirectional_streaming/agent/agent.py +++ b/src/strands/experimental/bidirectional_streaming/agent/agent.py @@ -22,7 +22,7 @@ from ..event_loop.bidirectional_event_loop import start_bidirectional_connection, stop_bidirectional_connection from ..models.bidirectional_model import BidirectionalModel from ..types.bidirectional_streaming import AudioInputEvent, BidirectionalStreamEvent -from ..utils.debug import log_event, log_flow + logger = logging.getLogger(__name__) @@ -79,9 +79,9 @@ async def start(self) -> None: if self._session and self._session.active: raise ValueError("Conversation already active. Call end() first.") - log_flow("conversation_start", "initializing session") + logger.debug("Conversation start - initializing session") self._session = await start_bidirectional_connection(self) - log_event("conversation_ready") + logger.debug("Conversation ready") async def send(self, input_data: str | AudioInputEvent) -> None: """Send input to the model (text or audio). @@ -102,7 +102,7 @@ async def send(self, input_data: str | AudioInputEvent) -> None: # Add user text message to history self.messages.append({"role": "user", "content": input_data}) - log_event("text_sent", length=len(input_data)) + logger.debug("Text sent: %d characters", len(input_data)) await self._session.model_session.send_text_content(input_data) elif isinstance(input_data, dict) and "audioData" in input_data: # Handle audio input diff --git a/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py b/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py index 684c0037e..16be08aaf 100644 --- a/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py +++ b/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py @@ -21,7 +21,7 @@ from ....types.content import Message from ....types.tools import ToolResult, ToolUse from ..models.bidirectional_model import BidirectionalModelSession -from ..utils.debug import log_event, log_flow + logger = logging.getLogger(__name__) @@ -73,7 +73,7 @@ async def start_bidirectional_connection(agent: "BidirectionalAgent") -> Bidirec Returns: BidirectionalConnection: Active session with background tasks running. """ - log_flow("session_start", "initializing model session") + logger.debug("Starting bidirectional session - initializing model session") # Create provider-specific session model_session = await agent.model.create_bidirectional_connection( @@ -85,7 +85,7 @@ async def start_bidirectional_connection(agent: "BidirectionalAgent") -> Bidirec # Start concurrent background processors IMMEDIATELY after session creation # This is critical - Nova Sonic needs response processing during initialization - log_flow("background_tasks", "starting processors") + logger.debug("Starting background processors for concurrent processing") session.background_tasks = [ asyncio.create_task(_process_model_events(session)), # Handle model responses asyncio.create_task(_process_tool_execution(session)), # Execute tools concurrently @@ -94,7 +94,7 @@ async def start_bidirectional_connection(agent: "BidirectionalAgent") -> Bidirec # Start main coordination cycle session.main_cycle_task = asyncio.create_task(bidirectional_event_loop_cycle(session)) - log_event("session_ready", tasks=len(session.background_tasks)) + logger.debug("Session ready with %d background tasks", len(session.background_tasks)) return session @@ -107,7 +107,7 @@ async def stop_bidirectional_connection(session: BidirectionalConnection) -> Non if not session.active: return - log_flow("session_cleanup", "starting") + logger.debug("Session cleanup starting") session.active = False # Cancel pending tool tasks @@ -134,7 +134,7 @@ async def stop_bidirectional_connection(session: BidirectionalConnection) -> Non # Close model session await session.model_session.close() - log_event("session_closed") + logger.debug("Session closed") async def bidirectional_event_loop_cycle(session: BidirectionalConnection) -> None: @@ -150,7 +150,7 @@ async def bidirectional_event_loop_cycle(session: BidirectionalConnection) -> No try: # Check if background processors are still running if all(task.done() for task in session.background_tasks): - log_event("session_end", reason="all_processors_completed") + logger.debug("Session end - all processors completed") session.active = False break @@ -159,7 +159,7 @@ async def bidirectional_event_loop_cycle(session: BidirectionalConnection) -> No if task.done() and not task.cancelled(): exception = task.exception() if exception: - log_event("session_error", processor=i, error=str(exception)) + logger.error("Session error in processor %d: %s", i, str(exception)) session.active = False raise exception @@ -169,7 +169,7 @@ async def bidirectional_event_loop_cycle(session: BidirectionalConnection) -> No except asyncio.CancelledError: break except Exception as e: - log_event("event_loop_error", error=str(e)) + logger.error("Event loop error: %s", str(e)) session.active = False raise @@ -187,10 +187,10 @@ async def _handle_interruption(session: BidirectionalConnection) -> None: async with session.interruption_lock: # If already interrupted, skip duplicate processing if session.interrupted: - log_event("interruption_already_in_progress") + logger.debug("Interruption already in progress") return - log_event("interruption_detected") + logger.debug("Interruption detected") session.interrupted = True # Cancel all pending tool execution tasks @@ -199,10 +199,10 @@ async def _handle_interruption(session: BidirectionalConnection) -> None: if not task.done(): task.cancel() cancelled_tools += 1 - log_event("tool_task_cancelled", task_id=task_id) + logger.debug("Tool task cancelled: %s", task_id) if cancelled_tools > 0: - log_event("tool_tasks_cancelled", count=cancelled_tools) + logger.debug("Tool tasks cancelled: %d", cancelled_tools) # Clear all queued audio output events cleared_count = 0 @@ -233,14 +233,14 @@ async def _handle_interruption(session: BidirectionalConnection) -> None: session.agent._output_queue.put_nowait(event) if audio_cleared > 0: - log_event("agent_audio_queue_cleared", count=audio_cleared) + logger.debug("Agent audio queue cleared: %d events", audio_cleared) if cleared_count > 0: - log_event("session_audio_queue_cleared", count=cleared_count) + logger.debug("Session audio queue cleared: %d events", cleared_count) # Reset interruption flag after clearing (automatic recovery) session.interrupted = False - log_event("interruption_handled", tools_cancelled=cancelled_tools, audio_cleared=cleared_count) + logger.debug("Interruption handled - tools cancelled: %d, audio cleared: %d", cancelled_tools, cleared_count) async def _process_model_events(session: BidirectionalConnection) -> None: @@ -252,7 +252,7 @@ async def _process_model_events(session: BidirectionalConnection) -> None: Args: session: BidirectionalConnection containing model session. """ - log_flow("model_events", "processor started") + logger.debug("Model events processor started") try: async for provider_event in session.model_session.receive_events(): if not session.active: @@ -261,12 +261,12 @@ async def _process_model_events(session: BidirectionalConnection) -> None: # Basic validation - skip invalid events if not isinstance(provider_event, dict): continue - + strands_event = provider_event # Handle interruption detection (provider converts raw patterns to interruptionDetected) if strands_event.get("interruptionDetected"): - log_event("interruption_forwarded") + logger.debug("Interruption forwarded") await _handle_interruption(session) # Forward interruption event to agent for application-level handling await session.agent._output_queue.put(strands_event) @@ -274,7 +274,7 @@ async def _process_model_events(session: BidirectionalConnection) -> None: # Queue tool requests for concurrent execution if strands_event.get("toolUse"): - log_event("tool_queued", name=strands_event["toolUse"].get("name")) + logger.debug("Tool queued: %s", strands_event["toolUse"].get("name")) await session.tool_queue.put(strands_event["toolUse"]) continue @@ -284,39 +284,39 @@ async def _process_model_events(session: BidirectionalConnection) -> None: # Update Agent conversation history using existing patterns if strands_event.get("messageStop"): - log_event("message_added_to_history") + logger.debug("Message added to history") session.agent.messages.append(strands_event["messageStop"]["message"]) - + # Handle user audio transcripts - add to message history if strands_event.get("textOutput") and strands_event["textOutput"].get("role") == "user": user_transcript = strands_event["textOutput"]["text"] if user_transcript.strip(): # Only add non-empty transcripts user_message = {"role": "user", "content": user_transcript} session.agent.messages.append(user_message) - log_event("user_transcript_added_to_history") + logger.debug("User transcript added to history") except Exception as e: - log_event("model_events_error", error=str(e)) + logger.error("Model events error: %s", str(e)) traceback.print_exc() finally: - log_flow("model_events", "processor stopped") + logger.debug("Model events processor stopped") async def _process_tool_execution(session: BidirectionalConnection) -> None: """Execute tools concurrently with interruption support. Background task that manages tool execution without blocking model event - processing or user interaction. Uses proper asyncio cancellation for + processing or user interaction. Uses proper asyncio cancellation for interruption handling rather than manual state checks. Args: session: BidirectionalConnection containing tool queue. """ - log_flow("tool_execution", "processor started") + logger.debug("Tool execution processor started") while session.active: try: tool_use = await asyncio.wait_for(session.tool_queue.get(), timeout=TOOL_QUEUE_TIMEOUT) - log_event("tool_execution_started", name=tool_use.get("name"), id=tool_use.get("toolUseId")) + logger.debug("Tool execution started: %s (id: %s)", tool_use.get("name"), tool_use.get("toolUseId")) task_id = str(uuid.uuid4()) task = asyncio.create_task(_execute_tool_with_strands(session, tool_use)) @@ -330,13 +330,13 @@ def cleanup_task(completed_task: asyncio.Task, task_id: str = task_id) -> None: # Log completion status if completed_task.cancelled(): - log_event("tool_task_cleanup_cancelled", task_id=task_id) + logger.debug("Tool task cleanup cancelled: %s", task_id) elif completed_task.exception(): - log_event("tool_task_cleanup_error", task_id=task_id, error=str(completed_task.exception())) + logger.error("Tool task cleanup error: %s - %s", task_id, str(completed_task.exception())) else: - log_event("tool_task_cleanup_success", task_id=task_id) + logger.debug("Tool task cleanup success: %s", task_id) except Exception as e: - log_event("tool_task_cleanup_failed", task_id=task_id, error=str(e)) + logger.error("Tool task cleanup failed: %s - %s", task_id, str(e)) task.add_done_callback(cleanup_task) @@ -350,15 +350,18 @@ def cleanup_task(completed_task: asyncio.Task, task_id: str = task_id) -> None: del session.pending_tool_tasks[task_id] if completed_tasks: - log_event("periodic_task_cleanup", count=len(completed_tasks)) + logger.debug("Periodic task cleanup: %d tasks", len(completed_tasks)) continue except Exception as e: - log_event("tool_execution_error", error=str(e)) + logger.error("Tool execution error: %s", str(e)) if not session.active: break - log_flow("tool_execution", "processor stopped") + logger.debug("Tool execution processor stopped") + + + async def _execute_tool_with_strands(session: BidirectionalConnection, tool_use: dict) -> None: @@ -390,7 +393,7 @@ async def _execute_tool_with_strands(session: BidirectionalConnection, tool_use: valid_tool_uses = [tu for tu in tool_uses if tu.get("toolUseId") not in invalid_tool_use_ids] if not valid_tool_uses: - log_event("tool_validation_failed", name=tool_name, id=tool_id) + logger.warning("Tool validation failed: %s (id: %s)", tool_name, tool_id) return # Execute tools directly (simpler approach for bidirectional) @@ -408,29 +411,29 @@ async def _execute_tool_with_strands(session: BidirectionalConnection, tool_use: tool_results.append(tool_result) except Exception as e: - log_event("tool_execution_failed", name=tool_name, error=str(e)) + logger.error("Tool execution failed: %s - %s", tool_name, str(e)) tool_result = _create_error_result(tool_use["toolUseId"], str(e)) tool_results.append(tool_result) else: - log_event("tool_not_found", name=tool_name) + logger.warning("Tool not found: %s", tool_name) # Send results through provider-specific session for result in tool_results: await session.model_session.send_tool_result(tool_use.get("toolUseId"), result) - log_event("tool_execution_completed", name=tool_name, results=len(tool_results)) + logger.debug("Tool execution completed: %s (%d results)", tool_name, len(tool_results)) except asyncio.CancelledError: # Task was cancelled due to interruption - this is expected behavior - log_event("tool_task_cancelled_gracefully", name=tool_name, id=tool_id) + logger.debug("Tool task cancelled gracefully: %s (id: %s)", tool_name, tool_id) raise # Re-raise to properly handle cancellation except Exception as e: - log_event("tool_execution_error", name=tool_use.get("name"), error=str(e)) - + logger.error("Tool execution error: %s - %s", tool_use.get("name"), str(e)) + try: await session.model_session.send_tool_result(tool_use.get("toolUseId"), {"error": str(e)}) except Exception as send_error: - log_event("tool_error_send_failed", error=str(send_error)) + logger.error("Tool error send failed: %s", str(send_error)) def _extract_callable_function(tool_func: any) -> any: diff --git a/src/strands/experimental/bidirectional_streaming/models/novasonic.py b/src/strands/experimental/bidirectional_streaming/models/novasonic.py index dfd911172..7f7937ef1 100644 --- a/src/strands/experimental/bidirectional_streaming/models/novasonic.py +++ b/src/strands/experimental/bidirectional_streaming/models/novasonic.py @@ -36,7 +36,7 @@ InterruptionDetectedEvent, TextOutputEvent, ) -from ..utils.debug import log_event, log_flow, time_it_async + from .bidirectional_model import BidirectionalModel, BidirectionalModelSession logger = logging.getLogger(__name__) @@ -121,10 +121,10 @@ async def initialize( init_events = self._build_initialization_events(system_prompt, tools or [], messages) - log_flow("nova_init", f"sending {len(init_events)} events") + logger.debug(f"Nova Sonic initialization - sending {len(init_events)} events") await self._send_initialization_events(init_events) - log_event("nova_connection_initialized") + logger.info("Nova Sonic connection initialized successfully") self._response_task = asyncio.create_task(self._process_responses()) except Exception as e: @@ -147,12 +147,12 @@ def _build_initialization_events( async def _send_initialization_events(self, events: list[str]) -> None: """Send initialization events with required delays.""" for i, event in enumerate(events): - await time_it_async(f"send_init_event_{i + 1}", lambda event=event: self._send_nova_event(event)) + await self._send_nova_event(event) await asyncio.sleep(EVENT_DELAY) async def _process_responses(self) -> None: """Process Nova Sonic responses continuously.""" - log_flow("nova_responses", "processor started") + logger.debug("Nova Sonic response processor started") try: while self._active: @@ -167,14 +167,14 @@ async def _process_responses(self) -> None: await asyncio.sleep(0.1) continue except Exception as e: - log_event("nova_response_error", error=str(e)) + logger.warning(f"Nova Sonic response error: {e}") await asyncio.sleep(0.1) continue except Exception as e: - log_event("nova_fatal_error", error=str(e)) + logger.error(f"Nova Sonic fatal error: {e}") finally: - log_flow("nova_responses", "processor stopped") + logger.debug("Nova Sonic response processor stopped") async def _handle_response_data(self, response_data: str) -> None: """Handle decoded response data from Nova Sonic.""" @@ -190,21 +190,21 @@ async def _handle_response_data(self, response_data: str) -> None: await self._event_queue.put(nova_event) except json.JSONDecodeError as e: - log_event("nova_json_error", error=str(e)) + logger.warning(f"Nova Sonic JSON decode error: {e}") def _log_event_type(self, nova_event: dict[str, any]) -> None: """Log specific Nova Sonic event types for debugging.""" if "usageEvent" in nova_event: - log_event("nova_usage", usage=nova_event["usageEvent"]) + logger.debug("Nova usage: %s", nova_event["usageEvent"]) elif "textOutput" in nova_event: - log_event("nova_text_output") + logger.debug("Nova text output") elif "toolUse" in nova_event: tool_use = nova_event["toolUse"] - log_event("nova_tool_use", name=tool_use["toolName"], id=tool_use["toolUseId"]) + logger.debug("Nova tool use: %s (id: %s)", tool_use["toolName"], tool_use["toolUseId"]) elif "audioOutput" in nova_event: audio_content = nova_event["audioOutput"]["content"] audio_bytes = base64.b64decode(audio_content) - log_event("nova_audio_output", bytes=len(audio_bytes)) + logger.debug("Nova audio output: %d bytes", len(audio_bytes)) async def receive_events(self) -> AsyncIterable[dict[str, any]]: """Receive Nova Sonic events and convert to provider-agnostic format.""" @@ -212,7 +212,7 @@ async def receive_events(self) -> AsyncIterable[dict[str, any]]: logger.error("Stream is None") return - log_flow("nova_events", "starting event stream") + logger.debug("Nova events - starting event stream") # Emit connection start event to Strands event system connection_start: BidirectionalConnectionStartEvent = { @@ -257,7 +257,7 @@ async def start_audio_connection(self) -> None: if self.audio_connection_active: return - log_event("nova_audio_connection_start") + logger.debug("Nova audio connection start") audio_content_start = json.dumps( { @@ -319,7 +319,7 @@ async def _check_silence(self) -> None: if self.audio_connection_active and self.last_audio_time: elapsed = time.time() - self.last_audio_time if elapsed >= self.silence_threshold: - log_event("nova_silence_detected", elapsed=elapsed) + logger.debug("Nova silence detected: %.2f seconds", elapsed) await self.end_audio_input() except asyncio.CancelledError: pass @@ -329,7 +329,7 @@ async def end_audio_input(self) -> None: if not self.audio_connection_active: return - log_event("nova_audio_connection_end") + logger.debug("Nova audio connection end") audio_content_end = json.dumps( {"event": {"contentEnd": {"promptName": self.prompt_name, "contentName": self.audio_content_name}}} @@ -375,7 +375,7 @@ async def send_tool_result(self, tool_use_id: str, result: dict[str, any]) -> No if not self._active: return - log_event("nova_tool_result_send", id=tool_use_id) + logger.debug("Nova tool result send: %s", tool_use_id) content_name = str(uuid.uuid4()) events = [ self._get_tool_content_start_event(content_name, tool_use_id), @@ -384,14 +384,16 @@ async def send_tool_result(self, tool_use_id: str, result: dict[str, any]) -> No ] for i, event in enumerate(events): - await time_it_async(f"send_tool_event_{i + 1}", lambda event=event: self._send_nova_event(event)) + await self._send_nova_event(event) + + async def close(self) -> None: """Close Nova Sonic connection with proper cleanup sequence.""" if not self._active: return - log_flow("nova_cleanup", "starting connection close") + logger.debug("Nova cleanup - starting connection close") self._active = False # Cancel response processing task if running @@ -423,9 +425,9 @@ async def close(self) -> None: logger.warning("Error closing Nova Sonic stream: %s", e) except Exception as e: - log_event("nova_cleanup_error", error=str(e)) + logger.error("Nova cleanup error: %s", str(e)) finally: - log_event("nova_connection_closed") + logger.debug("Nova connection closed") def _convert_nova_event(self, nova_event: dict[str, any]) -> dict[str, any] | None: """Convert Nova Sonic events to provider-agnostic format.""" @@ -452,7 +454,7 @@ def _convert_nova_event(self, nova_event: dict[str, any]) -> dict[str, any] | No # Check for Nova Sonic interruption pattern (matches working sample) if '{ "interrupted" : true }' in text_content: - log_event("nova_interruption_in_text") + logger.debug("Nova interruption detected in text") interruption: InterruptionDetectedEvent = {"reason": "user_input"} return {"interruptionDetected": interruption} @@ -480,7 +482,7 @@ def _convert_nova_event(self, nova_event: dict[str, any]) -> dict[str, any] | No # Handle interruption elif nova_event.get("stopReason") == "INTERRUPTED": - log_event("nova_interruption_stop_reason") + logger.debug("Nova interruption stop reason") interruption: InterruptionDetectedEvent = {"reason": "user_input"} @@ -664,29 +666,26 @@ async def create_bidirectional_connection( **kwargs, ) -> BidirectionalModelSession: """Create Nova Sonic bidirectional connection.""" - log_flow("nova_connection_create", "starting") + logger.debug("Nova connection create - starting") # Initialize client if needed if not self._client: - await time_it_async("initialize_client", lambda: self._initialize_client()) + await self._initialize_client() # Start Nova Sonic bidirectional stream try: - stream = await time_it_async( - "invoke_model_with_bidirectional_stream", - lambda: self._client.invoke_model_with_bidirectional_stream( - InvokeModelWithBidirectionalStreamOperationInput(model_id=self.model_id) - ), + stream = await self._client.invoke_model_with_bidirectional_stream( + InvokeModelWithBidirectionalStreamOperationInput(model_id=self.model_id) ) # Create and initialize connection connection = NovaSonicSession(stream, self.config) - await time_it_async("initialize_connection", lambda: connection.initialize(system_prompt, tools, messages)) + await connection.initialize(system_prompt, tools, messages) - log_event("nova_connection_created") + logger.debug("Nova connection created") return connection except Exception as e: - log_event("nova_connection_create_error", error=str(e)) + logger.error("Nova connection create error: %s", str(e)) logger.error("Failed to create Nova Sonic connection: %s", e) raise From 04265baa9267865fe9686dbe89440c552e77f2da Mon Sep 17 00:00:00 2001 From: Rachit Mehta Date: Thu, 2 Oct 2025 14:02:24 -0400 Subject: [PATCH 14/23] Removed logging utility --- .../bidirectional_streaming/utils/__init__.py | 5 -- .../bidirectional_streaming/utils/debug.py | 48 ------------------- 2 files changed, 53 deletions(-) delete mode 100644 src/strands/experimental/bidirectional_streaming/utils/__init__.py delete mode 100644 src/strands/experimental/bidirectional_streaming/utils/debug.py diff --git a/src/strands/experimental/bidirectional_streaming/utils/__init__.py b/src/strands/experimental/bidirectional_streaming/utils/__init__.py deleted file mode 100644 index 579478436..000000000 --- a/src/strands/experimental/bidirectional_streaming/utils/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -"""Utility functions for bidirectional streaming.""" - -from .debug import log_event, log_flow, time_it_async - -__all__ = ["log_event", "log_flow", "time_it_async"] diff --git a/src/strands/experimental/bidirectional_streaming/utils/debug.py b/src/strands/experimental/bidirectional_streaming/utils/debug.py deleted file mode 100644 index 6a7fc3982..000000000 --- a/src/strands/experimental/bidirectional_streaming/utils/debug.py +++ /dev/null @@ -1,48 +0,0 @@ -"""Debug utilities for Strands bidirectional streaming. - -Provides consistent debug logging across all bidirectional streaming components -with configurable output control matching the Nova Sonic tool use example. -""" - -import datetime -import inspect -import time - -# Debug logging system matching successful tool use example -DEBUG = False # Disable debug logging for clean output like tool use example - - -def debug_print(message): - """Print debug message with timestamp and function name.""" - if DEBUG: - function_name = inspect.stack()[1].function - if function_name == "time_it_async": - function_name = inspect.stack()[2].function - timestamp = "{:%Y-%m-%d %H:%M:%S.%f}".format(datetime.datetime.now())[:-3] - print(f"{timestamp} {function_name} {message}") - - -def log_event(event_type, **context): - """Log important events with structured context.""" - if DEBUG: - function_name = inspect.stack()[1].function - timestamp = "{:%Y-%m-%d %H:%M:%S.%f}".format(datetime.datetime.now())[:-3] - context_str = " ".join([f"{k}={v}" for k, v in context.items()]) if context else "" - print(f"{timestamp} {function_name} EVENT: {event_type} {context_str}") - - -def log_flow(step, details=""): - """Log important flow steps without excessive detail.""" - if DEBUG: - function_name = inspect.stack()[1].function - timestamp = "{:%Y-%m-%d %H:%M:%S.%f}".format(datetime.datetime.now())[:-3] - print(f"{timestamp} {function_name} FLOW: {step} {details}") - - -async def time_it_async(label, method_to_run): - """Time asynchronous method execution.""" - start_time = time.perf_counter() - result = await method_to_run() - end_time = time.perf_counter() - debug_print(f"Execution time for {label}: {end_time - start_time:.4f} seconds") - return result From 8a7396cf0715409b7fb35deb2c51b1164541a307 Mon Sep 17 00:00:00 2001 From: Rachit Mehta Date: Thu, 2 Oct 2025 14:36:36 -0400 Subject: [PATCH 15/23] Updated types --- .../experimental/bidirectional_streaming/__init__.py | 3 --- .../bidirectional_streaming/models/bidirectional_model.py | 6 +++--- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/src/strands/experimental/bidirectional_streaming/__init__.py b/src/strands/experimental/bidirectional_streaming/__init__.py index f6a3b41bf..52822711a 100644 --- a/src/strands/experimental/bidirectional_streaming/__init__.py +++ b/src/strands/experimental/bidirectional_streaming/__init__.py @@ -1,5 +1,2 @@ """Bidirectional streaming package for real-time audio/text conversations.""" -from .utils import log_event, log_flow, time_it_async - -__all__ = ["log_event", "log_flow", "time_it_async"] diff --git a/src/strands/experimental/bidirectional_streaming/models/bidirectional_model.py b/src/strands/experimental/bidirectional_streaming/models/bidirectional_model.py index 4cd9cc6b8..d5c3c9b65 100644 --- a/src/strands/experimental/bidirectional_streaming/models/bidirectional_model.py +++ b/src/strands/experimental/bidirectional_streaming/models/bidirectional_model.py @@ -17,7 +17,7 @@ from ....types.content import Messages from ....types.tools import ToolSpec -from ..types.bidirectional_streaming import AudioInputEvent +from ..types.bidirectional_streaming import AudioInputEvent, BidirectionalStreamEvent logger = logging.getLogger(__name__) @@ -31,7 +31,7 @@ class BidirectionalModelSession(abc.ABC): """ @abc.abstractmethod - async def receive_events(self) -> AsyncIterable[dict[str, any]]: + async def receive_events(self) -> AsyncIterable[BidirectionalStreamEvent]: """Receive events from the model in standardized format. Converts provider-specific events to a common format that can be @@ -71,7 +71,7 @@ async def send_tool_result(self, tool_use_id: str, result: dict[str, any]) -> No """Send tool execution result to the model. Formats and sends tool results according to the provider's specific protocol. - Handles both successful results and error cases. + Handles both successful results and error cases through the result dictionary. """ raise NotImplementedError From 3107e6bac979c137cf575b1fbd45b57e1e3c87fd Mon Sep 17 00:00:00 2001 From: Rachit Mehta Date: Thu, 9 Oct 2025 12:53:41 -0400 Subject: [PATCH 16/23] (feat)bidirectional_streaming: add openai realtime model provider --- pyproject.toml | 12 + .../bidirectional_streaming/__init__.py | 46 +- .../models/__init__.py | 10 +- .../models/novasonic.py | 12 +- .../bidirectional_streaming/models/openai.py | 508 ++++++++++++++++++ ...al_streaming.py => test_bidi_novasonic.py} | 0 .../tests/test_bidi_openai.py | 285 ++++++++++ .../bidirectional_streaming/types/__init__.py | 4 + .../types/bidirectional_streaming.py | 50 +- 9 files changed, 916 insertions(+), 11 deletions(-) create mode 100644 src/strands/experimental/bidirectional_streaming/models/openai.py rename src/strands/experimental/bidirectional_streaming/tests/{test_bidirectional_streaming.py => test_bidi_novasonic.py} (100%) create mode 100644 src/strands/experimental/bidirectional_streaming/tests/test_bidi_openai.py diff --git a/pyproject.toml b/pyproject.toml index 3b8866f4a..2900719ef 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -53,12 +53,24 @@ sagemaker = [ "boto3-stubs[sagemaker-runtime]>=1.26.0,<2.0.0", "openai>=1.68.0,<2.0.0", # SageMaker uses OpenAI-compatible interface ] +bidirectional-streaming-nova = [ + "pyaudio>=0.2.13", + "rx>=3.2.0", + "smithy-aws-core>=0.0.1", + "pytz", + "aws_sdk_bedrock_runtime", +] +bidirectional-streaming-openai = [ + "pyaudio>=0.2.13", + "websockets>=12.0,<14.0", +] bidirectional-streaming = [ "pyaudio>=0.2.13", "rx>=3.2.0", "smithy-aws-core>=0.0.1", "pytz", "aws_sdk_bedrock_runtime", + "websockets>=12.0,<14.0", ] otel = ["opentelemetry-exporter-otlp-proto-http>=1.30.0,<2.0.0"] docs = [ diff --git a/src/strands/experimental/bidirectional_streaming/__init__.py b/src/strands/experimental/bidirectional_streaming/__init__.py index 52822711a..a6af41dff 100644 --- a/src/strands/experimental/bidirectional_streaming/__init__.py +++ b/src/strands/experimental/bidirectional_streaming/__init__.py @@ -1,2 +1,46 @@ -"""Bidirectional streaming package for real-time audio/text conversations.""" +""" +Bidirectional streaming package. +""" +# Main components - Primary user interface +from .agent.agent import BidirectionalAgent + +# Model providers - What users need to create models +from .models.novasonic import NovaSonicBidirectionalModel +from .models.openai import OpenAIRealtimeBidirectionalModel + +# Event types - For type hints and event handling +from .types.bidirectional_streaming import ( + AudioInputEvent, + AudioOutputEvent, + TextOutputEvent, + InterruptionDetectedEvent, + BidirectionalStreamEvent, + VoiceActivityEvent, + UsageMetricsEvent, +) + +# Advanced interfaces (for custom implementations) +from .models.bidirectional_model import BidirectionalModel, BidirectionalModelSession + +__all__ = [ + # Main interface + "BidirectionalAgent", + + # Model providers + "NovaSonicBidirectionalModel", + "OpenAIRealtimeBidirectionalModel", + + # Event types + "AudioInputEvent", + "AudioOutputEvent", + "TextOutputEvent", + "InterruptionDetectedEvent", + "BidirectionalStreamEvent", + "VoiceActivityEvent", + "UsageMetricsEvent", + + # Model interface + "BidirectionalModel", + "BidirectionalModelSession", +] \ No newline at end of file diff --git a/src/strands/experimental/bidirectional_streaming/models/__init__.py b/src/strands/experimental/bidirectional_streaming/models/__init__.py index 6cba974e0..4a11f9e4a 100644 --- a/src/strands/experimental/bidirectional_streaming/models/__init__.py +++ b/src/strands/experimental/bidirectional_streaming/models/__init__.py @@ -2,5 +2,13 @@ from .bidirectional_model import BidirectionalModel, BidirectionalModelSession from .novasonic import NovaSonicBidirectionalModel, NovaSonicSession +from .openai import OpenAIRealtimeBidirectionalModel, OpenAIRealtimeSession -__all__ = ["BidirectionalModel", "BidirectionalModelSession", "NovaSonicBidirectionalModel", "NovaSonicSession"] +__all__ = [ + "BidirectionalModel", + "BidirectionalModelSession", + "NovaSonicBidirectionalModel", + "NovaSonicSession", + "OpenAIRealtimeBidirectionalModel", + "OpenAIRealtimeSession" +] \ No newline at end of file diff --git a/src/strands/experimental/bidirectional_streaming/models/novasonic.py b/src/strands/experimental/bidirectional_streaming/models/novasonic.py index 7f7937ef1..bc00b7e91 100644 --- a/src/strands/experimental/bidirectional_streaming/models/novasonic.py +++ b/src/strands/experimental/bidirectional_streaming/models/novasonic.py @@ -35,6 +35,7 @@ BidirectionalConnectionStartEvent, InterruptionDetectedEvent, TextOutputEvent, + UsageMetricsEvent, ) from .bidirectional_model import BidirectionalModel, BidirectionalModelSession @@ -488,9 +489,16 @@ def _convert_nova_event(self, nova_event: dict[str, any]) -> dict[str, any] | No return {"interruptionDetected": interruption} - # Handle usage events (ignore) + # Handle usage events - convert to standardized format elif "usageEvent" in nova_event: - return None + usage_data = nova_event["usageEvent"] + usage_metrics: UsageMetricsEvent = { + "totalTokens": usage_data.get("totalTokens"), + "inputTokens": usage_data.get("totalInputTokens"), + "outputTokens": usage_data.get("totalOutputTokens"), + "audioTokens": usage_data.get("details", {}).get("total", {}).get("output", {}).get("speechTokens") + } + return {"usageMetrics": usage_metrics} # Handle content start events (track role) elif "contentStart" in nova_event: diff --git a/src/strands/experimental/bidirectional_streaming/models/openai.py b/src/strands/experimental/bidirectional_streaming/models/openai.py new file mode 100644 index 000000000..0fa859db9 --- /dev/null +++ b/src/strands/experimental/bidirectional_streaming/models/openai.py @@ -0,0 +1,508 @@ +"""OpenAI Realtime API provider for Strands bidirectional streaming. + +Provides real-time audio and text communication through OpenAI's Realtime API +with WebSocket connections, voice activity detection, and function calling. +""" + +import asyncio +import base64 +import json +import logging +import uuid +from typing import AsyncIterable + +import websockets +from websockets.exceptions import ConnectionClosed +from websockets.client import WebSocketClientProtocol + +from ....types.content import Messages +from ....types.tools import ToolSpec, ToolUse +from ..types.bidirectional_streaming import ( + AudioInputEvent, + AudioOutputEvent, + BidirectionalConnectionEndEvent, + BidirectionalConnectionStartEvent, + BidirectionalStreamEvent, + InterruptionDetectedEvent, + TextOutputEvent, + VoiceActivityEvent, +) +from .bidirectional_model import BidirectionalModel, BidirectionalModelSession + +logger = logging.getLogger(__name__) + +# OpenAI Realtime API configuration +OPENAI_REALTIME_URL = "wss://api.openai.com/v1/realtime" +DEFAULT_MODEL = "gpt-realtime" + +AUDIO_FORMAT = {"type": "audio/pcm", "rate": 24000} + +DEFAULT_SESSION_CONFIG = { + "type": "realtime", + "instructions": "You are a helpful assistant. Please speak in English and keep your responses clear and concise.", + "output_modalities": ["audio"], + "audio": { + "input": { + "format": AUDIO_FORMAT, + "turn_detection": { + "type": "server_vad", + "threshold": 0.5, + "prefix_padding_ms": 300, + "silence_duration_ms": 500, + } + }, + "output": {"format": AUDIO_FORMAT, "voice": "alloy"}, + }, +} + + +class OpenAIRealtimeSession(BidirectionalModelSession): + """OpenAI Realtime API session for real-time audio/text streaming. + + Manages WebSocket connection to OpenAI's Realtime API with automatic VAD, + function calling, and event conversion to Strands format. + """ + + def __init__(self, websocket: WebSocketClientProtocol, config: dict[str, any]) -> None: + """Initialize OpenAI Realtime session.""" + self.websocket = websocket + self.config = config + self.session_id = str(uuid.uuid4()) + self._active = True + + self._event_queue = asyncio.Queue() + self._response_task = None + self._function_call_buffer = {} + + logger.debug("OpenAI Realtime session initialized: %s", self.session_id) + + def _require_active(self) -> bool: + """Check if session is active.""" + return self._active + + def _create_text_event(self, text: str, role: str) -> dict[str, any]: + """Create standardized text output event.""" + text_output: TextOutputEvent = {"text": text, "role": role} + return {"textOutput": text_output} + + def _create_voice_activity_event(self, activity_type: str) -> dict[str, any]: + """Create standardized voice activity event.""" + voice_activity: VoiceActivityEvent = {"activityType": activity_type} + return {"voiceActivity": voice_activity} + + async def _create_conversation_item(self, item_data: dict) -> None: + """Create conversation item and trigger response.""" + await self._send_event({"type": "conversation.item.create", "item": item_data}) + await self._send_event({"type": "response.create"}) + + async def initialize( + self, + system_prompt: str | None = None, + tools: list[ToolSpec] | None = None, + messages: Messages | None = None, + ) -> None: + """Initialize session with configuration.""" + try: + session_config = self._build_session_config(system_prompt, tools) + await self._send_event({"type": "session.update", "session": session_config}) + + if messages: + await self._add_conversation_history(messages) + + self._response_task = asyncio.create_task(self._process_responses()) + logger.info("OpenAI Realtime session initialized successfully") + + except Exception as e: + logger.error("Error during OpenAI Realtime initialization: %s", e) + raise + + def _build_session_config(self, system_prompt: str | None, tools: list[ToolSpec] | None) -> dict: + """Build session configuration for OpenAI Realtime API.""" + config = DEFAULT_SESSION_CONFIG.copy() + + if system_prompt: + config["instructions"] = system_prompt + + if tools: + config["tools"] = self._convert_tools_to_openai_format(tools) + + custom_config = self.config.get("session", {}) + supported_params = { + "type", "output_modalities", "instructions", "voice", "audio", + "tools", "tool_choice", "input_audio_format", "output_audio_format", + "input_audio_transcription", "turn_detection" + } + + for key, value in custom_config.items(): + if key in supported_params: + config[key] = value + else: + logger.warning("Ignoring unsupported session parameter: %s", key) + + return config + + def _convert_tools_to_openai_format(self, tools: list[ToolSpec]) -> list[dict]: + """Convert Strands tool specifications to OpenAI function format.""" + openai_tools = [] + + for tool in tools: + input_schema = tool["inputSchema"] + if "json" in input_schema: + schema = json.loads(input_schema["json"]) if isinstance(input_schema["json"], str) else input_schema["json"] + else: + schema = input_schema + + openai_tool = { + "type": "function", + "function": { + "name": tool["name"], + "description": tool["description"], + "parameters": schema + } + } + openai_tools.append(openai_tool) + + return openai_tools + + async def _add_conversation_history(self, messages: Messages) -> None: + """Add conversation history to the session.""" + for message in messages: + conversation_item = { + "type": "conversation.item.create", + "item": {"type": "message", "role": message["role"], "content": []} + } + + content = message.get("content", "") + if isinstance(content, str): + conversation_item["item"]["content"].append({"type": "input_text", "text": content}) + elif isinstance(content, list): + for item in content: + if isinstance(item, dict) and item.get("type") == "text": + conversation_item["item"]["content"].append({"type": "input_text", "text": item.get("text", "")}) + + await self._send_event(conversation_item) + + async def _process_responses(self) -> None: + """Process incoming WebSocket messages.""" + logger.debug("OpenAI Realtime response processor started") + + try: + async for message in self.websocket: + if not self._active: + break + + try: + event = json.loads(message) + await self._event_queue.put(event) + except json.JSONDecodeError as e: + logger.warning("Failed to parse OpenAI event: %s", e) + continue + + except ConnectionClosed: + logger.debug("OpenAI Realtime WebSocket connection closed") + except Exception as e: + logger.error("Error in OpenAI Realtime response processing: %s", e) + finally: + self._active = False + logger.debug("OpenAI Realtime response processor stopped") + + async def receive_events(self) -> AsyncIterable[BidirectionalStreamEvent]: + """Receive OpenAI events and convert to Strands format.""" + connection_start: BidirectionalConnectionStartEvent = { + "connectionId": self.session_id, + "metadata": {"provider": "openai_realtime", "model": self.config.get("model", DEFAULT_MODEL)}, + } + yield {"BidirectionalConnectionStart": connection_start} + + try: + while self._active: + try: + openai_event = await asyncio.wait_for(self._event_queue.get(), timeout=1.0) + provider_event = self._convert_openai_event(openai_event) + if provider_event: + yield provider_event + except asyncio.TimeoutError: + continue + + except Exception as e: + logger.error("Error receiving OpenAI Realtime event: %s", e) + finally: + connection_end: BidirectionalConnectionEndEvent = { + "connectionId": self.session_id, + "reason": "connection_complete", + "metadata": {"provider": "openai_realtime"}, + } + yield {"BidirectionalConnectionEnd": connection_end} + + def _convert_openai_event(self, openai_event: dict[str, any]) -> dict[str, any] | None: + """Convert OpenAI events to Strands format.""" + event_type = openai_event.get("type") + + # Audio output + if event_type == "response.output_audio.delta": + audio_data = base64.b64decode(openai_event["delta"]) + audio_output: AudioOutputEvent = { + "audioData": audio_data, + "format": "pcm", + "sampleRate": 24000, + "channels": 1, + "encoding": None, + } + return {"audioOutput": audio_output} + + # Text output using helper method + elif event_type == "response.output_text.delta": + return self._create_text_event(openai_event["delta"], "assistant") + + elif event_type == "response.output_audio_transcript.delta": + return self._create_text_event(openai_event["delta"], "assistant") + + # User transcription + elif event_type == "conversation.item.input_audio_transcription.delta": + transcript_delta = openai_event.get("delta", "") + return self._create_text_event(transcript_delta, "user") if transcript_delta.strip() else None + + elif event_type == "conversation.item.input_audio_transcription.completed": + transcript = openai_event.get("transcript", "") + return self._create_text_event(transcript, "user") if transcript.strip() else None + + elif event_type == "conversation.item.input_audio_transcription.segment": + segment_data = openai_event.get("segment", {}) + text = segment_data.get("text", "") + return self._create_text_event(text, "user") if text.strip() else None + + elif event_type == "conversation.item.input_audio_transcription.failed": + error_info = openai_event.get("error", {}) + logger.warning("OpenAI transcription failed: %s", error_info.get("message", "Unknown error")) + return None + + # Function call processing + elif event_type == "response.function_call_arguments.delta": + call_id = openai_event.get("call_id") + delta = openai_event.get("delta", "") + if call_id: + if call_id not in self._function_call_buffer: + self._function_call_buffer[call_id] = {"call_id": call_id, "name": "", "arguments": delta} + else: + self._function_call_buffer[call_id]["arguments"] += delta + return None + + elif event_type == "response.function_call_arguments.done": + call_id = openai_event.get("call_id") + if call_id and call_id in self._function_call_buffer: + function_call = self._function_call_buffer[call_id] + try: + tool_use: ToolUse = { + "toolUseId": call_id, + "name": function_call["name"], + "input": json.loads(function_call["arguments"]) if function_call["arguments"] else {}, + } + del self._function_call_buffer[call_id] + return {"toolUse": tool_use} + except (json.JSONDecodeError, KeyError) as e: + logger.warning("Error parsing function arguments for %s: %s", call_id, e) + del self._function_call_buffer[call_id] + return None + + # Voice activity detection using helper method + elif event_type == "input_audio_buffer.speech_started": + return self._create_voice_activity_event("speech_started") + elif event_type == "input_audio_buffer.speech_stopped": + return self._create_voice_activity_event("speech_stopped") + elif event_type == "input_audio_buffer.timeout_triggered": + return self._create_voice_activity_event("timeout") + + # Lifecycle events (log only) + elif event_type == "conversation.item.retrieve": + item = openai_event.get("item", {}) + logger.debug("OpenAI conversation item retrieved: %s", item.get("id")) + return None + + elif event_type == "conversation.item.added": + logger.debug("OpenAI conversation item added: %s", openai_event.get("item", {}).get("id")) + return None + + elif event_type == "conversation.item.done": + logger.debug("OpenAI conversation item done: %s", openai_event.get("item", {}).get("id")) + + item = openai_event.get("item", {}) + if item.get("type") == "message" and item.get("role") == "assistant": + content_parts = item.get("content", []) + if content_parts: + message_content = [] + for content_part in content_parts: + if content_part.get("type") == "output_text": + message_content.append({"type": "text", "text": content_part.get("text", "")}) + elif content_part.get("type") == "output_audio": + transcript = content_part.get("transcript", "") + if transcript: + message_content.append({"type": "text", "text": transcript}) + + if message_content: + message = {"role": "assistant", "content": message_content} + return {"messageStop": {"message": message}} + return None + + elif event_type in ["response.output_item.added", "response.output_item.done", + "response.content_part.added", "response.content_part.done"]: + item_data = openai_event.get("item") or openai_event.get("part") + logger.debug("OpenAI %s: %s", event_type, item_data.get("id") if item_data else "unknown") + + # Track function call names from response.output_item.added + if event_type == "response.output_item.added": + item = openai_event.get("item", {}) + if item.get("type") == "function_call": + call_id = item.get("call_id") + function_name = item.get("name") + if call_id and function_name: + if call_id not in self._function_call_buffer: + self._function_call_buffer[call_id] = {"call_id": call_id, "name": function_name, "arguments": ""} + else: + self._function_call_buffer[call_id]["name"] = function_name + return None + + elif event_type in ["input_audio_buffer.committed", "input_audio_buffer.cleared", + "session.created", "session.updated"]: + logger.debug("OpenAI %s event", event_type) + return None + + elif event_type == "error": + logger.error("OpenAI Realtime error: %s", openai_event.get("error", {})) + return None + + else: + logger.debug("Unhandled OpenAI event type: %s", event_type) + return None + + async def send_audio_content(self, audio_input: AudioInputEvent) -> None: + """Send audio content to OpenAI for processing.""" + if not self._require_active(): + return + + audio_base64 = base64.b64encode(audio_input["audioData"]).decode("utf-8") + await self._send_event({"type": "input_audio_buffer.append", "audio": audio_base64}) + + async def send_text_content(self, text: str, **kwargs) -> None: + """Send text content to OpenAI for processing.""" + if not self._require_active(): + return + + item_data = { + "type": "message", + "role": "user", + "content": [{"type": "input_text", "text": text}] + } + await self._create_conversation_item(item_data) + + async def send_interrupt(self) -> None: + """Send interruption signal to OpenAI.""" + if not self._require_active(): + return + + await self._send_event({"type": "response.cancel"}) + + async def send_tool_result(self, tool_use_id: str, result: dict[str, any]) -> None: + """Send tool result back to OpenAI.""" + if not self._require_active(): + return + + logger.debug("OpenAI tool result send: %s", tool_use_id) + result_text = json.dumps(result) if not isinstance(result, str) else result + + item_data = { + "type": "function_call_output", + "call_id": tool_use_id, + "output": result_text + } + await self._create_conversation_item(item_data) + + async def close(self) -> None: + """Close session and cleanup resources.""" + if not self._active: + return + + logger.debug("OpenAI Realtime cleanup - starting connection close") + self._active = False + + if self._response_task and not self._response_task.done(): + self._response_task.cancel() + try: + await self._response_task + except asyncio.CancelledError: + pass + + try: + await self.websocket.close() + except Exception as e: + logger.warning("Error closing OpenAI Realtime WebSocket: %s", e) + + logger.debug("OpenAI Realtime connection closed") + + async def _send_event(self, event: dict[str, any]) -> None: + """Send event to OpenAI via WebSocket.""" + try: + message = json.dumps(event) + await self.websocket.send(message) + logger.debug("Sent OpenAI event: %s", event.get("type")) + except Exception as e: + logger.error("Error sending OpenAI event: %s", e) + raise + + +class OpenAIRealtimeBidirectionalModel(BidirectionalModel): + """OpenAI Realtime API provider for Strands bidirectional streaming. + + Provides real-time audio/text communication through OpenAI's Realtime API + with WebSocket connections, voice activity detection, and function calling. + """ + + def __init__( + self, + model: str = DEFAULT_MODEL, + api_key: str | None = None, + **config: any + ) -> None: + """Initialize OpenAI Realtime bidirectional model.""" + self.model = model + self.api_key = api_key + self.config = config + + import os + if not self.api_key: + self.api_key = os.getenv("OPENAI_API_KEY") + if not self.api_key: + raise ValueError("OpenAI API key is required. Set OPENAI_API_KEY environment variable or pass api_key parameter.") + + logger.debug("OpenAI Realtime bidirectional model initialized: %s", model) + + async def create_bidirectional_connection( + self, + system_prompt: str | None = None, + tools: list[ToolSpec] | None = None, + messages: Messages | None = None, + **kwargs, + ) -> BidirectionalModelSession: + """Create bidirectional connection to OpenAI Realtime API.""" + logger.info("Creating OpenAI Realtime connection...") + + try: + url = f"{OPENAI_REALTIME_URL}?model={self.model}" + + headers = [("Authorization", f"Bearer {self.api_key}")] + if "organization" in self.config: + headers.append(("OpenAI-Organization", self.config["organization"])) + if "project" in self.config: + headers.append(("OpenAI-Project", self.config["project"])) + + websocket = await websockets.connect(url, additional_headers=headers) + logger.info("WebSocket connected successfully") + + session = OpenAIRealtimeSession(websocket, self.config) + await session.initialize(system_prompt, tools, messages) + + logger.info("OpenAI Realtime connection established") + return session + + except Exception as e: + logger.error("OpenAI connection error: %s", e) + raise \ No newline at end of file diff --git a/src/strands/experimental/bidirectional_streaming/tests/test_bidirectional_streaming.py b/src/strands/experimental/bidirectional_streaming/tests/test_bidi_novasonic.py similarity index 100% rename from src/strands/experimental/bidirectional_streaming/tests/test_bidirectional_streaming.py rename to src/strands/experimental/bidirectional_streaming/tests/test_bidi_novasonic.py diff --git a/src/strands/experimental/bidirectional_streaming/tests/test_bidi_openai.py b/src/strands/experimental/bidirectional_streaming/tests/test_bidi_openai.py new file mode 100644 index 000000000..098ec4a39 --- /dev/null +++ b/src/strands/experimental/bidirectional_streaming/tests/test_bidi_openai.py @@ -0,0 +1,285 @@ +#!/usr/bin/env python3 +"""Test OpenAI Realtime API speech-to-speech interaction.""" + +import asyncio +import os +import sys +import time +from pathlib import Path + +# Add the src directory to Python path +sys.path.insert(0, str(Path(__file__).parent / "src")) + +import pyaudio + +from strands.experimental.bidirectional_streaming.agent.agent import BidirectionalAgent +from strands.experimental.bidirectional_streaming.models.openai import OpenAIRealtimeBidirectionalModel + + +async def play(context): + """Handle audio playback with interruption support.""" + audio = pyaudio.PyAudio() + + try: + speaker = audio.open( + format=pyaudio.paInt16, + channels=1, + rate=24000, # OpenAI Realtime uses 24kHz + output=True, + frames_per_buffer=1024, + ) + + while context["active"]: + try: + # Check for interruption + if context.get("interrupted", False): + # Clear audio queue on interruption + while not context["audio_out"].empty(): + try: + context["audio_out"].get_nowait() + except asyncio.QueueEmpty: + break + + context["interrupted"] = False + await asyncio.sleep(0.05) + continue + + # Get audio data with timeout + try: + audio_data = await asyncio.wait_for(context["audio_out"].get(), timeout=0.1) + + if audio_data and context["active"]: + # Play in chunks to allow interruption + chunk_size = 1024 + for i in range(0, len(audio_data), chunk_size): + if context.get("interrupted", False) or not context["active"]: + break + + chunk = audio_data[i:i + chunk_size] + speaker.write(chunk) + await asyncio.sleep(0.001) # Brief pause for responsiveness + + except asyncio.TimeoutError: + continue + + except asyncio.CancelledError: + break + + except asyncio.CancelledError: + pass + except Exception as e: + print(f"Audio playback error: {e}") + finally: + try: + speaker.close() + except: + pass + audio.terminate() + + +async def record(context): + """Handle microphone recording.""" + audio = pyaudio.PyAudio() + + try: + microphone = audio.open( + format=pyaudio.paInt16, + channels=1, + rate=24000, # Match OpenAI's expected input rate + input=True, + frames_per_buffer=1024, + ) + + while context["active"]: + try: + audio_bytes = microphone.read(1024, exception_on_overflow=False) + await context["audio_in"].put(audio_bytes) + await asyncio.sleep(0.01) + except asyncio.CancelledError: + break + + except asyncio.CancelledError: + pass + except Exception as e: + print(f"Microphone recording error: {e}") + finally: + try: + microphone.close() + except: + pass + audio.terminate() + + +async def receive(agent, context): + """Handle events from the agent.""" + try: + async for event in agent.receive(): + if not context["active"]: + break + + # Handle audio output + if "audioOutput" in event: + audio_data = event["audioOutput"]["audioData"] + + if not context.get("interrupted", False): + await context["audio_out"].put(audio_data) + + # Handle text output (transcripts) + elif "textOutput" in event: + text_output = event["textOutput"] + role = text_output.get("role", "assistant") + text = text_output.get("text", "").strip() + + if text: + if role == "user": + print(f"User: {text}") + elif role == "assistant": + print(f"Assistant: {text}") + + # Handle interruption detection + elif "interruptionDetected" in event: + context["interrupted"] = True + + # Handle connection events + elif "BidirectionalConnectionStart" in event: + pass # Silent connection start + elif "BidirectionalConnectionEnd" in event: + context["active"] = False + break + + except asyncio.CancelledError: + pass + except Exception as e: + print(f"Receive handler error: {e}") + finally: + pass + + +async def send(agent, context): + """Send audio from microphone to agent.""" + try: + while context["active"]: + try: + audio_bytes = await asyncio.wait_for(context["audio_in"].get(), timeout=0.1) + + # Create audio event in expected format + audio_event = { + "audioData": audio_bytes, + "format": "pcm", + "sampleRate": 24000, + "channels": 1 + } + + await agent.send(audio_event) + + except asyncio.TimeoutError: + continue + except asyncio.CancelledError: + break + + except asyncio.CancelledError: + pass + except Exception as e: + print(f"Send handler error: {e}") + finally: + pass + + +async def main(): + """Main test function for OpenAI voice chat.""" + print("Starting OpenAI Realtime API test...") + + # Check API key + api_key = os.getenv("OPENAI_API_KEY") + if not api_key: + print("OPENAI_API_KEY environment variable not set") + return False + + # Check audio system + try: + audio = pyaudio.PyAudio() + audio.terminate() + except Exception as e: + print(f"Audio system error: {e}") + return False + + # Create OpenAI model + model = OpenAIRealtimeBidirectionalModel( + model="gpt-4o-realtime-preview", + api_key=api_key, + session={ + "output_modalities": ["audio"], + "audio": { + "input": { + "format": {"type": "audio/pcm", "rate": 24000}, + "turn_detection": { + "type": "server_vad", + "threshold": 0.5, + "silence_duration_ms": 700 + } + }, + "output": { + "format": {"type": "audio/pcm", "rate": 24000}, + "voice": "alloy" + } + } + } + ) + + # Create agent + agent = BidirectionalAgent( + model=model, + system_prompt="You are a helpful voice assistant. Keep your responses brief and natural. Say hello when you first connect." + ) + + # Start the session + await agent.start() + + # Create shared context + context = { + "active": True, + "audio_in": asyncio.Queue(), + "audio_out": asyncio.Queue(), + "interrupted": False, + "start_time": time.time() + } + + print("Speak into your microphone. Press Ctrl+C to stop.") + + try: + # Run all tasks concurrently + await asyncio.gather( + play(context), + record(context), + receive(agent, context), + send(agent, context), + return_exceptions=True + ) + + except KeyboardInterrupt: + print("\nInterrupted by user") + except asyncio.CancelledError: + print("\nTest cancelled") + except Exception as e: + print(f"\nError during voice chat: {e}") + finally: + print("Cleaning up...") + context["active"] = False + + try: + await agent.end() + except Exception as e: + print(f"Cleanup error: {e}") + + return True + + +if __name__ == "__main__": + try: + asyncio.run(main()) + except KeyboardInterrupt: + print("\nTest interrupted by user") + except Exception as e: + print(f"Test error: {e}") + import traceback + traceback.print_exc() \ No newline at end of file diff --git a/src/strands/experimental/bidirectional_streaming/types/__init__.py b/src/strands/experimental/bidirectional_streaming/types/__init__.py index 510285f06..412061146 100644 --- a/src/strands/experimental/bidirectional_streaming/types/__init__.py +++ b/src/strands/experimental/bidirectional_streaming/types/__init__.py @@ -13,6 +13,8 @@ BidirectionalStreamEvent, InterruptionDetectedEvent, TextOutputEvent, + UsageMetricsEvent, + VoiceActivityEvent, ) __all__ = [ @@ -23,6 +25,8 @@ "BidirectionalStreamEvent", "InterruptionDetectedEvent", "TextOutputEvent", + "UsageMetricsEvent", + "VoiceActivityEvent", "SUPPORTED_AUDIO_FORMATS", "SUPPORTED_SAMPLE_RATES", "SUPPORTED_CHANNELS", diff --git a/src/strands/experimental/bidirectional_streaming/types/bidirectional_streaming.py b/src/strands/experimental/bidirectional_streaming/types/bidirectional_streaming.py index 01d72356a..194698f29 100644 --- a/src/strands/experimental/bidirectional_streaming/types/bidirectional_streaming.py +++ b/src/strands/experimental/bidirectional_streaming/types/bidirectional_streaming.py @@ -116,11 +116,43 @@ class BidirectionalConnectionEndEvent(TypedDict): metadata: Provider-specific connection metadata. """ - reason: Literal["user_request", "timeout", "error"] + reason: Literal["user_request", "timeout", "error", "connection_complete"] connectionId: Optional[str] metadata: Optional[Dict[str, Any]] +class VoiceActivityEvent(TypedDict): + """Voice activity detection event for speech monitoring. + + Provides standardized voice activity detection events across providers + to enable speech-aware applications and better conversation flow. + + Attributes: + activityType: Type of voice activity detected. + """ + + activityType: Literal["speech_started", "speech_stopped", "timeout"] + + +class UsageMetricsEvent(TypedDict): + """Token usage and performance tracking. + + Provides standardized usage metrics across providers for cost monitoring + and performance optimization. + + Attributes: + totalTokens: Total tokens used in the interaction. + inputTokens: Tokens used for input processing. + outputTokens: Tokens used for output generation. + audioTokens: Tokens used specifically for audio processing. + """ + + totalTokens: Optional[int] + inputTokens: Optional[int] + outputTokens: Optional[int] + audioTokens: Optional[int] + + class BidirectionalStreamEvent(StreamEvent, total=False): """Bidirectional stream event extending existing StreamEvent. @@ -134,11 +166,15 @@ class BidirectionalStreamEvent(StreamEvent, total=False): interruptionDetected: User interruption detection. BidirectionalConnectionStart: connection start event. BidirectionalConnectionEnd: connection end event. + voiceActivity: Voice activity detection events. + usageMetrics: Token usage and performance metrics. """ - audioOutput: AudioOutputEvent - audioInput: AudioInputEvent - textOutput: TextOutputEvent - interruptionDetected: InterruptionDetectedEvent - BidirectionalConnectionStart: BidirectionalConnectionStartEvent - BidirectionalConnectionEnd: BidirectionalConnectionEndEvent + audioOutput: Optional[AudioOutputEvent] + audioInput: Optional[AudioInputEvent] + textOutput: Optional[TextOutputEvent] + interruptionDetected: Optional[InterruptionDetectedEvent] + BidirectionalConnectionStart: Optional[BidirectionalConnectionStartEvent] + BidirectionalConnectionEnd: Optional[BidirectionalConnectionEndEvent] + voiceActivity: Optional[VoiceActivityEvent] + usageMetrics: Optional[UsageMetricsEvent] From da8b86ca77e4f324b6cfc2a1d7ce756ec8a6d310 Mon Sep 17 00:00:00 2001 From: Rachit Mehta Date: Fri, 10 Oct 2025 10:29:08 -0400 Subject: [PATCH 17/23] fix function calling --- .../bidirectional_streaming/__init__.py | 15 +++++++-------- .../bidirectional_streaming/agent/agent.py | 1 - .../event_loop/bidirectional_event_loop.py | 1 - .../bidirectional_streaming/models/novasonic.py | 1 - .../bidirectional_streaming/models/openai.py | 14 ++++++-------- .../tests/test_bidi_openai.py | 2 ++ 6 files changed, 15 insertions(+), 19 deletions(-) diff --git a/src/strands/experimental/bidirectional_streaming/__init__.py b/src/strands/experimental/bidirectional_streaming/__init__.py index a6af41dff..aeb335dea 100644 --- a/src/strands/experimental/bidirectional_streaming/__init__.py +++ b/src/strands/experimental/bidirectional_streaming/__init__.py @@ -1,10 +1,12 @@ -""" -Bidirectional streaming package. +"""Bidirectional streaming package. """ # Main components - Primary user interface from .agent.agent import BidirectionalAgent +# Advanced interfaces (for custom implementations) +from .models.bidirectional_model import BidirectionalModel, BidirectionalModelSession + # Model providers - What users need to create models from .models.novasonic import NovaSonicBidirectionalModel from .models.openai import OpenAIRealtimeBidirectionalModel @@ -13,16 +15,13 @@ from .types.bidirectional_streaming import ( AudioInputEvent, AudioOutputEvent, - TextOutputEvent, - InterruptionDetectedEvent, BidirectionalStreamEvent, - VoiceActivityEvent, + InterruptionDetectedEvent, + TextOutputEvent, UsageMetricsEvent, + VoiceActivityEvent, ) -# Advanced interfaces (for custom implementations) -from .models.bidirectional_model import BidirectionalModel, BidirectionalModelSession - __all__ = [ # Main interface "BidirectionalAgent", diff --git a/src/strands/experimental/bidirectional_streaming/agent/agent.py b/src/strands/experimental/bidirectional_streaming/agent/agent.py index 68d371a51..0cd90063d 100644 --- a/src/strands/experimental/bidirectional_streaming/agent/agent.py +++ b/src/strands/experimental/bidirectional_streaming/agent/agent.py @@ -23,7 +23,6 @@ from ..models.bidirectional_model import BidirectionalModel from ..types.bidirectional_streaming import AudioInputEvent, BidirectionalStreamEvent - logger = logging.getLogger(__name__) diff --git a/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py b/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py index 16be08aaf..340cd9267 100644 --- a/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py +++ b/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py @@ -22,7 +22,6 @@ from ....types.tools import ToolResult, ToolUse from ..models.bidirectional_model import BidirectionalModelSession - logger = logging.getLogger(__name__) # Session constants diff --git a/src/strands/experimental/bidirectional_streaming/models/novasonic.py b/src/strands/experimental/bidirectional_streaming/models/novasonic.py index bc00b7e91..4e4952fa9 100644 --- a/src/strands/experimental/bidirectional_streaming/models/novasonic.py +++ b/src/strands/experimental/bidirectional_streaming/models/novasonic.py @@ -37,7 +37,6 @@ TextOutputEvent, UsageMetricsEvent, ) - from .bidirectional_model import BidirectionalModel, BidirectionalModelSession logger = logging.getLogger(__name__) diff --git a/src/strands/experimental/bidirectional_streaming/models/openai.py b/src/strands/experimental/bidirectional_streaming/models/openai.py index 0fa859db9..76bf9f50d 100644 --- a/src/strands/experimental/bidirectional_streaming/models/openai.py +++ b/src/strands/experimental/bidirectional_streaming/models/openai.py @@ -12,8 +12,8 @@ from typing import AsyncIterable import websockets -from websockets.exceptions import ConnectionClosed from websockets.client import WebSocketClientProtocol +from websockets.exceptions import ConnectionClosed from ....types.content import Messages from ....types.tools import ToolSpec, ToolUse @@ -23,7 +23,6 @@ BidirectionalConnectionEndEvent, BidirectionalConnectionStartEvent, BidirectionalStreamEvent, - InterruptionDetectedEvent, TextOutputEvent, VoiceActivityEvent, ) @@ -142,7 +141,7 @@ def _build_session_config(self, system_prompt: str | None, tools: list[ToolSpec] return config def _convert_tools_to_openai_format(self, tools: list[ToolSpec]) -> list[dict]: - """Convert Strands tool specifications to OpenAI function format.""" + """Convert Strands tool specifications to OpenAI Realtime API format.""" openai_tools = [] for tool in tools: @@ -152,13 +151,12 @@ def _convert_tools_to_openai_format(self, tools: list[ToolSpec]) -> list[dict]: else: schema = input_schema + # OpenAI Realtime API expects flat structure, not nested under "function" openai_tool = { "type": "function", - "function": { - "name": tool["name"], - "description": tool["description"], - "parameters": schema - } + "name": tool["name"], + "description": tool["description"], + "parameters": schema } openai_tools.append(openai_tool) diff --git a/src/strands/experimental/bidirectional_streaming/tests/test_bidi_openai.py b/src/strands/experimental/bidirectional_streaming/tests/test_bidi_openai.py index 098ec4a39..660040f3e 100644 --- a/src/strands/experimental/bidirectional_streaming/tests/test_bidi_openai.py +++ b/src/strands/experimental/bidirectional_streaming/tests/test_bidi_openai.py @@ -11,6 +11,7 @@ sys.path.insert(0, str(Path(__file__).parent / "src")) import pyaudio +from strands_tools import calculator from strands.experimental.bidirectional_streaming.agent.agent import BidirectionalAgent from strands.experimental.bidirectional_streaming.models.openai import OpenAIRealtimeBidirectionalModel @@ -229,6 +230,7 @@ async def main(): # Create agent agent = BidirectionalAgent( model=model, + tools=[calculator], system_prompt="You are a helpful voice assistant. Keep your responses brief and natural. Say hello when you first connect." ) From 9368c82c76f6ca858d355c68624e078c8b95cf4e Mon Sep 17 00:00:00 2001 From: Rachit Mehta Date: Tue, 14 Oct 2025 08:30:23 -0400 Subject: [PATCH 18/23] feat(tool_executor): Plug tool executor into bidirectional streaming implementation --- .../bidirectional_streaming/__init__.py | 41 +- .../bidirectional_streaming/agent/agent.py | 297 +++++++++- .../event_loop/bidirectional_event_loop.py | 179 +++--- .../models/__init__.py | 9 +- .../models/novasonic.py | 23 +- .../bidirectional_streaming/models/openai.py | 522 ++++++++++++++++++ .../tests/test_bidi_openai.py | 317 +++++++++++ .../tests/test_bidirectional_streaming.py | 27 + 8 files changed, 1317 insertions(+), 98 deletions(-) create mode 100644 src/strands/experimental/bidirectional_streaming/models/openai.py create mode 100644 src/strands/experimental/bidirectional_streaming/tests/test_bidi_openai.py diff --git a/src/strands/experimental/bidirectional_streaming/__init__.py b/src/strands/experimental/bidirectional_streaming/__init__.py index 52822711a..844a8a1f8 100644 --- a/src/strands/experimental/bidirectional_streaming/__init__.py +++ b/src/strands/experimental/bidirectional_streaming/__init__.py @@ -1,2 +1,41 @@ -"""Bidirectional streaming package for real-time audio/text conversations.""" +"""Bidirectional streaming package.""" +# Main components - Primary user interface +from .agent.agent import BidirectionalAgent + +# Advanced interfaces (for custom implementations) +from .models.bidirectional_model import BidirectionalModel, BidirectionalModelSession + +# Model providers - What users need to create models +from .models.novasonic import NovaSonicBidirectionalModel +from .models.openai import OpenAIRealtimeBidirectionalModel + +# Event types - For type hints and event handling +from .types.bidirectional_streaming import ( + AudioInputEvent, + AudioOutputEvent, + BidirectionalStreamEvent, + InterruptionDetectedEvent, + TextOutputEvent, + UsageMetricsEvent, + VoiceActivityEvent, +) + +__all__ = [ + # Main interface + "BidirectionalAgent", + # Model providers + "NovaSonicBidirectionalModel", + "OpenAIRealtimeBidirectionalModel", + # Event types + "AudioInputEvent", + "AudioOutputEvent", + "TextOutputEvent", + "InterruptionDetectedEvent", + "BidirectionalStreamEvent", + "VoiceActivityEvent", + "UsageMetricsEvent", + # Model interface + "BidirectionalModel", + "BidirectionalModelSession", +] diff --git a/src/strands/experimental/bidirectional_streaming/agent/agent.py b/src/strands/experimental/bidirectional_streaming/agent/agent.py index 68d371a51..26b964c53 100644 --- a/src/strands/experimental/bidirectional_streaming/agent/agent.py +++ b/src/strands/experimental/bidirectional_streaming/agent/agent.py @@ -13,12 +13,22 @@ """ import asyncio +import json import logging -from typing import AsyncIterable +import random +from concurrent.futures import ThreadPoolExecutor +from typing import Any, AsyncIterable, Callable, Mapping, Optional +from .... import _identifier +from ....hooks import HookProvider, HookRegistry +from ....telemetry.metrics import EventLoopMetrics from ....tools.executors import ConcurrentToolExecutor +from ....tools.executors._executor import ToolExecutor from ....tools.registry import ToolRegistry -from ....types.content import Messages +from ....tools.watcher import ToolWatcher +from ....types.content import Message, Messages +from ....types.tools import ToolResult, ToolUse +from ....types.traces import AttributeValue from ..event_loop.bidirectional_event_loop import start_bidirectional_connection, stop_bidirectional_connection from ..models.bidirectional_model import BidirectionalModel from ..types.bidirectional_streaming import AudioInputEvent, BidirectionalStreamEvent @@ -26,6 +36,9 @@ logger = logging.getLogger(__name__) +_DEFAULT_AGENT_NAME = "Strands Agents" +_DEFAULT_AGENT_ID = "default" + class BidirectionalAgent: """Agent for bidirectional streaming conversations. @@ -34,12 +47,125 @@ class BidirectionalAgent: sessions. Supports concurrent tool execution and interruption handling. """ + class ToolCaller: + """Call tool as a function for bidirectional agent.""" + + def __init__(self, agent: "BidirectionalAgent") -> None: + """Initialize tool caller with agent reference.""" + # WARNING: Do not add any other member variables or methods as this could result in a name conflict with + # agent tools and thus break their execution. + self._agent = agent + + def __getattr__(self, name: str) -> Callable[..., Any]: + """Call tool as a function. + + This method enables the method-style interface (e.g., `agent.tool.tool_name(param="value")`). + It matches underscore-separated names to hyphenated tool names (e.g., 'some_thing' matches 'some-thing'). + + Args: + name: The name of the attribute (tool) being accessed. + + Returns: + A function that when called will execute the named tool. + + Raises: + AttributeError: If no tool with the given name exists or if multiple tools match the given name. + """ + + def caller( + user_message_override: Optional[str] = None, + record_direct_tool_call: Optional[bool] = None, + **kwargs: Any, + ) -> Any: + """Call a tool directly by name. + + Args: + user_message_override: Optional custom message to record instead of default + record_direct_tool_call: Whether to record direct tool calls in message history. + For bidirectional agents, this is always True to maintain conversation history. + **kwargs: Keyword arguments to pass to the tool. + + Returns: + The result returned by the tool. + + Raises: + AttributeError: If the tool doesn't exist. + """ + normalized_name = self._find_normalized_tool_name(name) + + # Create unique tool ID and set up the tool request + tool_id = f"tooluse_{name}_{random.randint(100000000, 999999999)}" + tool_use: ToolUse = { + "toolUseId": tool_id, + "name": normalized_name, + "input": kwargs.copy(), + } + tool_results: list[ToolResult] = [] + invocation_state = kwargs + + async def acall() -> ToolResult: + async for event in ToolExecutor._stream(self._agent, tool_use, tool_results, invocation_state): + _ = event + + return tool_results[0] + + def tcall() -> ToolResult: + return asyncio.run(acall()) + + with ThreadPoolExecutor() as executor: + future = executor.submit(tcall) + tool_result = future.result() + + # Always record direct tool calls for bidirectional agents to maintain conversation history + # Use agent's record_direct_tool_call setting if not overridden + if record_direct_tool_call is not None: + should_record_direct_tool_call = record_direct_tool_call + else: + should_record_direct_tool_call = self._agent.record_direct_tool_call + + if should_record_direct_tool_call: + # Create a record of this tool execution in the message history + self._agent._record_tool_execution(tool_use, tool_result, user_message_override) + + return tool_result + + return caller + + def _find_normalized_tool_name(self, name: str) -> str: + """Lookup the tool represented by name, replacing characters with underscores as necessary.""" + tool_registry = self._agent.tool_registry.registry + + if tool_registry.get(name, None): + return name + + # If the desired name contains underscores, it might be a placeholder for characters that can't be + # represented as python identifiers but are valid as tool names, such as dashes. In that case, find + # all tools that can be represented with the normalized name + if "_" in name: + filtered_tools = [ + tool_name for (tool_name, tool) in tool_registry.items() if tool_name.replace("-", "_") == name + ] + + # The registry itself defends against similar names, so we can just take the first match + if filtered_tools: + return filtered_tools[0] + + raise AttributeError(f"Tool '{name}' not found") + def __init__( self, model: BidirectionalModel, tools: list | None = None, system_prompt: str | None = None, messages: Messages | None = None, + record_direct_tool_call: bool = True, + load_tools_from_directory: bool = False, + agent_id: Optional[str] = None, + name: Optional[str] = None, + tool_executor: Optional[ToolExecutor] = None, + hooks: Optional[list[HookProvider]] = None, + trace_attributes: Optional[Mapping[str, AttributeValue]] = None, + description: Optional[str] = None, ): """Initialize bidirectional agent with required model and optional configuration. @@ -48,24 +174,177 @@ def __init__( tools: Optional list of tools available to the model. system_prompt: Optional system prompt for conversations. messages: Optional conversation history to initialize with. + record_direct_tool_call: Whether to record direct tool calls in message history. + load_tools_from_directory: Whether to load and automatically reload tools in the `./tools/` directory. + agent_id: Optional ID for the agent, useful for session management and multi-agent scenarios. + name: Name of the Agent. + tool_executor: Definition of tool execution strategy (e.g., sequential, concurrent, etc.). + hooks: Hooks to be added to the agent hook registry. + trace_attributes: Custom trace attributes to apply to the agent's trace span. + description: Description of what the Agent does. """ self.model = model self.system_prompt = system_prompt self.messages = messages or [] - - # Initialize tool registry using existing Strands infrastructure + + # Agent identification + self.agent_id = _identifier.validate(agent_id or _DEFAULT_AGENT_ID, _identifier.Identifier.AGENT) + self.name = name or _DEFAULT_AGENT_NAME + self.description = description + + # Tool execution configuration + self.record_direct_tool_call = record_direct_tool_call + self.load_tools_from_directory = load_tools_from_directory + + # Process trace attributes to ensure they're of compatible types + self.trace_attributes: dict[str, AttributeValue] = {} + if trace_attributes: + for k, v in trace_attributes.items(): + if isinstance(v, (str, int, float, bool)) or ( + isinstance(v, list) and all(isinstance(x, (str, int, float, bool)) for x in v) + ): + self.trace_attributes[k] = v + + # Initialize tool registry self.tool_registry = ToolRegistry() - if tools: + + if tools is not None: self.tool_registry.process_tools(tools) - self.tool_registry.initialize_tools() - - # Initialize tool executor for concurrent execution - self.tool_executor = ConcurrentToolExecutor() + + self.tool_registry.initialize_tools(self.load_tools_from_directory) + + # Initialize tool watcher if directory loading is enabled + if self.load_tools_from_directory: + self.tool_watcher = ToolWatcher(tool_registry=self.tool_registry) + + # Initialize tool executor + self.tool_executor = tool_executor or ConcurrentToolExecutor() + + # Initialize hooks system + self.hooks = HookRegistry() + if hooks: + for hook in hooks: + self.hooks.add_hook(hook) + + # Initialize other components + self.event_loop_metrics = EventLoopMetrics() + self.tool_caller = BidirectionalAgent.ToolCaller(self) # Session management self._session = None self._output_queue = asyncio.Queue() + @property + def tool(self) -> ToolCaller: + """Call tool as a function. + + Returns: + Tool caller through which user can invoke tool as a function. + + Example: + ``` + agent = BidirectionalAgent(model=model, tools=[calculator]) + agent.tool.calculator(expression="2+2") + ``` + """ + return self.tool_caller + + @property + def tool_names(self) -> list[str]: + """Get a list of all registered tool names. + + Returns: + Names of all tools available to this agent. + """ + all_tools = self.tool_registry.get_all_tools_config() + return list(all_tools.keys()) + + def _record_tool_execution( + self, + tool: ToolUse, + tool_result: ToolResult, + user_message_override: Optional[str], + ) -> None: + """Record a tool execution in the message history. + + Creates a sequence of messages that represent the tool execution: + + 1. A user message describing the tool call + 2. An assistant message with the tool use + 3. A user message with the tool result + 4. An assistant message acknowledging the tool call + + Args: + tool: The tool call information. + tool_result: The result returned by the tool. + user_message_override: Optional custom message to include. + """ + # Filter tool input parameters to only include those defined in tool spec + filtered_input = self._filter_tool_parameters_for_recording(tool["name"], tool["input"]) + + # Create user message describing the tool call + input_parameters = json.dumps(filtered_input, default=lambda o: f"<>") + + user_msg_content = [ + {"text": (f"agent.tool.{tool['name']} direct tool call.\nInput parameters: {input_parameters}\n")} + ] + + # Add override message if provided + if user_message_override: + user_msg_content.insert(0, {"text": f"{user_message_override}\n"}) + + # Create filtered tool use for message history + filtered_tool: ToolUse = { + "toolUseId": tool["toolUseId"], + "name": tool["name"], + "input": filtered_input, + } + + # Create the message sequence + user_msg: Message = { + "role": "user", + "content": user_msg_content, + } + tool_use_msg: Message = { + "role": "assistant", + "content": [{"toolUse": filtered_tool}], + } + tool_result_msg: Message = { + "role": "user", + "content": [{"toolResult": tool_result}], + } + assistant_msg: Message = { + "role": "assistant", + "content": [{"text": f"agent.tool.{tool['name']} was called."}], + } + + # Add to message history + self.messages.append(user_msg) + self.messages.append(tool_use_msg) + self.messages.append(tool_result_msg) + self.messages.append(assistant_msg) + + logger.debug("Direct tool call recorded in message history: %s", tool["name"]) + + def _filter_tool_parameters_for_recording(self, tool_name: str, input_params: dict[str, Any]) -> dict[str, Any]: + """Filter input parameters to only include those defined in the tool specification. + + Args: + tool_name: Name of the tool to get specification for + input_params: Original input parameters + + Returns: + Filtered parameters containing only those defined in tool spec + """ + all_tools_config = self.tool_registry.get_all_tools_config() + tool_spec = all_tools_config.get(tool_name) + + if not tool_spec or "inputSchema" not in tool_spec: + return input_params.copy() + + properties = tool_spec["inputSchema"]["json"]["properties"] + return {k: v for k, v in input_params.items() if k in properties} + async def start(self) -> None: """Start a persistent bidirectional conversation session. diff --git a/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py b/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py index 16be08aaf..69f5d759d 100644 --- a/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py +++ b/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py @@ -12,12 +12,13 @@ """ import asyncio -import json import logging import traceback import uuid from ....tools._validator import validate_and_prepare_tools +from ....telemetry.metrics import Trace +from ....types._events import ToolResultEvent, ToolStreamEvent from ....types.content import Message from ....types.tools import ToolResult, ToolUse from ..models.bidirectional_model import BidirectionalModelSession @@ -59,6 +60,9 @@ def __init__(self, model_session: BidirectionalModelSession, agent: "Bidirection # Interruption handling (model-agnostic) self.interrupted = False self.interruption_lock = asyncio.Lock() + + # Tool execution tracking + self.tool_count = 0 async def start_bidirectional_connection(agent: "BidirectionalAgent") -> BidirectionalConnection: @@ -195,11 +199,11 @@ async def _handle_interruption(session: BidirectionalConnection) -> None: # Cancel all pending tool execution tasks cancelled_tools = 0 - for task_id, task in list(session.pending_tool_tasks.items()): + for _task_id, task in list(session.pending_tool_tasks.items()): if not task.done(): task.cancel() cancelled_tools += 1 - logger.debug("Tool task cancelled: %s", task_id) + logger.debug("Tool task cancelled: %s", _task_id) if cancelled_tools > 0: logger.debug("Tool tasks cancelled: %d", cancelled_tools) @@ -274,7 +278,8 @@ async def _process_model_events(session: BidirectionalConnection) -> None: # Queue tool requests for concurrent execution if strands_event.get("toolUse"): - logger.debug("Tool queued: %s", strands_event["toolUse"].get("name")) + tool_name = strands_event["toolUse"].get("name") + logger.debug("Tool usage detected: %s", tool_name) await session.tool_queue.put(strands_event["toolUse"]) continue @@ -316,7 +321,13 @@ async def _process_tool_execution(session: BidirectionalConnection) -> None: while session.active: try: tool_use = await asyncio.wait_for(session.tool_queue.get(), timeout=TOOL_QUEUE_TIMEOUT) - logger.debug("Tool execution started: %s (id: %s)", tool_use.get("name"), tool_use.get("toolUseId")) + tool_name = tool_use.get("name") + tool_id = tool_use.get("toolUseId") + + session.tool_count += 1 + print(f"\nTool #{session.tool_count}: {tool_name}") + + logger.debug("Tool execution started: %s (id: %s)", tool_name, tool_id) task_id = str(uuid.uuid4()) task = asyncio.create_task(_execute_tool_with_strands(session, tool_use)) @@ -330,11 +341,11 @@ def cleanup_task(completed_task: asyncio.Task, task_id: str = task_id) -> None: # Log completion status if completed_task.cancelled(): - logger.debug("Tool task cleanup cancelled: %s", task_id) + logger.debug("Tool task cancelled: %s", task_id) elif completed_task.exception(): - logger.error("Tool task cleanup error: %s - %s", task_id, str(completed_task.exception())) + logger.error("Tool task error: %s - %s", task_id, str(completed_task.exception())) else: - logger.debug("Tool task cleanup success: %s", task_id) + logger.debug("Tool task completed: %s", task_id) except Exception as e: logger.error("Tool task cleanup failed: %s - %s", task_id, str(e)) @@ -365,94 +376,106 @@ def cleanup_task(completed_task: asyncio.Task, task_id: str = task_id) -> None: async def _execute_tool_with_strands(session: BidirectionalConnection, tool_use: dict) -> None: - """Execute tool using Strands infrastructure with interruption support. - - Executes tools using the existing Strands tool system with proper asyncio - cancellation handling. Tool execution is stopped via task cancellation, - not manual state checks. - + """Execute tool using the complete Strands tool execution system. + + Uses proper Strands ToolExecutor system with validation, error handling, + and event streaming. + Args: session: BidirectionalConnection for context. tool_use: Tool use event to execute. """ tool_name = tool_use.get("name") tool_id = tool_use.get("toolUseId") - + + logger.debug("Executing tool: %s (id: %s)", tool_name, tool_id) + try: - # Create message structure for existing tool system + # Create message structure for validation tool_message: Message = {"role": "assistant", "content": [{"toolUse": tool_use}]} - + + # Use Strands validation system tool_uses: list[ToolUse] = [] tool_results: list[ToolResult] = [] invalid_tool_use_ids: list[str] = [] - - # Validate using existing Strands validation + validate_and_prepare_tools(tool_message, tool_uses, tool_results, invalid_tool_use_ids) - - # Filter valid tool uses + + # Filter valid tools valid_tool_uses = [tu for tu in tool_uses if tu.get("toolUseId") not in invalid_tool_use_ids] - + if not valid_tool_uses: - logger.warning("Tool validation failed: %s (id: %s)", tool_name, tool_id) + logger.warning("No valid tools after validation: %s", tool_name) return - - # Execute tools directly (simpler approach for bidirectional) - for tool_use in valid_tool_uses: - tool_func = session.agent.tool_registry.registry.get(tool_use["name"]) - - if tool_func: - try: - actual_func = _extract_callable_function(tool_func) - - # Execute tool function with provided input - result = actual_func(**tool_use.get("input", {})) - - tool_result = _create_success_result(tool_use["toolUseId"], result) - tool_results.append(tool_result) - - except Exception as e: - logger.error("Tool execution failed: %s - %s", tool_name, str(e)) - tool_result = _create_error_result(tool_use["toolUseId"], str(e)) - tool_results.append(tool_result) - else: - logger.warning("Tool not found: %s", tool_name) - - # Send results through provider-specific session - for result in tool_results: - await session.model_session.send_tool_result(tool_use.get("toolUseId"), result) - - logger.debug("Tool execution completed: %s (%d results)", tool_name, len(tool_results)) - + + # Create invocation state for tool execution + invocation_state = { + "agent": session.agent, + "model": session.agent.model, + "messages": session.agent.messages, + "system_prompt": session.agent.system_prompt, + } + + # Create cycle trace and span + cycle_trace = Trace("Bidirectional Tool Execution") + cycle_span = None + + tool_events = session.agent.tool_executor._execute( + session.agent, + valid_tool_uses, + tool_results, + cycle_trace, + cycle_span, + invocation_state + ) + + # Process tool events and send results to provider + async for tool_event in tool_events: + if isinstance(tool_event, ToolResultEvent): + tool_result = tool_event.tool_result + tool_use_id = tool_result.get("toolUseId") + + # Send result through provider-specific session + await session.model_session.send_tool_result(tool_use_id, tool_result) + logger.debug("Tool result sent: %s", tool_use_id) + + # Handle streaming events if needed later + elif isinstance(tool_event, ToolStreamEvent): + logger.debug("Tool stream event: %s", tool_event) + pass + + # Add tool result message to conversation history + if tool_results: + from ....hooks import MessageAddedEvent + + tool_result_message: Message = { + "role": "user", + "content": [{"toolResult": result} for result in tool_results], + } + + session.agent.messages.append(tool_result_message) + session.agent.hooks.invoke_callbacks(MessageAddedEvent(agent=session.agent, message=tool_result_message)) + logger.debug("Tool result message added to history: %s", tool_name) + + logger.debug("Tool execution completed: %s", tool_name) + except asyncio.CancelledError: - # Task was cancelled due to interruption - this is expected behavior - logger.debug("Tool task cancelled gracefully: %s (id: %s)", tool_name, tool_id) - raise # Re-raise to properly handle cancellation + logger.debug("Tool execution cancelled: %s (id: %s)", tool_name, tool_id) + raise except Exception as e: - logger.error("Tool execution error: %s - %s", tool_use.get("name"), str(e)) + logger.error("Tool execution error: %s - %s", tool_name, str(e)) + # Send error result + error_result: ToolResult = { + "toolUseId": tool_id, + "status": "error", + "content": [{"text": f"Error: {str(e)}"}] + } try: - await session.model_session.send_tool_result(tool_use.get("toolUseId"), {"error": str(e)}) - except Exception as send_error: - logger.error("Tool error send failed: %s", str(send_error)) - - -def _extract_callable_function(tool_func: any) -> any: - """Extract the callable function from different tool object types.""" - if hasattr(tool_func, "_tool_func"): - return tool_func._tool_func - elif hasattr(tool_func, "func"): - return tool_func.func - elif callable(tool_func): - return tool_func - else: - raise ValueError(f"Tool function not callable: {type(tool_func).__name__}") - - -def _create_success_result(tool_use_id: str, result: any) -> dict[str, any]: - """Create a successful tool result.""" - return {"toolUseId": tool_use_id, "status": "success", "content": [{"text": json.dumps(result)}]} + await session.model_session.send_tool_result(tool_id, error_result) + logger.debug("Error result sent: %s", tool_id) + except Exception: + logger.error("Failed to send error result: %s", tool_id) + pass # Session might be closed -def _create_error_result(tool_use_id: str, error: str) -> dict[str, any]: - """Create an error tool result.""" - return {"toolUseId": tool_use_id, "status": "error", "content": [{"text": f"Error: {error}"}]} diff --git a/src/strands/experimental/bidirectional_streaming/models/__init__.py b/src/strands/experimental/bidirectional_streaming/models/__init__.py index 6cba974e0..882f89eef 100644 --- a/src/strands/experimental/bidirectional_streaming/models/__init__.py +++ b/src/strands/experimental/bidirectional_streaming/models/__init__.py @@ -3,4 +3,11 @@ from .bidirectional_model import BidirectionalModel, BidirectionalModelSession from .novasonic import NovaSonicBidirectionalModel, NovaSonicSession -__all__ = ["BidirectionalModel", "BidirectionalModelSession", "NovaSonicBidirectionalModel", "NovaSonicSession"] +__all__ = [ + "BidirectionalModel", + "BidirectionalModelSession", + "NovaSonicBidirectionalModel", + "NovaSonicSession", + "OpenAIRealtimeBidirectionalModel", + "OpenAIRealtimeSession", +] diff --git a/src/strands/experimental/bidirectional_streaming/models/novasonic.py b/src/strands/experimental/bidirectional_streaming/models/novasonic.py index 7f7937ef1..a1d61e11a 100644 --- a/src/strands/experimental/bidirectional_streaming/models/novasonic.py +++ b/src/strands/experimental/bidirectional_streaming/models/novasonic.py @@ -121,7 +121,7 @@ async def initialize( init_events = self._build_initialization_events(system_prompt, tools or [], messages) - logger.debug(f"Nova Sonic initialization - sending {len(init_events)} events") + logger.debug("Nova Sonic initialization - sending %d events", len(init_events)) await self._send_initialization_events(init_events) logger.info("Nova Sonic connection initialized successfully") @@ -146,7 +146,7 @@ def _build_initialization_events( async def _send_initialization_events(self, events: list[str]) -> None: """Send initialization events with required delays.""" - for i, event in enumerate(events): + for _i, event in enumerate(events): await self._send_nova_event(event) await asyncio.sleep(EVENT_DELAY) @@ -167,12 +167,12 @@ async def _process_responses(self) -> None: await asyncio.sleep(0.1) continue except Exception as e: - logger.warning(f"Nova Sonic response error: {e}") + logger.warning("Nova Sonic response error: %s", e) await asyncio.sleep(0.1) continue except Exception as e: - logger.error(f"Nova Sonic fatal error: {e}") + logger.error("Nova Sonic fatal error: %s", e) finally: logger.debug("Nova Sonic response processor stopped") @@ -190,7 +190,7 @@ async def _handle_response_data(self, response_data: str) -> None: await self._event_queue.put(nova_event) except json.JSONDecodeError as e: - logger.warning(f"Nova Sonic JSON decode error: {e}") + logger.warning("Nova Sonic JSON decode error: %s", e) def _log_event_type(self, nova_event: dict[str, any]) -> None: """Log specific Nova Sonic event types for debugging.""" @@ -383,11 +383,9 @@ async def send_tool_result(self, tool_use_id: str, result: dict[str, any]) -> No self._get_content_end_event(content_name), ] - for i, event in enumerate(events): + for _i, event in enumerate(events): await self._send_nova_event(event) - - async def close(self) -> None: """Close Nova Sonic connection with proper cleanup sequence.""" if not self._active: @@ -490,7 +488,14 @@ def _convert_nova_event(self, nova_event: dict[str, any]) -> dict[str, any] | No # Handle usage events (ignore) elif "usageEvent" in nova_event: - return None + usage_data = nova_event["usageEvent"] + usage_metrics: UsageMetricsEvent = { + "totalTokens": usage_data.get("totalTokens"), + "inputTokens": usage_data.get("totalInputTokens"), + "outputTokens": usage_data.get("totalOutputTokens"), + "audioTokens": usage_data.get("details", {}).get("total", {}).get("output", {}).get("speechTokens"), + } + return {"usageMetrics": usage_metrics} # Handle content start events (track role) elif "contentStart" in nova_event: diff --git a/src/strands/experimental/bidirectional_streaming/models/openai.py b/src/strands/experimental/bidirectional_streaming/models/openai.py new file mode 100644 index 000000000..7c79e3e6c --- /dev/null +++ b/src/strands/experimental/bidirectional_streaming/models/openai.py @@ -0,0 +1,522 @@ +/Users/mehtarac/Desktop/sdk-python/src/strands/experimental/bidirectional_streaming/models/openai.py + +"""OpenAI Realtime API provider for Strands bidirectional streaming. + +Provides real-time audio and text communication through OpenAI's Realtime API +with WebSocket connections, voice activity detection, and function calling. +""" + +import asyncio +import base64 +import json +import logging +import uuid +from typing import AsyncIterable + +import websockets +from websockets.client import WebSocketClientProtocol +from websockets.exceptions import ConnectionClosed + +from ....types.content import Messages +from ....types.tools import ToolSpec, ToolUse +from ..types.bidirectional_streaming import ( + AudioInputEvent, + AudioOutputEvent, + BidirectionalConnectionEndEvent, + BidirectionalConnectionStartEvent, + BidirectionalStreamEvent, + TextOutputEvent, + VoiceActivityEvent, +) +from .bidirectional_model import BidirectionalModel, BidirectionalModelSession + +logger = logging.getLogger(__name__) + +# OpenAI Realtime API configuration +OPENAI_REALTIME_URL = "wss://api.openai.com/v1/realtime" +DEFAULT_MODEL = "gpt-realtime" + +AUDIO_FORMAT = {"type": "audio/pcm", "rate": 24000} + +DEFAULT_SESSION_CONFIG = { + "type": "realtime", + "instructions": "You are a helpful assistant. Please speak in English and keep your responses clear and concise.", + "output_modalities": ["audio"], + "audio": { + "input": { + "format": AUDIO_FORMAT, + "turn_detection": { + "type": "server_vad", + "threshold": 0.5, + "prefix_padding_ms": 300, + "silence_duration_ms": 500, + }, + }, + "output": {"format": AUDIO_FORMAT, "voice": "alloy"}, + }, +} + + +class OpenAIRealtimeSession(BidirectionalModelSession): + """OpenAI Realtime API session for real-time audio/text streaming. + + Manages WebSocket connection to OpenAI's Realtime API with automatic VAD, + function calling, and event conversion to Strands format. + """ + + def __init__(self, websocket: WebSocketClientProtocol, config: dict[str, any]) -> None: + """Initialize OpenAI Realtime session.""" + self.websocket = websocket + self.config = config + self.session_id = str(uuid.uuid4()) + self._active = True + + self._event_queue = asyncio.Queue() + self._response_task = None + self._function_call_buffer = {} + + logger.debug("OpenAI Realtime session initialized: %s", self.session_id) + + def _require_active(self) -> bool: + """Check if session is active.""" + return self._active + + def _create_text_event(self, text: str, role: str) -> dict[str, any]: + """Create standardized text output event.""" + text_output: TextOutputEvent = {"text": text, "role": role} + return {"textOutput": text_output} + + def _create_voice_activity_event(self, activity_type: str) -> dict[str, any]: + """Create standardized voice activity event.""" + voice_activity: VoiceActivityEvent = {"activityType": activity_type} + return {"voiceActivity": voice_activity} + + async def _create_conversation_item(self, item_data: dict) -> None: + """Create conversation item and trigger response.""" + await self._send_event({"type": "conversation.item.create", "item": item_data}) + await self._send_event({"type": "response.create"}) + + async def initialize( + self, + system_prompt: str | None = None, + tools: list[ToolSpec] | None = None, + messages: Messages | None = None, + ) -> None: + """Initialize session with configuration.""" + try: + session_config = self._build_session_config(system_prompt, tools) + await self._send_event({"type": "session.update", "session": session_config}) + + if messages: + await self._add_conversation_history(messages) + + self._response_task = asyncio.create_task(self._process_responses()) + logger.info("OpenAI Realtime session initialized successfully") + + except Exception as e: + logger.error("Error during OpenAI Realtime initialization: %s", e) + raise + + def _build_session_config(self, system_prompt: str | None, tools: list[ToolSpec] | None) -> dict: + """Build session configuration for OpenAI Realtime API.""" + config = DEFAULT_SESSION_CONFIG.copy() + + if system_prompt: + config["instructions"] = system_prompt + + if tools: + config["tools"] = self._convert_tools_to_openai_format(tools) + + custom_config = self.config.get("session", {}) + supported_params = { + "type", + "output_modalities", + "instructions", + "voice", + "audio", + "tools", + "tool_choice", + "input_audio_format", + "output_audio_format", + "input_audio_transcription", + "turn_detection", + } + + for key, value in custom_config.items(): + if key in supported_params: + config[key] = value + else: + logger.warning("Ignoring unsupported session parameter: %s", key) + + return config + + def _convert_tools_to_openai_format(self, tools: list[ToolSpec]) -> list[dict]: + """Convert Strands tool specifications to OpenAI Realtime API format.""" + openai_tools = [] + + for tool in tools: + input_schema = tool["inputSchema"] + if "json" in input_schema: + schema = ( + json.loads(input_schema["json"]) if isinstance(input_schema["json"], str) else input_schema["json"] + ) + else: + schema = input_schema + + # OpenAI Realtime API expects flat structure, not nested under "function" + openai_tool = { + "type": "function", + "name": tool["name"], + "description": tool["description"], + "parameters": schema, + } + openai_tools.append(openai_tool) + + return openai_tools + + async def _add_conversation_history(self, messages: Messages) -> None: + """Add conversation history to the session.""" + for message in messages: + conversation_item = { + "type": "conversation.item.create", + "item": {"type": "message", "role": message["role"], "content": []}, + } + + content = message.get("content", "") + if isinstance(content, str): + conversation_item["item"]["content"].append({"type": "input_text", "text": content}) + elif isinstance(content, list): + for item in content: + if isinstance(item, dict) and item.get("type") == "text": + conversation_item["item"]["content"].append( + {"type": "input_text", "text": item.get("text", "")} + ) + + await self._send_event(conversation_item) + + async def _process_responses(self) -> None: + """Process incoming WebSocket messages.""" + logger.debug("OpenAI Realtime response processor started") + + try: + async for message in self.websocket: + if not self._active: + break + + try: + event = json.loads(message) + await self._event_queue.put(event) + except json.JSONDecodeError as e: + logger.warning("Failed to parse OpenAI event: %s", e) + continue + + except ConnectionClosed: + logger.debug("OpenAI Realtime WebSocket connection closed") + except Exception as e: + logger.error("Error in OpenAI Realtime response processing: %s", e) + finally: + self._active = False + logger.debug("OpenAI Realtime response processor stopped") + + async def receive_events(self) -> AsyncIterable[BidirectionalStreamEvent]: + """Receive OpenAI events and convert to Strands format.""" + connection_start: BidirectionalConnectionStartEvent = { + "connectionId": self.session_id, + "metadata": {"provider": "openai_realtime", "model": self.config.get("model", DEFAULT_MODEL)}, + } + yield {"BidirectionalConnectionStart": connection_start} + + try: + while self._active: + try: + openai_event = await asyncio.wait_for(self._event_queue.get(), timeout=1.0) + provider_event = self._convert_openai_event(openai_event) + if provider_event: + yield provider_event + except asyncio.TimeoutError: + continue + + except Exception as e: + logger.error("Error receiving OpenAI Realtime event: %s", e) + finally: + connection_end: BidirectionalConnectionEndEvent = { + "connectionId": self.session_id, + "reason": "connection_complete", + "metadata": {"provider": "openai_realtime"}, + } + yield {"BidirectionalConnectionEnd": connection_end} + + def _convert_openai_event(self, openai_event: dict[str, any]) -> dict[str, any] | None: + """Convert OpenAI events to Strands format.""" + event_type = openai_event.get("type") + + # Audio output + if event_type == "response.output_audio.delta": + audio_data = base64.b64decode(openai_event["delta"]) + audio_output: AudioOutputEvent = { + "audioData": audio_data, + "format": "pcm", + "sampleRate": 24000, + "channels": 1, + "encoding": None, + } + return {"audioOutput": audio_output} + + # Text output using helper method + elif event_type == "response.output_text.delta": + return self._create_text_event(openai_event["delta"], "assistant") + + elif event_type == "response.output_audio_transcript.delta": + return self._create_text_event(openai_event["delta"], "assistant") + + # User transcription + elif event_type == "conversation.item.input_audio_transcription.delta": + transcript_delta = openai_event.get("delta", "") + return self._create_text_event(transcript_delta, "user") if transcript_delta.strip() else None + + elif event_type == "conversation.item.input_audio_transcription.completed": + transcript = openai_event.get("transcript", "") + return self._create_text_event(transcript, "user") if transcript.strip() else None + + elif event_type == "conversation.item.input_audio_transcription.segment": + segment_data = openai_event.get("segment", {}) + text = segment_data.get("text", "") + return self._create_text_event(text, "user") if text.strip() else None + + elif event_type == "conversation.item.input_audio_transcription.failed": + error_info = openai_event.get("error", {}) + logger.warning("OpenAI transcription failed: %s", error_info.get("message", "Unknown error")) + return None + + # Function call processing + elif event_type == "response.function_call_arguments.delta": + call_id = openai_event.get("call_id") + delta = openai_event.get("delta", "") + if call_id: + if call_id not in self._function_call_buffer: + self._function_call_buffer[call_id] = {"call_id": call_id, "name": "", "arguments": delta} + else: + self._function_call_buffer[call_id]["arguments"] += delta + return None + + elif event_type == "response.function_call_arguments.done": + call_id = openai_event.get("call_id") + if call_id and call_id in self._function_call_buffer: + function_call = self._function_call_buffer[call_id] + try: + tool_use: ToolUse = { + "toolUseId": call_id, + "name": function_call["name"], + "input": json.loads(function_call["arguments"]) if function_call["arguments"] else {}, + } + del self._function_call_buffer[call_id] + return {"toolUse": tool_use} + except (json.JSONDecodeError, KeyError) as e: + logger.warning("Error parsing function arguments for %s: %s", call_id, e) + del self._function_call_buffer[call_id] + return None + + # Voice activity detection using helper method + elif event_type == "input_audio_buffer.speech_started": + return self._create_voice_activity_event("speech_started") + elif event_type == "input_audio_buffer.speech_stopped": + return self._create_voice_activity_event("speech_stopped") + elif event_type == "input_audio_buffer.timeout_triggered": + return self._create_voice_activity_event("timeout") + + # Lifecycle events (log only) + elif event_type == "conversation.item.retrieve": + item = openai_event.get("item", {}) + logger.debug("OpenAI conversation item retrieved: %s", item.get("id")) + return None + + elif event_type == "conversation.item.added": + logger.debug("OpenAI conversation item added: %s", openai_event.get("item", {}).get("id")) + return None + + elif event_type == "conversation.item.done": + logger.debug("OpenAI conversation item done: %s", openai_event.get("item", {}).get("id")) + + item = openai_event.get("item", {}) + if item.get("type") == "message" and item.get("role") == "assistant": + content_parts = item.get("content", []) + if content_parts: + message_content = [] + for content_part in content_parts: + if content_part.get("type") == "output_text": + message_content.append({"type": "text", "text": content_part.get("text", "")}) + elif content_part.get("type") == "output_audio": + transcript = content_part.get("transcript", "") + if transcript: + message_content.append({"type": "text", "text": transcript}) + + if message_content: + message = {"role": "assistant", "content": message_content} + return {"messageStop": {"message": message}} + return None + + elif event_type in [ + "response.output_item.added", + "response.output_item.done", + "response.content_part.added", + "response.content_part.done", + ]: + item_data = openai_event.get("item") or openai_event.get("part") + logger.debug("OpenAI %s: %s", event_type, item_data.get("id") if item_data else "unknown") + + # Track function call names from response.output_item.added + if event_type == "response.output_item.added": + item = openai_event.get("item", {}) + if item.get("type") == "function_call": + call_id = item.get("call_id") + function_name = item.get("name") + if call_id and function_name: + if call_id not in self._function_call_buffer: + self._function_call_buffer[call_id] = { + "call_id": call_id, + "name": function_name, + "arguments": "", + } + else: + self._function_call_buffer[call_id]["name"] = function_name + return None + + elif event_type in [ + "input_audio_buffer.committed", + "input_audio_buffer.cleared", + "session.created", + "session.updated", + ]: + logger.debug("OpenAI %s event", event_type) + return None + + elif event_type == "error": + logger.error("OpenAI Realtime error: %s", openai_event.get("error", {})) + return None + + else: + logger.debug("Unhandled OpenAI event type: %s", event_type) + return None + + async def send_audio_content(self, audio_input: AudioInputEvent) -> None: + """Send audio content to OpenAI for processing.""" + if not self._require_active(): + return + + audio_base64 = base64.b64encode(audio_input["audioData"]).decode("utf-8") + await self._send_event({"type": "input_audio_buffer.append", "audio": audio_base64}) + + async def send_text_content(self, text: str, **kwargs) -> None: + """Send text content to OpenAI for processing.""" + if not self._require_active(): + return + + item_data = {"type": "message", "role": "user", "content": [{"type": "input_text", "text": text}]} + await self._create_conversation_item(item_data) + + async def send_interrupt(self) -> None: + """Send interruption signal to OpenAI.""" + if not self._require_active(): + return + + await self._send_event({"type": "response.cancel"}) + + async def send_tool_result(self, tool_use_id: str, result: dict[str, any]) -> None: + """Send tool result back to OpenAI.""" + if not self._require_active(): + return + + logger.debug("OpenAI tool result send: %s", tool_use_id) + result_text = json.dumps(result) if not isinstance(result, str) else result + + item_data = {"type": "function_call_output", "call_id": tool_use_id, "output": result_text} + await self._create_conversation_item(item_data) + + async def close(self) -> None: + """Close session and cleanup resources.""" + if not self._active: + return + + logger.debug("OpenAI Realtime cleanup - starting connection close") + self._active = False + + if self._response_task and not self._response_task.done(): + self._response_task.cancel() + try: + await self._response_task + except asyncio.CancelledError: + pass + + try: + await self.websocket.close() + except Exception as e: + logger.warning("Error closing OpenAI Realtime WebSocket: %s", e) + + logger.debug("OpenAI Realtime connection closed") + + async def _send_event(self, event: dict[str, any]) -> None: + """Send event to OpenAI via WebSocket.""" + try: + message = json.dumps(event) + await self.websocket.send(message) + logger.debug("Sent OpenAI event: %s", event.get("type")) + except Exception as e: + logger.error("Error sending OpenAI event: %s", e) + raise + + +class OpenAIRealtimeBidirectionalModel(BidirectionalModel): + """OpenAI Realtime API provider for Strands bidirectional streaming. + + Provides real-time audio/text communication through OpenAI's Realtime API + with WebSocket connections, voice activity detection, and function calling. + """ + + def __init__(self, model: str = DEFAULT_MODEL, api_key: str | None = None, **config: any) -> None: + """Initialize OpenAI Realtime bidirectional model.""" + self.model = model + self.api_key = api_key + self.config = config + + import os + + if not self.api_key: + self.api_key = os.getenv("OPENAI_API_KEY") + if not self.api_key: + raise ValueError( + "OpenAI API key is required. Set OPENAI_API_KEY environment variable or pass api_key parameter." + ) + + logger.debug("OpenAI Realtime bidirectional model initialized: %s", model) + + async def create_bidirectional_connection( + self, + system_prompt: str | None = None, + tools: list[ToolSpec] | None = None, + messages: Messages | None = None, + **kwargs, + ) -> BidirectionalModelSession: + """Create bidirectional connection to OpenAI Realtime API.""" + logger.info("Creating OpenAI Realtime connection...") + + try: + url = f"{OPENAI_REALTIME_URL}?model={self.model}" + + headers = [("Authorization", f"Bearer {self.api_key}")] + if "organization" in self.config: + headers.append(("OpenAI-Organization", self.config["organization"])) + if "project" in self.config: + headers.append(("OpenAI-Project", self.config["project"])) + + websocket = await websockets.connect(url, additional_headers=headers) + logger.info("WebSocket connected successfully") + + session = OpenAIRealtimeSession(websocket, self.config) + await session.initialize(system_prompt, tools, messages) + + logger.info("OpenAI Realtime connection established") + return session + + except Exception as e: + logger.error("OpenAI connection error: %s", e) + raise diff --git a/src/strands/experimental/bidirectional_streaming/tests/test_bidi_openai.py b/src/strands/experimental/bidirectional_streaming/tests/test_bidi_openai.py new file mode 100644 index 000000000..5ce4b8cb2 --- /dev/null +++ b/src/strands/experimental/bidirectional_streaming/tests/test_bidi_openai.py @@ -0,0 +1,317 @@ +#!/usr/bin/env python3 +"""Test OpenAI Realtime API speech-to-speech interaction.""" + +import asyncio +import os +import sys +import time +from pathlib import Path + +# Add the src directory to Python path +sys.path.insert(0, str(Path(__file__).parent / "src")) + +import pyaudio +from strands_tools import calculator + +from strands.experimental.bidirectional_streaming.agent.agent import BidirectionalAgent +from strands.experimental.bidirectional_streaming.models.openai import OpenAIRealtimeBidirectionalModel + + +def test_direct_tool_calling(): + """Test direct tool calling functionality.""" + print("Testing direct tool calling...") + + try: + api_key = os.getenv("OPENAI_API_KEY") + if not api_key: + print("OPENAI_API_KEY not set - skipping test") + return + + model = OpenAIRealtimeBidirectionalModel(model="gpt-4o-realtime-preview", api_key=api_key) + agent = BidirectionalAgent(model=model, tools=[calculator]) + + # Test calculator + result = agent.tool.calculator(expression="2 * 3") + content = result.get("content", [{}])[0].get("text", "") + print(f"Result: {content}") + print("Test completed") + + except Exception as e: + print(f"Test failed: {e}") + + +async def play(context): + """Handle audio playback with interruption support.""" + audio = pyaudio.PyAudio() + + try: + speaker = audio.open( + format=pyaudio.paInt16, + channels=1, + rate=24000, # OpenAI Realtime uses 24kHz + output=True, + frames_per_buffer=1024, + ) + + while context["active"]: + try: + # Check for interruption + if context.get("interrupted", False): + # Clear audio queue on interruption + while not context["audio_out"].empty(): + try: + context["audio_out"].get_nowait() + except asyncio.QueueEmpty: + break + + context["interrupted"] = False + await asyncio.sleep(0.05) + continue + + # Get audio data with timeout + try: + audio_data = await asyncio.wait_for(context["audio_out"].get(), timeout=0.1) + + if audio_data and context["active"]: + # Play in chunks to allow interruption + chunk_size = 1024 + for i in range(0, len(audio_data), chunk_size): + if context.get("interrupted", False) or not context["active"]: + break + + chunk = audio_data[i:i + chunk_size] + speaker.write(chunk) + await asyncio.sleep(0.001) # Brief pause for responsiveness + + except asyncio.TimeoutError: + continue + + except asyncio.CancelledError: + break + + except asyncio.CancelledError: + pass + except Exception as e: + print(f"Audio playback error: {e}") + finally: + try: + speaker.close() + except Exception: + pass + audio.terminate() + + +async def record(context): + """Handle microphone recording.""" + audio = pyaudio.PyAudio() + + try: + microphone = audio.open( + format=pyaudio.paInt16, + channels=1, + rate=24000, # Match OpenAI's expected input rate + input=True, + frames_per_buffer=1024, + ) + + while context["active"]: + try: + audio_bytes = microphone.read(1024, exception_on_overflow=False) + await context["audio_in"].put(audio_bytes) + await asyncio.sleep(0.01) + except asyncio.CancelledError: + break + + except asyncio.CancelledError: + pass + except Exception as e: + print(f"Microphone recording error: {e}") + finally: + try: + microphone.close() + except Exception: + pass + audio.terminate() + + +async def receive(agent, context): + """Handle events from the agent.""" + try: + async for event in agent.receive(): + if not context["active"]: + break + + # Handle audio output + if "audioOutput" in event: + audio_data = event["audioOutput"]["audioData"] + + if not context.get("interrupted", False): + await context["audio_out"].put(audio_data) + + # Handle text output (transcripts) + elif "textOutput" in event: + text_output = event["textOutput"] + role = text_output.get("role", "assistant") + text = text_output.get("text", "").strip() + + if text: + if role == "user": + print(f"User: {text}") + elif role == "assistant": + print(f"Assistant: {text}") + + # Handle interruption detection + elif "interruptionDetected" in event: + context["interrupted"] = True + + # Handle connection events + elif "BidirectionalConnectionStart" in event: + pass # Silent connection start + elif "BidirectionalConnectionEnd" in event: + context["active"] = False + break + + except asyncio.CancelledError: + pass + except Exception as e: + print(f"Receive handler error: {e}") + finally: + pass + + +async def send(agent, context): + """Send audio from microphone to agent.""" + try: + while context["active"]: + try: + audio_bytes = await asyncio.wait_for(context["audio_in"].get(), timeout=0.1) + + # Create audio event in expected format + audio_event = { + "audioData": audio_bytes, + "format": "pcm", + "sampleRate": 24000, + "channels": 1 + } + + await agent.send(audio_event) + + except asyncio.TimeoutError: + continue + except asyncio.CancelledError: + break + + except asyncio.CancelledError: + pass + except Exception as e: + print(f"Send handler error: {e}") + finally: + pass + + +async def main(): + """Main test function for OpenAI voice chat.""" + print("Starting OpenAI Realtime API test...") + + # Check API key + api_key = os.getenv("OPENAI_API_KEY") + if not api_key: + print("OPENAI_API_KEY environment variable not set") + return False + + # Check audio system + try: + audio = pyaudio.PyAudio() + audio.terminate() + except Exception as e: + print(f"Audio system error: {e}") + return False + + # Create OpenAI model + model = OpenAIRealtimeBidirectionalModel( + model="gpt-4o-realtime-preview", + api_key=api_key, + session={ + "output_modalities": ["audio"], + "audio": { + "input": { + "format": {"type": "audio/pcm", "rate": 24000}, + "turn_detection": { + "type": "server_vad", + "threshold": 0.5, + "silence_duration_ms": 700 + } + }, + "output": { + "format": {"type": "audio/pcm", "rate": 24000}, + "voice": "alloy" + } + } + } + ) + + # Create agent + agent = BidirectionalAgent( + model=model, + tools=[calculator], + system_prompt=( + "You are a helpful voice assistant. Keep your responses brief and natural. " + "Say hello when you first connect." + ) + ) + + # Start the session + await agent.start() + + # Create shared context + context = { + "active": True, + "audio_in": asyncio.Queue(), + "audio_out": asyncio.Queue(), + "interrupted": False, + "start_time": time.time() + } + + print("Speak into your microphone. Press Ctrl+C to stop.") + + try: + # Run all tasks concurrently + await asyncio.gather( + play(context), + record(context), + receive(agent, context), + send(agent, context), + return_exceptions=True + ) + + except KeyboardInterrupt: + print("\nInterrupted by user") + except asyncio.CancelledError: + print("\nTest cancelled") + except Exception as e: + print(f"\nError during voice chat: {e}") + finally: + print("Cleaning up...") + context["active"] = False + + try: + await agent.end() + except Exception as e: + print(f"Cleanup error: {e}") + + return True + + +if __name__ == "__main__": + # Test direct tool calling first + print("OpenAI Realtime API Test Suite") + test_direct_tool_calling() + + try: + asyncio.run(main()) + except KeyboardInterrupt: + print("\nTest interrupted by user") + except Exception as e: + print(f"Test error: {e}") + import traceback + traceback.print_exc() \ No newline at end of file diff --git a/src/strands/experimental/bidirectional_streaming/tests/test_bidirectional_streaming.py b/src/strands/experimental/bidirectional_streaming/tests/test_bidirectional_streaming.py index b31607966..8c3ae3b4c 100644 --- a/src/strands/experimental/bidirectional_streaming/tests/test_bidirectional_streaming.py +++ b/src/strands/experimental/bidirectional_streaming/tests/test_bidirectional_streaming.py @@ -10,6 +10,7 @@ # Add the src directory to Python path sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent.parent)) +import os import time import pyaudio @@ -19,6 +20,29 @@ from strands.experimental.bidirectional_streaming.models.novasonic import NovaSonicBidirectionalModel +def test_direct_tools(): + """Test direct tool calling.""" + print("Testing direct tool calling...") + + # Check AWS credentials + if not all([os.getenv("AWS_ACCESS_KEY_ID"), os.getenv("AWS_SECRET_ACCESS_KEY")]): + print("AWS credentials not set - skipping test") + return + + try: + model = NovaSonicBidirectionalModel() + agent = BidirectionalAgent(model=model, tools=[calculator]) + + # Test calculator + result = agent.tool.calculator(expression="2 * 3") + content = result.get("content", [{}])[0].get("text", "") + print(f"Result: {content}") + print("Test completed") + + except Exception as e: + print(f"Test failed: {e}") + + async def play(context): """Play audio output with responsive interruption support.""" audio = pyaudio.PyAudio() @@ -195,4 +219,7 @@ async def main(duration=180): if __name__ == "__main__": + # Test direct tool calling first + test_direct_tools() + asyncio.run(main()) From ee12db36c34e786fef880d9699d6696d41ffa14c Mon Sep 17 00:00:00 2001 From: Rachit Mehta Date: Tue, 14 Oct 2025 08:41:45 -0400 Subject: [PATCH 19/23] feat(tool_executor): Plug tool executor into bidirectional streaming implementation --- .../bidirectional_streaming/__init__.py | 4 - .../models/__init__.py | 2 - .../models/novasonic.py | 1 + .../bidirectional_streaming/models/openai.py | 522 ------------------ ...al_streaming.py => test_bidi_novasonic.py} | 0 .../tests/test_bidi_openai.py | 317 ----------- .../types/bidirectional_streaming.py | 35 +- 7 files changed, 29 insertions(+), 852 deletions(-) delete mode 100644 src/strands/experimental/bidirectional_streaming/models/openai.py rename src/strands/experimental/bidirectional_streaming/tests/{test_bidirectional_streaming.py => test_bidi_novasonic.py} (100%) delete mode 100644 src/strands/experimental/bidirectional_streaming/tests/test_bidi_openai.py diff --git a/src/strands/experimental/bidirectional_streaming/__init__.py b/src/strands/experimental/bidirectional_streaming/__init__.py index 844a8a1f8..0f842ee9f 100644 --- a/src/strands/experimental/bidirectional_streaming/__init__.py +++ b/src/strands/experimental/bidirectional_streaming/__init__.py @@ -8,7 +8,6 @@ # Model providers - What users need to create models from .models.novasonic import NovaSonicBidirectionalModel -from .models.openai import OpenAIRealtimeBidirectionalModel # Event types - For type hints and event handling from .types.bidirectional_streaming import ( @@ -18,7 +17,6 @@ InterruptionDetectedEvent, TextOutputEvent, UsageMetricsEvent, - VoiceActivityEvent, ) __all__ = [ @@ -26,14 +24,12 @@ "BidirectionalAgent", # Model providers "NovaSonicBidirectionalModel", - "OpenAIRealtimeBidirectionalModel", # Event types "AudioInputEvent", "AudioOutputEvent", "TextOutputEvent", "InterruptionDetectedEvent", "BidirectionalStreamEvent", - "VoiceActivityEvent", "UsageMetricsEvent", # Model interface "BidirectionalModel", diff --git a/src/strands/experimental/bidirectional_streaming/models/__init__.py b/src/strands/experimental/bidirectional_streaming/models/__init__.py index 882f89eef..3a785e98a 100644 --- a/src/strands/experimental/bidirectional_streaming/models/__init__.py +++ b/src/strands/experimental/bidirectional_streaming/models/__init__.py @@ -8,6 +8,4 @@ "BidirectionalModelSession", "NovaSonicBidirectionalModel", "NovaSonicSession", - "OpenAIRealtimeBidirectionalModel", - "OpenAIRealtimeSession", ] diff --git a/src/strands/experimental/bidirectional_streaming/models/novasonic.py b/src/strands/experimental/bidirectional_streaming/models/novasonic.py index a1d61e11a..7f35a3c1c 100644 --- a/src/strands/experimental/bidirectional_streaming/models/novasonic.py +++ b/src/strands/experimental/bidirectional_streaming/models/novasonic.py @@ -35,6 +35,7 @@ BidirectionalConnectionStartEvent, InterruptionDetectedEvent, TextOutputEvent, + UsageMetricsEvent ) from .bidirectional_model import BidirectionalModel, BidirectionalModelSession diff --git a/src/strands/experimental/bidirectional_streaming/models/openai.py b/src/strands/experimental/bidirectional_streaming/models/openai.py deleted file mode 100644 index 7c79e3e6c..000000000 --- a/src/strands/experimental/bidirectional_streaming/models/openai.py +++ /dev/null @@ -1,522 +0,0 @@ -/Users/mehtarac/Desktop/sdk-python/src/strands/experimental/bidirectional_streaming/models/openai.py - -"""OpenAI Realtime API provider for Strands bidirectional streaming. - -Provides real-time audio and text communication through OpenAI's Realtime API -with WebSocket connections, voice activity detection, and function calling. -""" - -import asyncio -import base64 -import json -import logging -import uuid -from typing import AsyncIterable - -import websockets -from websockets.client import WebSocketClientProtocol -from websockets.exceptions import ConnectionClosed - -from ....types.content import Messages -from ....types.tools import ToolSpec, ToolUse -from ..types.bidirectional_streaming import ( - AudioInputEvent, - AudioOutputEvent, - BidirectionalConnectionEndEvent, - BidirectionalConnectionStartEvent, - BidirectionalStreamEvent, - TextOutputEvent, - VoiceActivityEvent, -) -from .bidirectional_model import BidirectionalModel, BidirectionalModelSession - -logger = logging.getLogger(__name__) - -# OpenAI Realtime API configuration -OPENAI_REALTIME_URL = "wss://api.openai.com/v1/realtime" -DEFAULT_MODEL = "gpt-realtime" - -AUDIO_FORMAT = {"type": "audio/pcm", "rate": 24000} - -DEFAULT_SESSION_CONFIG = { - "type": "realtime", - "instructions": "You are a helpful assistant. Please speak in English and keep your responses clear and concise.", - "output_modalities": ["audio"], - "audio": { - "input": { - "format": AUDIO_FORMAT, - "turn_detection": { - "type": "server_vad", - "threshold": 0.5, - "prefix_padding_ms": 300, - "silence_duration_ms": 500, - }, - }, - "output": {"format": AUDIO_FORMAT, "voice": "alloy"}, - }, -} - - -class OpenAIRealtimeSession(BidirectionalModelSession): - """OpenAI Realtime API session for real-time audio/text streaming. - - Manages WebSocket connection to OpenAI's Realtime API with automatic VAD, - function calling, and event conversion to Strands format. - """ - - def __init__(self, websocket: WebSocketClientProtocol, config: dict[str, any]) -> None: - """Initialize OpenAI Realtime session.""" - self.websocket = websocket - self.config = config - self.session_id = str(uuid.uuid4()) - self._active = True - - self._event_queue = asyncio.Queue() - self._response_task = None - self._function_call_buffer = {} - - logger.debug("OpenAI Realtime session initialized: %s", self.session_id) - - def _require_active(self) -> bool: - """Check if session is active.""" - return self._active - - def _create_text_event(self, text: str, role: str) -> dict[str, any]: - """Create standardized text output event.""" - text_output: TextOutputEvent = {"text": text, "role": role} - return {"textOutput": text_output} - - def _create_voice_activity_event(self, activity_type: str) -> dict[str, any]: - """Create standardized voice activity event.""" - voice_activity: VoiceActivityEvent = {"activityType": activity_type} - return {"voiceActivity": voice_activity} - - async def _create_conversation_item(self, item_data: dict) -> None: - """Create conversation item and trigger response.""" - await self._send_event({"type": "conversation.item.create", "item": item_data}) - await self._send_event({"type": "response.create"}) - - async def initialize( - self, - system_prompt: str | None = None, - tools: list[ToolSpec] | None = None, - messages: Messages | None = None, - ) -> None: - """Initialize session with configuration.""" - try: - session_config = self._build_session_config(system_prompt, tools) - await self._send_event({"type": "session.update", "session": session_config}) - - if messages: - await self._add_conversation_history(messages) - - self._response_task = asyncio.create_task(self._process_responses()) - logger.info("OpenAI Realtime session initialized successfully") - - except Exception as e: - logger.error("Error during OpenAI Realtime initialization: %s", e) - raise - - def _build_session_config(self, system_prompt: str | None, tools: list[ToolSpec] | None) -> dict: - """Build session configuration for OpenAI Realtime API.""" - config = DEFAULT_SESSION_CONFIG.copy() - - if system_prompt: - config["instructions"] = system_prompt - - if tools: - config["tools"] = self._convert_tools_to_openai_format(tools) - - custom_config = self.config.get("session", {}) - supported_params = { - "type", - "output_modalities", - "instructions", - "voice", - "audio", - "tools", - "tool_choice", - "input_audio_format", - "output_audio_format", - "input_audio_transcription", - "turn_detection", - } - - for key, value in custom_config.items(): - if key in supported_params: - config[key] = value - else: - logger.warning("Ignoring unsupported session parameter: %s", key) - - return config - - def _convert_tools_to_openai_format(self, tools: list[ToolSpec]) -> list[dict]: - """Convert Strands tool specifications to OpenAI Realtime API format.""" - openai_tools = [] - - for tool in tools: - input_schema = tool["inputSchema"] - if "json" in input_schema: - schema = ( - json.loads(input_schema["json"]) if isinstance(input_schema["json"], str) else input_schema["json"] - ) - else: - schema = input_schema - - # OpenAI Realtime API expects flat structure, not nested under "function" - openai_tool = { - "type": "function", - "name": tool["name"], - "description": tool["description"], - "parameters": schema, - } - openai_tools.append(openai_tool) - - return openai_tools - - async def _add_conversation_history(self, messages: Messages) -> None: - """Add conversation history to the session.""" - for message in messages: - conversation_item = { - "type": "conversation.item.create", - "item": {"type": "message", "role": message["role"], "content": []}, - } - - content = message.get("content", "") - if isinstance(content, str): - conversation_item["item"]["content"].append({"type": "input_text", "text": content}) - elif isinstance(content, list): - for item in content: - if isinstance(item, dict) and item.get("type") == "text": - conversation_item["item"]["content"].append( - {"type": "input_text", "text": item.get("text", "")} - ) - - await self._send_event(conversation_item) - - async def _process_responses(self) -> None: - """Process incoming WebSocket messages.""" - logger.debug("OpenAI Realtime response processor started") - - try: - async for message in self.websocket: - if not self._active: - break - - try: - event = json.loads(message) - await self._event_queue.put(event) - except json.JSONDecodeError as e: - logger.warning("Failed to parse OpenAI event: %s", e) - continue - - except ConnectionClosed: - logger.debug("OpenAI Realtime WebSocket connection closed") - except Exception as e: - logger.error("Error in OpenAI Realtime response processing: %s", e) - finally: - self._active = False - logger.debug("OpenAI Realtime response processor stopped") - - async def receive_events(self) -> AsyncIterable[BidirectionalStreamEvent]: - """Receive OpenAI events and convert to Strands format.""" - connection_start: BidirectionalConnectionStartEvent = { - "connectionId": self.session_id, - "metadata": {"provider": "openai_realtime", "model": self.config.get("model", DEFAULT_MODEL)}, - } - yield {"BidirectionalConnectionStart": connection_start} - - try: - while self._active: - try: - openai_event = await asyncio.wait_for(self._event_queue.get(), timeout=1.0) - provider_event = self._convert_openai_event(openai_event) - if provider_event: - yield provider_event - except asyncio.TimeoutError: - continue - - except Exception as e: - logger.error("Error receiving OpenAI Realtime event: %s", e) - finally: - connection_end: BidirectionalConnectionEndEvent = { - "connectionId": self.session_id, - "reason": "connection_complete", - "metadata": {"provider": "openai_realtime"}, - } - yield {"BidirectionalConnectionEnd": connection_end} - - def _convert_openai_event(self, openai_event: dict[str, any]) -> dict[str, any] | None: - """Convert OpenAI events to Strands format.""" - event_type = openai_event.get("type") - - # Audio output - if event_type == "response.output_audio.delta": - audio_data = base64.b64decode(openai_event["delta"]) - audio_output: AudioOutputEvent = { - "audioData": audio_data, - "format": "pcm", - "sampleRate": 24000, - "channels": 1, - "encoding": None, - } - return {"audioOutput": audio_output} - - # Text output using helper method - elif event_type == "response.output_text.delta": - return self._create_text_event(openai_event["delta"], "assistant") - - elif event_type == "response.output_audio_transcript.delta": - return self._create_text_event(openai_event["delta"], "assistant") - - # User transcription - elif event_type == "conversation.item.input_audio_transcription.delta": - transcript_delta = openai_event.get("delta", "") - return self._create_text_event(transcript_delta, "user") if transcript_delta.strip() else None - - elif event_type == "conversation.item.input_audio_transcription.completed": - transcript = openai_event.get("transcript", "") - return self._create_text_event(transcript, "user") if transcript.strip() else None - - elif event_type == "conversation.item.input_audio_transcription.segment": - segment_data = openai_event.get("segment", {}) - text = segment_data.get("text", "") - return self._create_text_event(text, "user") if text.strip() else None - - elif event_type == "conversation.item.input_audio_transcription.failed": - error_info = openai_event.get("error", {}) - logger.warning("OpenAI transcription failed: %s", error_info.get("message", "Unknown error")) - return None - - # Function call processing - elif event_type == "response.function_call_arguments.delta": - call_id = openai_event.get("call_id") - delta = openai_event.get("delta", "") - if call_id: - if call_id not in self._function_call_buffer: - self._function_call_buffer[call_id] = {"call_id": call_id, "name": "", "arguments": delta} - else: - self._function_call_buffer[call_id]["arguments"] += delta - return None - - elif event_type == "response.function_call_arguments.done": - call_id = openai_event.get("call_id") - if call_id and call_id in self._function_call_buffer: - function_call = self._function_call_buffer[call_id] - try: - tool_use: ToolUse = { - "toolUseId": call_id, - "name": function_call["name"], - "input": json.loads(function_call["arguments"]) if function_call["arguments"] else {}, - } - del self._function_call_buffer[call_id] - return {"toolUse": tool_use} - except (json.JSONDecodeError, KeyError) as e: - logger.warning("Error parsing function arguments for %s: %s", call_id, e) - del self._function_call_buffer[call_id] - return None - - # Voice activity detection using helper method - elif event_type == "input_audio_buffer.speech_started": - return self._create_voice_activity_event("speech_started") - elif event_type == "input_audio_buffer.speech_stopped": - return self._create_voice_activity_event("speech_stopped") - elif event_type == "input_audio_buffer.timeout_triggered": - return self._create_voice_activity_event("timeout") - - # Lifecycle events (log only) - elif event_type == "conversation.item.retrieve": - item = openai_event.get("item", {}) - logger.debug("OpenAI conversation item retrieved: %s", item.get("id")) - return None - - elif event_type == "conversation.item.added": - logger.debug("OpenAI conversation item added: %s", openai_event.get("item", {}).get("id")) - return None - - elif event_type == "conversation.item.done": - logger.debug("OpenAI conversation item done: %s", openai_event.get("item", {}).get("id")) - - item = openai_event.get("item", {}) - if item.get("type") == "message" and item.get("role") == "assistant": - content_parts = item.get("content", []) - if content_parts: - message_content = [] - for content_part in content_parts: - if content_part.get("type") == "output_text": - message_content.append({"type": "text", "text": content_part.get("text", "")}) - elif content_part.get("type") == "output_audio": - transcript = content_part.get("transcript", "") - if transcript: - message_content.append({"type": "text", "text": transcript}) - - if message_content: - message = {"role": "assistant", "content": message_content} - return {"messageStop": {"message": message}} - return None - - elif event_type in [ - "response.output_item.added", - "response.output_item.done", - "response.content_part.added", - "response.content_part.done", - ]: - item_data = openai_event.get("item") or openai_event.get("part") - logger.debug("OpenAI %s: %s", event_type, item_data.get("id") if item_data else "unknown") - - # Track function call names from response.output_item.added - if event_type == "response.output_item.added": - item = openai_event.get("item", {}) - if item.get("type") == "function_call": - call_id = item.get("call_id") - function_name = item.get("name") - if call_id and function_name: - if call_id not in self._function_call_buffer: - self._function_call_buffer[call_id] = { - "call_id": call_id, - "name": function_name, - "arguments": "", - } - else: - self._function_call_buffer[call_id]["name"] = function_name - return None - - elif event_type in [ - "input_audio_buffer.committed", - "input_audio_buffer.cleared", - "session.created", - "session.updated", - ]: - logger.debug("OpenAI %s event", event_type) - return None - - elif event_type == "error": - logger.error("OpenAI Realtime error: %s", openai_event.get("error", {})) - return None - - else: - logger.debug("Unhandled OpenAI event type: %s", event_type) - return None - - async def send_audio_content(self, audio_input: AudioInputEvent) -> None: - """Send audio content to OpenAI for processing.""" - if not self._require_active(): - return - - audio_base64 = base64.b64encode(audio_input["audioData"]).decode("utf-8") - await self._send_event({"type": "input_audio_buffer.append", "audio": audio_base64}) - - async def send_text_content(self, text: str, **kwargs) -> None: - """Send text content to OpenAI for processing.""" - if not self._require_active(): - return - - item_data = {"type": "message", "role": "user", "content": [{"type": "input_text", "text": text}]} - await self._create_conversation_item(item_data) - - async def send_interrupt(self) -> None: - """Send interruption signal to OpenAI.""" - if not self._require_active(): - return - - await self._send_event({"type": "response.cancel"}) - - async def send_tool_result(self, tool_use_id: str, result: dict[str, any]) -> None: - """Send tool result back to OpenAI.""" - if not self._require_active(): - return - - logger.debug("OpenAI tool result send: %s", tool_use_id) - result_text = json.dumps(result) if not isinstance(result, str) else result - - item_data = {"type": "function_call_output", "call_id": tool_use_id, "output": result_text} - await self._create_conversation_item(item_data) - - async def close(self) -> None: - """Close session and cleanup resources.""" - if not self._active: - return - - logger.debug("OpenAI Realtime cleanup - starting connection close") - self._active = False - - if self._response_task and not self._response_task.done(): - self._response_task.cancel() - try: - await self._response_task - except asyncio.CancelledError: - pass - - try: - await self.websocket.close() - except Exception as e: - logger.warning("Error closing OpenAI Realtime WebSocket: %s", e) - - logger.debug("OpenAI Realtime connection closed") - - async def _send_event(self, event: dict[str, any]) -> None: - """Send event to OpenAI via WebSocket.""" - try: - message = json.dumps(event) - await self.websocket.send(message) - logger.debug("Sent OpenAI event: %s", event.get("type")) - except Exception as e: - logger.error("Error sending OpenAI event: %s", e) - raise - - -class OpenAIRealtimeBidirectionalModel(BidirectionalModel): - """OpenAI Realtime API provider for Strands bidirectional streaming. - - Provides real-time audio/text communication through OpenAI's Realtime API - with WebSocket connections, voice activity detection, and function calling. - """ - - def __init__(self, model: str = DEFAULT_MODEL, api_key: str | None = None, **config: any) -> None: - """Initialize OpenAI Realtime bidirectional model.""" - self.model = model - self.api_key = api_key - self.config = config - - import os - - if not self.api_key: - self.api_key = os.getenv("OPENAI_API_KEY") - if not self.api_key: - raise ValueError( - "OpenAI API key is required. Set OPENAI_API_KEY environment variable or pass api_key parameter." - ) - - logger.debug("OpenAI Realtime bidirectional model initialized: %s", model) - - async def create_bidirectional_connection( - self, - system_prompt: str | None = None, - tools: list[ToolSpec] | None = None, - messages: Messages | None = None, - **kwargs, - ) -> BidirectionalModelSession: - """Create bidirectional connection to OpenAI Realtime API.""" - logger.info("Creating OpenAI Realtime connection...") - - try: - url = f"{OPENAI_REALTIME_URL}?model={self.model}" - - headers = [("Authorization", f"Bearer {self.api_key}")] - if "organization" in self.config: - headers.append(("OpenAI-Organization", self.config["organization"])) - if "project" in self.config: - headers.append(("OpenAI-Project", self.config["project"])) - - websocket = await websockets.connect(url, additional_headers=headers) - logger.info("WebSocket connected successfully") - - session = OpenAIRealtimeSession(websocket, self.config) - await session.initialize(system_prompt, tools, messages) - - logger.info("OpenAI Realtime connection established") - return session - - except Exception as e: - logger.error("OpenAI connection error: %s", e) - raise diff --git a/src/strands/experimental/bidirectional_streaming/tests/test_bidirectional_streaming.py b/src/strands/experimental/bidirectional_streaming/tests/test_bidi_novasonic.py similarity index 100% rename from src/strands/experimental/bidirectional_streaming/tests/test_bidirectional_streaming.py rename to src/strands/experimental/bidirectional_streaming/tests/test_bidi_novasonic.py diff --git a/src/strands/experimental/bidirectional_streaming/tests/test_bidi_openai.py b/src/strands/experimental/bidirectional_streaming/tests/test_bidi_openai.py deleted file mode 100644 index 5ce4b8cb2..000000000 --- a/src/strands/experimental/bidirectional_streaming/tests/test_bidi_openai.py +++ /dev/null @@ -1,317 +0,0 @@ -#!/usr/bin/env python3 -"""Test OpenAI Realtime API speech-to-speech interaction.""" - -import asyncio -import os -import sys -import time -from pathlib import Path - -# Add the src directory to Python path -sys.path.insert(0, str(Path(__file__).parent / "src")) - -import pyaudio -from strands_tools import calculator - -from strands.experimental.bidirectional_streaming.agent.agent import BidirectionalAgent -from strands.experimental.bidirectional_streaming.models.openai import OpenAIRealtimeBidirectionalModel - - -def test_direct_tool_calling(): - """Test direct tool calling functionality.""" - print("Testing direct tool calling...") - - try: - api_key = os.getenv("OPENAI_API_KEY") - if not api_key: - print("OPENAI_API_KEY not set - skipping test") - return - - model = OpenAIRealtimeBidirectionalModel(model="gpt-4o-realtime-preview", api_key=api_key) - agent = BidirectionalAgent(model=model, tools=[calculator]) - - # Test calculator - result = agent.tool.calculator(expression="2 * 3") - content = result.get("content", [{}])[0].get("text", "") - print(f"Result: {content}") - print("Test completed") - - except Exception as e: - print(f"Test failed: {e}") - - -async def play(context): - """Handle audio playback with interruption support.""" - audio = pyaudio.PyAudio() - - try: - speaker = audio.open( - format=pyaudio.paInt16, - channels=1, - rate=24000, # OpenAI Realtime uses 24kHz - output=True, - frames_per_buffer=1024, - ) - - while context["active"]: - try: - # Check for interruption - if context.get("interrupted", False): - # Clear audio queue on interruption - while not context["audio_out"].empty(): - try: - context["audio_out"].get_nowait() - except asyncio.QueueEmpty: - break - - context["interrupted"] = False - await asyncio.sleep(0.05) - continue - - # Get audio data with timeout - try: - audio_data = await asyncio.wait_for(context["audio_out"].get(), timeout=0.1) - - if audio_data and context["active"]: - # Play in chunks to allow interruption - chunk_size = 1024 - for i in range(0, len(audio_data), chunk_size): - if context.get("interrupted", False) or not context["active"]: - break - - chunk = audio_data[i:i + chunk_size] - speaker.write(chunk) - await asyncio.sleep(0.001) # Brief pause for responsiveness - - except asyncio.TimeoutError: - continue - - except asyncio.CancelledError: - break - - except asyncio.CancelledError: - pass - except Exception as e: - print(f"Audio playback error: {e}") - finally: - try: - speaker.close() - except Exception: - pass - audio.terminate() - - -async def record(context): - """Handle microphone recording.""" - audio = pyaudio.PyAudio() - - try: - microphone = audio.open( - format=pyaudio.paInt16, - channels=1, - rate=24000, # Match OpenAI's expected input rate - input=True, - frames_per_buffer=1024, - ) - - while context["active"]: - try: - audio_bytes = microphone.read(1024, exception_on_overflow=False) - await context["audio_in"].put(audio_bytes) - await asyncio.sleep(0.01) - except asyncio.CancelledError: - break - - except asyncio.CancelledError: - pass - except Exception as e: - print(f"Microphone recording error: {e}") - finally: - try: - microphone.close() - except Exception: - pass - audio.terminate() - - -async def receive(agent, context): - """Handle events from the agent.""" - try: - async for event in agent.receive(): - if not context["active"]: - break - - # Handle audio output - if "audioOutput" in event: - audio_data = event["audioOutput"]["audioData"] - - if not context.get("interrupted", False): - await context["audio_out"].put(audio_data) - - # Handle text output (transcripts) - elif "textOutput" in event: - text_output = event["textOutput"] - role = text_output.get("role", "assistant") - text = text_output.get("text", "").strip() - - if text: - if role == "user": - print(f"User: {text}") - elif role == "assistant": - print(f"Assistant: {text}") - - # Handle interruption detection - elif "interruptionDetected" in event: - context["interrupted"] = True - - # Handle connection events - elif "BidirectionalConnectionStart" in event: - pass # Silent connection start - elif "BidirectionalConnectionEnd" in event: - context["active"] = False - break - - except asyncio.CancelledError: - pass - except Exception as e: - print(f"Receive handler error: {e}") - finally: - pass - - -async def send(agent, context): - """Send audio from microphone to agent.""" - try: - while context["active"]: - try: - audio_bytes = await asyncio.wait_for(context["audio_in"].get(), timeout=0.1) - - # Create audio event in expected format - audio_event = { - "audioData": audio_bytes, - "format": "pcm", - "sampleRate": 24000, - "channels": 1 - } - - await agent.send(audio_event) - - except asyncio.TimeoutError: - continue - except asyncio.CancelledError: - break - - except asyncio.CancelledError: - pass - except Exception as e: - print(f"Send handler error: {e}") - finally: - pass - - -async def main(): - """Main test function for OpenAI voice chat.""" - print("Starting OpenAI Realtime API test...") - - # Check API key - api_key = os.getenv("OPENAI_API_KEY") - if not api_key: - print("OPENAI_API_KEY environment variable not set") - return False - - # Check audio system - try: - audio = pyaudio.PyAudio() - audio.terminate() - except Exception as e: - print(f"Audio system error: {e}") - return False - - # Create OpenAI model - model = OpenAIRealtimeBidirectionalModel( - model="gpt-4o-realtime-preview", - api_key=api_key, - session={ - "output_modalities": ["audio"], - "audio": { - "input": { - "format": {"type": "audio/pcm", "rate": 24000}, - "turn_detection": { - "type": "server_vad", - "threshold": 0.5, - "silence_duration_ms": 700 - } - }, - "output": { - "format": {"type": "audio/pcm", "rate": 24000}, - "voice": "alloy" - } - } - } - ) - - # Create agent - agent = BidirectionalAgent( - model=model, - tools=[calculator], - system_prompt=( - "You are a helpful voice assistant. Keep your responses brief and natural. " - "Say hello when you first connect." - ) - ) - - # Start the session - await agent.start() - - # Create shared context - context = { - "active": True, - "audio_in": asyncio.Queue(), - "audio_out": asyncio.Queue(), - "interrupted": False, - "start_time": time.time() - } - - print("Speak into your microphone. Press Ctrl+C to stop.") - - try: - # Run all tasks concurrently - await asyncio.gather( - play(context), - record(context), - receive(agent, context), - send(agent, context), - return_exceptions=True - ) - - except KeyboardInterrupt: - print("\nInterrupted by user") - except asyncio.CancelledError: - print("\nTest cancelled") - except Exception as e: - print(f"\nError during voice chat: {e}") - finally: - print("Cleaning up...") - context["active"] = False - - try: - await agent.end() - except Exception as e: - print(f"Cleanup error: {e}") - - return True - - -if __name__ == "__main__": - # Test direct tool calling first - print("OpenAI Realtime API Test Suite") - test_direct_tool_calling() - - try: - asyncio.run(main()) - except KeyboardInterrupt: - print("\nTest interrupted by user") - except Exception as e: - print(f"Test error: {e}") - import traceback - traceback.print_exc() \ No newline at end of file diff --git a/src/strands/experimental/bidirectional_streaming/types/bidirectional_streaming.py b/src/strands/experimental/bidirectional_streaming/types/bidirectional_streaming.py index 01d72356a..c0f6eb209 100644 --- a/src/strands/experimental/bidirectional_streaming/types/bidirectional_streaming.py +++ b/src/strands/experimental/bidirectional_streaming/types/bidirectional_streaming.py @@ -116,10 +116,28 @@ class BidirectionalConnectionEndEvent(TypedDict): metadata: Provider-specific connection metadata. """ - reason: Literal["user_request", "timeout", "error"] + reason: Literal["user_request", "timeout", "error", "connection_complete"] connectionId: Optional[str] metadata: Optional[Dict[str, Any]] +class UsageMetricsEvent(TypedDict): + """Token usage and performance tracking. + + Provides standardized usage metrics across providers for cost monitoring + and performance optimization. + + Attributes: + totalTokens: Total tokens used in the interaction. + inputTokens: Tokens used for input processing. + outputTokens: Tokens used for output generation. + audioTokens: Tokens used specifically for audio processing. + """ + + totalTokens: Optional[int] + inputTokens: Optional[int] + outputTokens: Optional[int] + audioTokens: Optional[int] + class BidirectionalStreamEvent(StreamEvent, total=False): """Bidirectional stream event extending existing StreamEvent. @@ -134,11 +152,14 @@ class BidirectionalStreamEvent(StreamEvent, total=False): interruptionDetected: User interruption detection. BidirectionalConnectionStart: connection start event. BidirectionalConnectionEnd: connection end event. + usageMetrics: Token usage and performance metrics. """ - audioOutput: AudioOutputEvent - audioInput: AudioInputEvent - textOutput: TextOutputEvent - interruptionDetected: InterruptionDetectedEvent - BidirectionalConnectionStart: BidirectionalConnectionStartEvent - BidirectionalConnectionEnd: BidirectionalConnectionEndEvent + audioOutput: Optional[AudioOutputEvent] + audioInput: Optional[AudioInputEvent] + textOutput: Optional[TextOutputEvent] + interruptionDetected: Optional[InterruptionDetectedEvent] + BidirectionalConnectionStart: Optional[BidirectionalConnectionStartEvent] + BidirectionalConnectionEnd: Optional[BidirectionalConnectionEndEvent] + usageMetrics: Optional[UsageMetricsEvent] + From 4679e0c803d0e7b6fad7d32ef0866309fd8b55e4 Mon Sep 17 00:00:00 2001 From: Rachit Mehta Date: Mon, 20 Oct 2025 10:03:11 -0400 Subject: [PATCH 20/23] (feat)bidirectional_streaming: add openai realtime model provider #3 --- .../models/novasonic.py | 14 ++-- .../bidirectional_streaming/models/openai.py | 64 +++++++++---------- 2 files changed, 37 insertions(+), 41 deletions(-) diff --git a/src/strands/experimental/bidirectional_streaming/models/novasonic.py b/src/strands/experimental/bidirectional_streaming/models/novasonic.py index 4e4952fa9..db21fb967 100644 --- a/src/strands/experimental/bidirectional_streaming/models/novasonic.py +++ b/src/strands/experimental/bidirectional_streaming/models/novasonic.py @@ -23,7 +23,7 @@ from aws_sdk_bedrock_runtime.client import BedrockRuntimeClient, InvokeModelWithBidirectionalStreamOperationInput from aws_sdk_bedrock_runtime.config import Config, HTTPAuthSchemeResolver, SigV4AuthScheme -from aws_sdk_bedrock_runtime.models import BidirectionalInputPayloadPart, InvokeModelWithBidirectionalStreamInputChunk +from aws_sdk_bedrock_runtime.models import BidirectionalInputPayloadPart, InvokeModelWithBidirectionalStreamInputChunk, InvokeModelWithBidirectionalStreamOperationOutput from smithy_aws_core.identity.environment import EnvironmentCredentialsResolver from ....types.content import Messages @@ -80,11 +80,11 @@ class NovaSonicSession(BidirectionalModelSession): interface. """ - def __init__(self, stream: any, config: dict[str, any]) -> None: + def __init__(self, stream: InvokeModelWithBidirectionalStreamOperationOutput, config: dict[str, any]) -> None: """Initialize Nova Sonic connection. Args: - stream: Nova Sonic bidirectional stream. + stream: Nova Sonic bidirectional stream operation output from AWS SDK. config: Model configuration. """ self.stream = stream @@ -492,10 +492,10 @@ def _convert_nova_event(self, nova_event: dict[str, any]) -> dict[str, any] | No elif "usageEvent" in nova_event: usage_data = nova_event["usageEvent"] usage_metrics: UsageMetricsEvent = { - "totalTokens": usage_data.get("totalTokens"), - "inputTokens": usage_data.get("totalInputTokens"), - "outputTokens": usage_data.get("totalOutputTokens"), - "audioTokens": usage_data.get("details", {}).get("total", {}).get("output", {}).get("speechTokens") + "totalTokens": usage_data.get("totalTokens", 0), + "inputTokens": usage_data.get("totalInputTokens", 0), + "outputTokens": usage_data.get("totalOutputTokens", 0), + "audioTokens": usage_data.get("details", {}).get("total", {}).get("output", {}).get("speechTokens", 0) } return {"usageMetrics": usage_metrics} diff --git a/src/strands/experimental/bidirectional_streaming/models/openai.py b/src/strands/experimental/bidirectional_streaming/models/openai.py index 76bf9f50d..7d009b1c7 100644 --- a/src/strands/experimental/bidirectional_streaming/models/openai.py +++ b/src/strands/experimental/bidirectional_streaming/models/openai.py @@ -89,10 +89,7 @@ def _create_voice_activity_event(self, activity_type: str) -> dict[str, any]: voice_activity: VoiceActivityEvent = {"activityType": activity_type} return {"voiceActivity": voice_activity} - async def _create_conversation_item(self, item_data: dict) -> None: - """Create conversation item and trigger response.""" - await self._send_event({"type": "conversation.item.create", "item": item_data}) - await self._send_event({"type": "response.create"}) + async def initialize( self, @@ -248,21 +245,16 @@ def _convert_openai_event(self, openai_event: dict[str, any]) -> dict[str, any] } return {"audioOutput": audio_output} - # Text output using helper method - elif event_type == "response.output_text.delta": + # Assistant text output events - combine multiple similar events + elif event_type in ["response.output_text.delta", "response.output_audio_transcript.delta"]: return self._create_text_event(openai_event["delta"], "assistant") - elif event_type == "response.output_audio_transcript.delta": - return self._create_text_event(openai_event["delta"], "assistant") - - # User transcription - elif event_type == "conversation.item.input_audio_transcription.delta": - transcript_delta = openai_event.get("delta", "") - return self._create_text_event(transcript_delta, "user") if transcript_delta.strip() else None - - elif event_type == "conversation.item.input_audio_transcription.completed": - transcript = openai_event.get("transcript", "") - return self._create_text_event(transcript, "user") if transcript.strip() else None + # User transcription events - combine multiple similar events + elif event_type in ["conversation.item.input_audio_transcription.delta", + "conversation.item.input_audio_transcription.completed"]: + text_key = "delta" if "delta" in event_type else "transcript" + text = openai_event.get(text_key, "") + return self._create_text_event(text, "user") if text.strip() else None elif event_type == "conversation.item.input_audio_transcription.segment": segment_data = openai_event.get("segment", {}) @@ -302,22 +294,22 @@ def _convert_openai_event(self, openai_event: dict[str, any]) -> dict[str, any] del self._function_call_buffer[call_id] return None - # Voice activity detection using helper method - elif event_type == "input_audio_buffer.speech_started": - return self._create_voice_activity_event("speech_started") - elif event_type == "input_audio_buffer.speech_stopped": - return self._create_voice_activity_event("speech_stopped") - elif event_type == "input_audio_buffer.timeout_triggered": - return self._create_voice_activity_event("timeout") + # Voice activity detection events - combine similar events using mapping + elif event_type in ["input_audio_buffer.speech_started", "input_audio_buffer.speech_stopped", + "input_audio_buffer.timeout_triggered"]: + # Map event types to activity types + activity_map = { + "input_audio_buffer.speech_started": "speech_started", + "input_audio_buffer.speech_stopped": "speech_stopped", + "input_audio_buffer.timeout_triggered": "timeout" + } + return self._create_voice_activity_event(activity_map[event_type]) - # Lifecycle events (log only) - elif event_type == "conversation.item.retrieve": + # Lifecycle events (log only) - combine multiple similar events + elif event_type in ["conversation.item.retrieve", "conversation.item.added"]: item = openai_event.get("item", {}) - logger.debug("OpenAI conversation item retrieved: %s", item.get("id")) - return None - - elif event_type == "conversation.item.added": - logger.debug("OpenAI conversation item added: %s", openai_event.get("item", {}).get("id")) + action = "retrieved" if "retrieve" in event_type else "added" + logger.debug("OpenAI conversation item %s: %s", action, item.get("id")) return None elif event_type == "conversation.item.done": @@ -341,6 +333,7 @@ def _convert_openai_event(self, openai_event: dict[str, any]) -> dict[str, any] return {"messageStop": {"message": message}} return None + # Response output events - combine similar events elif event_type in ["response.output_item.added", "response.output_item.done", "response.content_part.added", "response.content_part.done"]: item_data = openai_event.get("item") or openai_event.get("part") @@ -359,6 +352,7 @@ def _convert_openai_event(self, openai_event: dict[str, any]) -> dict[str, any] self._function_call_buffer[call_id]["name"] = function_name return None + # Session/buffer events - combine simple log-only events elif event_type in ["input_audio_buffer.committed", "input_audio_buffer.cleared", "session.created", "session.updated"]: logger.debug("OpenAI %s event", event_type) @@ -380,7 +374,7 @@ async def send_audio_content(self, audio_input: AudioInputEvent) -> None: audio_base64 = base64.b64encode(audio_input["audioData"]).decode("utf-8") await self._send_event({"type": "input_audio_buffer.append", "audio": audio_base64}) - async def send_text_content(self, text: str, **kwargs) -> None: + async def send_text_content(self, text: str) -> None: """Send text content to OpenAI for processing.""" if not self._require_active(): return @@ -390,7 +384,8 @@ async def send_text_content(self, text: str, **kwargs) -> None: "role": "user", "content": [{"type": "input_text", "text": text}] } - await self._create_conversation_item(item_data) + await self._send_event({"type": "conversation.item.create", "item": item_data}) + await self._send_event({"type": "response.create"}) async def send_interrupt(self) -> None: """Send interruption signal to OpenAI.""" @@ -412,7 +407,8 @@ async def send_tool_result(self, tool_use_id: str, result: dict[str, any]) -> No "call_id": tool_use_id, "output": result_text } - await self._create_conversation_item(item_data) + await self._send_event({"type": "conversation.item.create", "item": item_data}) + await self._send_event({"type": "response.create"}) async def close(self) -> None: """Close session and cleanup resources.""" From 4648327c2b4dce49b3f25561012bdf76ab7cb72d Mon Sep 17 00:00:00 2001 From: Murat Kaan Meral Date: Wed, 29 Oct 2025 15:52:11 +0100 Subject: [PATCH 21/23] feat(gemini): Add bidirectional gemini model --- .../bidirectional_streaming/agent/agent.py | 30 +- .../models/__init__.py | 11 +- .../models/bidirectional_model.py | 11 +- .../models/gemini_live.py | 499 ++++++++++++++++++ .../tests/test_gemini_live.py | 359 +++++++++++++ .../bidirectional_streaming/types/__init__.py | 4 + .../types/bidirectional_streaming.py | 38 ++ 7 files changed, 933 insertions(+), 19 deletions(-) create mode 100644 src/strands/experimental/bidirectional_streaming/models/gemini_live.py create mode 100644 src/strands/experimental/bidirectional_streaming/tests/test_gemini_live.py diff --git a/src/strands/experimental/bidirectional_streaming/agent/agent.py b/src/strands/experimental/bidirectional_streaming/agent/agent.py index 6f8360ade..820a6c490 100644 --- a/src/strands/experimental/bidirectional_streaming/agent/agent.py +++ b/src/strands/experimental/bidirectional_streaming/agent/agent.py @@ -31,7 +31,7 @@ from ....types.traces import AttributeValue from ..event_loop.bidirectional_event_loop import start_bidirectional_connection, stop_bidirectional_connection from ..models.bidirectional_model import BidirectionalModel -from ..types.bidirectional_streaming import AudioInputEvent, BidirectionalStreamEvent +from ..types.bidirectional_streaming import AudioInputEvent, BidirectionalStreamEvent, ImageInputEvent logger = logging.getLogger(__name__) @@ -359,18 +359,16 @@ async def start(self) -> None: logger.debug("Conversation start - initializing session") self._session = await start_bidirectional_connection(self) - logger.debug("Conversation ready") - - async def send(self, input_data: str | AudioInputEvent) -> None: - """Send input to the model (text or audio). - - Unified method for sending both text and audio input to the model during - an active conversation session. User input is automatically added to - conversation history for complete message tracking. - + + async def send(self, input_data: str | AudioInputEvent | ImageInputEvent) -> None: + """Send input to the model (text, audio, or image). + + Unified method for sending text, audio, and image input to the model during + an active conversation session. + Args: - input_data: Either a string for text input or AudioInputEvent for audio input. - + input_data: String for text, AudioInputEvent for audio, or ImageInputEvent for images. + Raises: ValueError: If no active session or invalid input type. """ @@ -385,10 +383,14 @@ async def send(self, input_data: str | AudioInputEvent) -> None: elif isinstance(input_data, dict) and "audioData" in input_data: # Handle audio input await self._session.model_session.send_audio_content(input_data) + elif isinstance(input_data, dict) and "imageData" in input_data: + # Handle image input (ImageInputEvent) + await self._session.model_session.send_image_content(input_data) else: raise ValueError( - "Input must be either a string (text) or AudioInputEvent " - "(dict with audioData, format, sampleRate, channels)" + "Input must be either a string (text), AudioInputEvent " + "(dict with audioData, format, sampleRate, channels), or ImageInputEvent " + "(dict with imageData, mimeType, encoding)" ) async def receive(self) -> AsyncIterable[BidirectionalStreamEvent]: diff --git a/src/strands/experimental/bidirectional_streaming/models/__init__.py b/src/strands/experimental/bidirectional_streaming/models/__init__.py index 67254d4fe..c5287d15d 100644 --- a/src/strands/experimental/bidirectional_streaming/models/__init__.py +++ b/src/strands/experimental/bidirectional_streaming/models/__init__.py @@ -1,14 +1,17 @@ """Bidirectional model interfaces and implementations.""" from .bidirectional_model import BidirectionalModel, BidirectionalModelSession +from .gemini_live import GeminiLiveBidirectionalModel, GeminiLiveSession from .novasonic import NovaSonicBidirectionalModel, NovaSonicSession from .openai import OpenAIRealtimeBidirectionalModel, OpenAIRealtimeSession __all__ = [ - "BidirectionalModel", - "BidirectionalModelSession", - "NovaSonicBidirectionalModel", + "BidirectionalModel", + "BidirectionalModelSession", + "GeminiLiveBidirectionalModel", + "GeminiLiveSession", + "NovaSonicBidirectionalModel", "NovaSonicSession", "OpenAIRealtimeBidirectionalModel", - "OpenAIRealtimeSession" + "OpenAIRealtimeSession", ] diff --git a/src/strands/experimental/bidirectional_streaming/models/bidirectional_model.py b/src/strands/experimental/bidirectional_streaming/models/bidirectional_model.py index d5c3c9b65..42485561b 100644 --- a/src/strands/experimental/bidirectional_streaming/models/bidirectional_model.py +++ b/src/strands/experimental/bidirectional_streaming/models/bidirectional_model.py @@ -17,7 +17,7 @@ from ....types.content import Messages from ....types.tools import ToolSpec -from ..types.bidirectional_streaming import AudioInputEvent, BidirectionalStreamEvent +from ..types.bidirectional_streaming import AudioInputEvent, BidirectionalStreamEvent, ImageInputEvent logger = logging.getLogger(__name__) @@ -48,6 +48,15 @@ async def send_audio_content(self, audio_input: AudioInputEvent) -> None: """ raise NotImplementedError + # TODO: remove with interface unification + async def send_image_content(self, image_input: ImageInputEvent) -> None: + """Send image content to the model during an active connection. + + Handles image encoding and provider-specific formatting while presenting + a simple ImageInputEvent interface. + """ + raise NotImplementedError + @abc.abstractmethod async def send_text_content(self, text: str, **kwargs) -> None: """Send text content to the model during ongoing generation. diff --git a/src/strands/experimental/bidirectional_streaming/models/gemini_live.py b/src/strands/experimental/bidirectional_streaming/models/gemini_live.py new file mode 100644 index 000000000..64c4d7348 --- /dev/null +++ b/src/strands/experimental/bidirectional_streaming/models/gemini_live.py @@ -0,0 +1,499 @@ +"""Gemini Live API bidirectional model provider using official Google GenAI SDK. + +Implements the BidirectionalModel interface for Google's Gemini Live API using the +official Google GenAI SDK for simplified and robust WebSocket communication. + +Key improvements over custom WebSocket implementation: +- Uses official google-genai SDK with native Live API support +- Simplified session management with client.aio.live.connect() +- Built-in tool integration and event handling +- Automatic WebSocket connection management and error handling +- Native support for audio/text streaming and interruption +""" + +import asyncio +import base64 +import logging +import uuid +from typing import Any, AsyncIterable, Dict, List, Optional + +from google import genai +from google.genai import types as genai_types +from google.genai.types import LiveServerMessage, LiveServerContent + +from ....types.content import Messages +from ....types.tools import ToolSpec, ToolUse +from ..types.bidirectional_streaming import ( + AudioInputEvent, + AudioOutputEvent, + BidirectionalConnectionEndEvent, + BidirectionalConnectionStartEvent, + ImageInputEvent, + InterruptionDetectedEvent, + TextOutputEvent, + TranscriptEvent, +) +from .bidirectional_model import BidirectionalModel, BidirectionalModelSession + +logger = logging.getLogger(__name__) + +# Audio format constants +GEMINI_INPUT_SAMPLE_RATE = 16000 +GEMINI_OUTPUT_SAMPLE_RATE = 24000 +GEMINI_CHANNELS = 1 + + +class GeminiLiveSession(BidirectionalModelSession): + """Gemini Live API session using official Google GenAI SDK. + + Provides a clean interface to Gemini Live API using the official SDK, + eliminating custom WebSocket handling and providing robust error handling. + """ + + def __init__(self, client: genai.Client, model_id: str, config: Dict[str, Any]): + """Initialize Gemini Live API session. + + Args: + client: Gemini client instance + model_id: Model identifier + config: Model configuration including live config + """ + self.client = client + self.model_id = model_id + self.config = config + self.session_id = str(uuid.uuid4()) + self._active = True + self.live_session = None + self.live_session_cm = None + + + + async def initialize( + self, + system_prompt: Optional[str] = None, + tools: Optional[List[ToolSpec]] = None, + messages: Optional[Messages] = None + ) -> None: + """Initialize Gemini Live API session by creating the connection.""" + + try: + # Build live config + live_config = self.config.get("live_config") + + if live_config is None: + raise ValueError("live_config is required but not found in session config") + + # Create the context manager + self.live_session_cm = self.client.aio.live.connect( + model=self.model_id, + config=live_config + ) + + # Enter the context manager + self.live_session = await self.live_session_cm.__aenter__() + + # Send initial message history if provided + if messages: + await self._send_message_history(messages) + + + except Exception as e: + logger.error("Error initializing Gemini Live session: %s", e) + raise + + async def _send_message_history(self, messages: Messages) -> None: + """Send conversation history to Gemini Live API. + + Sends each message as a separate turn with the correct role to maintain + proper conversation context. Follows the same pattern as the non-bidirectional + Gemini model implementation. + """ + if not messages: + return + + # Convert each message to Gemini format and send separately + for message in messages: + content_parts = [] + for content_block in message["content"]: + if "text" in content_block: + content_parts.append(genai_types.Part(text=content_block["text"])) + + if content_parts: + # Map role correctly - Gemini uses "user" and "model" roles + # "assistant" role from Messages format maps to "model" in Gemini + role = "model" if message["role"] == "assistant" else message["role"] + content = genai_types.Content(role=role, parts=content_parts) + await self.live_session.send_client_content(turns=content) + + async def receive_events(self) -> AsyncIterable[Dict[str, Any]]: + """Receive Gemini Live API events and convert to provider-agnostic format.""" + + # Emit connection start event + connection_start: BidirectionalConnectionStartEvent = { + "connectionId": self.session_id, + "metadata": {"provider": "gemini_live", "model_id": self.config.get("model_id")} + } + yield {"BidirectionalConnectionStart": connection_start} + + try: + # Wrap in while loop to restart after turn_complete (SDK limitation workaround) + while self._active: + try: + async for message in self.live_session.receive(): + if not self._active: + break + + # Convert to provider-agnostic format + provider_event = self._convert_gemini_live_event(message) + if provider_event: + yield provider_event + + # SDK exits receive loop after turn_complete - restart automatically + if self._active: + logger.debug("Restarting receive loop after turn completion") + + except Exception as e: + logger.error("Error in receive iteration: %s", e) + # Small delay before retrying to avoid tight error loops + await asyncio.sleep(0.1) + + except Exception as e: + logger.error("Fatal error in receive loop: %s", e) + finally: + # Emit connection end event when exiting + connection_end: BidirectionalConnectionEndEvent = { + "connectionId": self.session_id, + "reason": "connection_complete", + "metadata": {"provider": "gemini_live"} + } + yield {"BidirectionalConnectionEnd": connection_end} + + def _convert_gemini_live_event(self, message: LiveServerMessage) -> Optional[Dict[str, Any]]: + """Convert Gemini Live API events to provider-agnostic format. + + Handles different types of text output: + - inputTranscription: User's speech transcribed to text (emitted as transcript event) + - outputTranscription: Model's audio transcribed to text (emitted as transcript event) + - modelTurn text: Actual text response from the model (emitted as textOutput) + """ + try: + # Handle interruption first (from server_content) + if message.server_content and message.server_content.interrupted: + interruption: InterruptionDetectedEvent = { + "reason": "user_input" + } + return {"interruptionDetected": interruption} + + # Handle input transcription (user's speech) - emit as transcript event + if message.server_content and message.server_content.input_transcription: + input_transcript = message.server_content.input_transcription + # Check if the transcription object has text content + if hasattr(input_transcript, 'text') and input_transcript.text: + transcription_text = input_transcript.text + logger.debug(f"Input transcription detected: {transcription_text}") + transcript: TranscriptEvent = { + "text": transcription_text, + "role": "user", + "type": "input" + } + return {"transcript": transcript} + + # Handle output transcription (model's audio) - emit as transcript event + if message.server_content and message.server_content.output_transcription: + output_transcript = message.server_content.output_transcription + # Check if the transcription object has text content + if hasattr(output_transcript, 'text') and output_transcript.text: + transcription_text = output_transcript.text + logger.debug(f"Output transcription detected: {transcription_text}") + transcript: TranscriptEvent = { + "text": transcription_text, + "role": "assistant", + "type": "output" + } + return {"transcript": transcript} + + # Handle actual text output from model (not transcription) + # The SDK's message.text property accesses modelTurn.parts[].text + if message.text: + text_output: TextOutputEvent = { + "text": message.text, + "role": "assistant" + } + return {"textOutput": text_output} + + # Handle audio output using SDK's built-in data property + if message.data: + audio_output: AudioOutputEvent = { + "audioData": message.data, + "format": "pcm", + "sampleRate": GEMINI_OUTPUT_SAMPLE_RATE, + "channels": GEMINI_CHANNELS, + "encoding": "raw" + } + return {"audioOutput": audio_output} + + # Handle tool calls + if message.tool_call and message.tool_call.function_calls: + for func_call in message.tool_call.function_calls: + tool_use_event: ToolUse = { + "toolUseId": func_call.id, + "name": func_call.name, + "input": func_call.args or {} + } + return {"toolUse": tool_use_event} + + # Silently ignore setup_complete, turn_complete, generation_complete, and usage_metadata messages + return None + + except Exception as e: + logger.error("Error converting Gemini Live event: %s", e) + logger.error("Message type: %s", type(message).__name__) + logger.error("Message attributes: %s", [attr for attr in dir(message) if not attr.startswith('_')]) + return None + + async def send_audio_content(self, audio_input: AudioInputEvent) -> None: + """Send audio content using Gemini Live API. + + Gemini Live expects continuous audio streaming via send_realtime_input. + This automatically triggers VAD and can interrupt ongoing responses. + """ + if not self._active: + return + + try: + # Create audio blob for the SDK + audio_blob = genai_types.Blob( + data=audio_input["audioData"], + mime_type=f"audio/pcm;rate={GEMINI_INPUT_SAMPLE_RATE}" + ) + + # Send real-time audio input - this automatically handles VAD and interruption + await self.live_session.send_realtime_input(audio=audio_blob) + + except Exception as e: + logger.error("Error sending audio content: %s", e) + + async def send_image_content(self, image_input: ImageInputEvent) -> None: + """Send image content using Gemini Live API. + + Sends image frames following the same pattern as the GitHub example. + Images are sent as base64-encoded data with MIME type. + """ + if not self._active: + return + + try: + # Prepare the message based on encoding + if image_input["encoding"] == "base64": + # Data is already base64 encoded + if isinstance(image_input["imageData"], bytes): + data_str = image_input["imageData"].decode() + else: + data_str = image_input["imageData"] + else: + # Raw bytes - need to base64 encode + data_str = base64.b64encode(image_input["imageData"]).decode() + + # Create the message in the format expected by Gemini Live + msg = { + "mime_type": image_input["mimeType"], + "data": data_str + } + + # Send using the same method as the GitHub example + await self.live_session.send(input=msg) + + except Exception as e: + logger.error("Error sending image content: %s", e) + + async def send_text_content(self, text: str, **kwargs) -> None: + """Send text content using Gemini Live API.""" + if not self._active: + return + + try: + # Create content with text + content = genai_types.Content( + role="user", + parts=[genai_types.Part(text=text)] + ) + + # Send as client content + await self.live_session.send_client_content(turns=content) + + except Exception as e: + logger.error("Error sending text content: %s", e) + + async def send_interrupt(self) -> None: + """Send interruption signal to Gemini Live API. + + Gemini Live uses automatic VAD-based interruption. When new audio input + is detected, it automatically interrupts the ongoing generation. + We don't need to send explicit interrupt signals like Nova Sonic. + """ + if not self._active: + return + + try: + # Gemini Live handles interruption automatically through VAD + # When new audio input is sent via send_realtime_input, it automatically + # interrupts any ongoing generation. No explicit interrupt signal needed. + logger.debug("Interrupt requested - Gemini Live handles this automatically via VAD") + + except Exception as e: + logger.error("Error in interrupt handling: %s", e) + + async def send_tool_result(self, tool_use_id: str, result: Dict[str, Any]) -> None: + """Send tool result using Gemini Live API.""" + if not self._active: + return + + try: + # Create function response + func_response = genai_types.FunctionResponse( + id=tool_use_id, + name=tool_use_id, # Gemini uses name as identifier + response=result + ) + + # Send tool response + await self.live_session.send_tool_response(function_responses=[func_response]) + except Exception as e: + logger.error("Error sending tool result: %s", e) + + async def send_tool_error(self, tool_use_id: str, error: str) -> None: + """Send tool error using Gemini Live API.""" + error_result = {"error": error} + await self.send_tool_result(tool_use_id, error_result) + + async def close(self) -> None: + """Close Gemini Live API connection.""" + if not self._active: + return + + self._active = False + + try: + # Exit the context manager properly + if self.live_session_cm: + await self.live_session_cm.__aexit__(None, None, None) + except Exception as e: + logger.error("Error closing Gemini Live session: %s", e) + raise + + +class GeminiLiveBidirectionalModel(BidirectionalModel): + """Gemini Live API model implementation using official Google GenAI SDK. + + Provides access to Google's Gemini Live API through the bidirectional + streaming interface, using the official SDK for robust and simple integration. + """ + + def __init__( + self, + model_id: str = "models/gemini-2.0-flash-live-preview-04-09", + api_key: Optional[str] = None, + **config + ): + """Initialize Gemini Live API bidirectional model. + + Args: + model_id: Gemini Live model identifier. + api_key: Google AI API key for authentication. + **config: Additional configuration. + """ + self.model_id = model_id + self.api_key = api_key + self.config = config + + # Create Gemini client with proper API version + client_kwargs = {} + if api_key: + client_kwargs["api_key"] = api_key + + # Use v1alpha for Live API as it has better model support + client_kwargs["http_options"] = {"api_version": "v1alpha"} + + self.client = genai.Client(**client_kwargs) + + async def create_bidirectional_connection( + self, + system_prompt: Optional[str] = None, + tools: Optional[List[ToolSpec]] = None, + messages: Optional[Messages] = None, + **kwargs + ) -> BidirectionalModelSession: + """Create Gemini Live API bidirectional connection using official SDK.""" + + try: + # Build configuration + live_config = self._build_live_config(system_prompt, tools, **kwargs) + + # Create session config + session_config = self._get_session_config() + session_config["live_config"] = live_config + + # Create and initialize session wrapper + session = GeminiLiveSession(self.client, self.model_id, session_config) + await session.initialize(system_prompt, tools, messages) + + return session + + except Exception as e: + logger.error("Failed to create Gemini Live connection: %s", e) + raise + + def _build_live_config( + self, + system_prompt: Optional[str] = None, + tools: Optional[List[ToolSpec]] = None, + **kwargs + ) -> Dict[str, Any]: + """Build LiveConnectConfig for the official SDK. + + Simply passes through all config parameters from params, allowing users + to configure any Gemini Live API parameter directly. + """ + # Start with user config from params + config_dict = {} + if "params" in self.config: + config_dict.update(self.config["params"]) + + # Override with any kwargs + config_dict.update(kwargs) + + # Add system instruction if provided + if system_prompt: + config_dict["system_instruction"] = system_prompt + + # Add tools if provided + if tools: + config_dict["tools"] = self._format_tools_for_live_api(tools) + + return config_dict + + def _format_tools_for_live_api(self, tool_specs: List[ToolSpec]) -> List[genai_types.Tool]: + """Format tool specs for Gemini Live API.""" + if not tool_specs: + return [] + + return [ + genai_types.Tool( + function_declarations=[ + genai_types.FunctionDeclaration( + description=tool_spec["description"], + name=tool_spec["name"], + parameters_json_schema=tool_spec["inputSchema"]["json"], + ) + for tool_spec in tool_specs + ], + ), + ] + + def _get_session_config(self) -> Dict[str, Any]: + """Get session configuration for Gemini Live API.""" + return { + "model_id": self.model_id, + "params": self.config.get("params"), + **self.config + } \ No newline at end of file diff --git a/src/strands/experimental/bidirectional_streaming/tests/test_gemini_live.py b/src/strands/experimental/bidirectional_streaming/tests/test_gemini_live.py new file mode 100644 index 000000000..4469e819a --- /dev/null +++ b/src/strands/experimental/bidirectional_streaming/tests/test_gemini_live.py @@ -0,0 +1,359 @@ +"""Test suite for Gemini Live bidirectional streaming with camera support. + +Tests the Gemini Live API with real-time audio and video interaction including: +- Audio input/output streaming +- Camera frame capture and transmission +- Interruption handling +- Concurrent tool execution +- Transcript events + +Requirements: +- pip install opencv-python pillow pyaudio google-genai +- Camera access permissions +- GOOGLE_AI_API_KEY environment variable +""" + +import asyncio +import base64 +import io +import logging +import os +import sys +from pathlib import Path + +# Add the src directory to Python path +sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent.parent)) +import time + +try: + import cv2 + import PIL.Image + CAMERA_AVAILABLE = True +except ImportError as e: + print(f"Camera dependencies not available: {e}") + print("Install with: pip install opencv-python pillow") + CAMERA_AVAILABLE = False + +import pyaudio +from strands_tools import calculator + +from strands.experimental.bidirectional_streaming.agent.agent import BidirectionalAgent +from strands.experimental.bidirectional_streaming.models.gemini_live import GeminiLiveBidirectionalModel + +# Configure logging - debug only for Gemini Live, info for everything else +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') +gemini_logger = logging.getLogger('strands.experimental.bidirectional_streaming.models.gemini_live') +gemini_logger.setLevel(logging.DEBUG) +logger = logging.getLogger(__name__) + + +async def play(context): + """Play audio output with responsive interruption support.""" + audio = pyaudio.PyAudio() + speaker = audio.open( + channels=1, + format=pyaudio.paInt16, + output=True, + rate=24000, + frames_per_buffer=1024, + ) + + try: + while context["active"]: + try: + # Check for interruption first + if context.get("interrupted", False): + # Clear entire audio queue immediately + while not context["audio_out"].empty(): + try: + context["audio_out"].get_nowait() + except asyncio.QueueEmpty: + break + + context["interrupted"] = False + await asyncio.sleep(0.05) + continue + + # Get next audio data + audio_data = await asyncio.wait_for(context["audio_out"].get(), timeout=0.1) + + if audio_data and context["active"]: + chunk_size = 1024 + for i in range(0, len(audio_data), chunk_size): + # Check for interruption before each chunk + if context.get("interrupted", False) or not context["active"]: + break + + end = min(i + chunk_size, len(audio_data)) + chunk = audio_data[i:end] + speaker.write(chunk) + await asyncio.sleep(0.001) + + except asyncio.TimeoutError: + continue # No audio available + except asyncio.QueueEmpty: + await asyncio.sleep(0.01) + except asyncio.CancelledError: + break + + except asyncio.CancelledError: + pass + finally: + speaker.close() + audio.terminate() + + +async def record(context): + """Record audio input from microphone.""" + audio = pyaudio.PyAudio() + + # List all available audio devices + print("Available audio devices:") + for i in range(audio.get_device_count()): + device_info = audio.get_device_info_by_index(i) + if device_info['maxInputChannels'] > 0: # Only show input devices + print(f" Device {i}: {device_info['name']} (inputs: {device_info['maxInputChannels']})") + + # Get default input device info + default_device = audio.get_default_input_device_info() + print(f"\nUsing default input device: {default_device['name']} (Device {default_device['index']})") + + microphone = audio.open( + channels=1, + format=pyaudio.paInt16, + frames_per_buffer=1024, + input=True, + rate=16000, + ) + + try: + while context["active"]: + try: + audio_bytes = microphone.read(1024, exception_on_overflow=False) + context["audio_in"].put_nowait(audio_bytes) + await asyncio.sleep(0.01) + except asyncio.CancelledError: + break + except asyncio.CancelledError: + pass + finally: + microphone.close() + audio.terminate() + + +async def receive(agent, context): + """Receive and process events from agent.""" + try: + async for event in agent.receive(): + # Debug: Log all event types + event_types = [k for k in event.keys() if not k.startswith('_')] + if event_types: + logger.debug(f"Received event types: {event_types}") + + # Handle audio output + if "audioOutput" in event: + if not context.get("interrupted", False): + context["audio_out"].put_nowait(event["audioOutput"]["audioData"]) + + # Handle interruption events + elif "interruptionDetected" in event: + context["interrupted"] = True + elif "interrupted" in event: + context["interrupted"] = True + + # Handle text output + elif "textOutput" in event: + text_content = event["textOutput"].get("text", "") + role = event["textOutput"].get("role", "unknown") + + # Check for text-based interruption patterns + if '{ "interrupted" : true }' in text_content: + context["interrupted"] = True + elif "interrupted" in text_content.lower(): + context["interrupted"] = True + + # Log text output + if role.upper() == "USER": + print(f"User: {text_content}") + elif role.upper() == "ASSISTANT": + print(f"Assistant: {text_content}") + + # Handle transcript events (audio transcriptions) + elif "transcript" in event: + transcript_text = event["transcript"].get("text", "") + transcript_role = event["transcript"].get("role", "unknown") + transcript_type = event["transcript"].get("type", "unknown") + + # Print transcripts with special formatting to distinguish from text output + if transcript_role.upper() == "USER": + print(f"🎤 User (transcript): {transcript_text}") + elif transcript_role.upper() == "ASSISTANT": + print(f"🔊 Assistant (transcript): {transcript_text}") + + # Handle turn complete events + elif "turnComplete" in event: + logger.debug("Turn complete event received - model ready for next input") + # Reset interrupted state since the turn is complete + context["interrupted"] = False + + except asyncio.CancelledError: + pass + + +def _get_frame(cap): + """Capture and process a frame from camera.""" + if not CAMERA_AVAILABLE: + return None + + # Read the frame + ret, frame = cap.read() + # Check if the frame was read successfully + if not ret: + return None + # Convert BGR to RGB color space + # OpenCV captures in BGR but PIL expects RGB format + # This prevents the blue tint in the video feed + frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + img = PIL.Image.fromarray(frame_rgb) + img.thumbnail([1024, 1024]) + + image_io = io.BytesIO() + img.save(image_io, format="jpeg") + image_io.seek(0) + + mime_type = "image/jpeg" + image_bytes = image_io.read() + return {"mime_type": mime_type, "data": base64.b64encode(image_bytes).decode()} + + +async def get_frames(context): + """Capture frames from camera and send to agent.""" + if not CAMERA_AVAILABLE: + print("Camera not available - skipping video capture") + return + + # This takes about a second, and will block the whole program + # causing the audio pipeline to overflow if you don't to_thread it. + cap = await asyncio.to_thread(cv2.VideoCapture, 0) # 0 represents the default camera + + print("Camera initialized. Starting video capture...") + + try: + while context["active"] and time.time() - context["start_time"] < context["duration"]: + frame = await asyncio.to_thread(_get_frame, cap) + if frame is None: + break + + # Send frame to agent as image input + try: + image_event = { + "imageData": frame["data"], + "mimeType": frame["mime_type"], + "encoding": "base64" + } + await context["agent"].send(image_event) + print("📸 Frame sent to model") + except Exception as e: + logger.error(f"Error sending frame: {e}") + + # Wait 1 second between frames (1 FPS) + await asyncio.sleep(1.0) + + except asyncio.CancelledError: + pass + finally: + # Release the VideoCapture object + cap.release() + + +async def send(agent, context): + """Send audio input to agent.""" + try: + while time.time() - context["start_time"] < context["duration"]: + try: + audio_bytes = context["audio_in"].get_nowait() + audio_event = {"audioData": audio_bytes, "format": "pcm", "sampleRate": 16000, "channels": 1} + await agent.send(audio_event) + except asyncio.QueueEmpty: + await asyncio.sleep(0.01) + except asyncio.CancelledError: + break + + context["active"] = False + except asyncio.CancelledError: + pass + + +async def main(duration=180): + """Main function for Gemini Live bidirectional streaming test with camera support.""" + print("Starting Gemini Live bidirectional streaming test with camera...") + print("Audio optimizations: 1024-byte buffers, balanced smooth playback + responsive interruption") + print("Video: Camera frames sent at 1 FPS to model") + + # Get API key from environment variable + api_key = os.getenv("GOOGLE_AI_API_KEY") + + if not api_key: + print("ERROR: GOOGLE_AI_API_KEY environment variable not set") + print("Please set it with: export GOOGLE_AI_API_KEY=your_api_key") + return + + # Initialize Gemini Live model with proper configuration + logger.info("Initializing Gemini Live model with API key") + + model = GeminiLiveBidirectionalModel( + model_id="gemini-2.5-flash-native-audio-preview-09-2025", + api_key=api_key, + params={ + "response_modalities": ["AUDIO"], + "output_audio_transcription": {}, # Enable output transcription + "input_audio_transcription": {} # Enable input transcription + } + ) + logger.info("Gemini Live model initialized successfully") + print("Using Gemini Live model") + + agent = BidirectionalAgent( + model=model, + tools=[calculator], + system_prompt="You are a helpful assistant." + ) + + await agent.start() + + # Create shared context for all tasks + context = { + "active": True, + "audio_in": asyncio.Queue(), + "audio_out": asyncio.Queue(), + "connection": agent._session, + "duration": duration, + "start_time": time.time(), + "interrupted": False, + "agent": agent, # Add agent reference for camera task + } + + print("Speak into microphone and show things to camera. Press Ctrl+C to exit.") + + try: + # Run all tasks concurrently including camera + await asyncio.gather( + play(context), + record(context), + receive(agent, context), + send(agent, context), + get_frames(context), # Add camera task + return_exceptions=True + ) + except KeyboardInterrupt: + print("\nInterrupted by user") + except asyncio.CancelledError: + print("\nTest cancelled") + finally: + print("Cleaning up...") + context["active"] = False + await agent.end() + + +if __name__ == "__main__": + asyncio.run(main()) \ No newline at end of file diff --git a/src/strands/experimental/bidirectional_streaming/types/__init__.py b/src/strands/experimental/bidirectional_streaming/types/__init__.py index 412061146..d040ee436 100644 --- a/src/strands/experimental/bidirectional_streaming/types/__init__.py +++ b/src/strands/experimental/bidirectional_streaming/types/__init__.py @@ -11,8 +11,10 @@ BidirectionalConnectionEndEvent, BidirectionalConnectionStartEvent, BidirectionalStreamEvent, + ImageInputEvent, InterruptionDetectedEvent, TextOutputEvent, + TranscriptEvent, UsageMetricsEvent, VoiceActivityEvent, ) @@ -23,8 +25,10 @@ "BidirectionalConnectionEndEvent", "BidirectionalConnectionStartEvent", "BidirectionalStreamEvent", + "ImageInputEvent", "InterruptionDetectedEvent", "TextOutputEvent", + "TranscriptEvent", "UsageMetricsEvent", "VoiceActivityEvent", "SUPPORTED_AUDIO_FORMATS", diff --git a/src/strands/experimental/bidirectional_streaming/types/bidirectional_streaming.py b/src/strands/experimental/bidirectional_streaming/types/bidirectional_streaming.py index 4aa720b20..4b215d74e 100644 --- a/src/strands/experimental/bidirectional_streaming/types/bidirectional_streaming.py +++ b/src/strands/experimental/bidirectional_streaming/types/bidirectional_streaming.py @@ -71,6 +71,23 @@ class AudioInputEvent(TypedDict): channels: Literal[1, 2] +class ImageInputEvent(TypedDict): + """Image input event for sending images/video frames to the model. + + Used for sending image data through the send() method. Supports both + raw image bytes and base64-encoded data. + + Attributes: + imageData: Image bytes (raw or base64-encoded string). + mimeType: MIME type (e.g., "image/jpeg", "image/png"). + encoding: How the imageData is encoded. + """ + + imageData: bytes | str + mimeType: str + encoding: Literal["base64", "raw"] + + class TextOutputEvent(TypedDict): """Text output event from the model during bidirectional streaming. @@ -83,6 +100,23 @@ class TextOutputEvent(TypedDict): role: Role +class TranscriptEvent(TypedDict): + """Transcript event for audio transcriptions. + + Used for both input transcriptions (user speech) and output transcriptions + (model audio). These are informational and separate from actual text responses. + + Attributes: + text: The transcribed text. + role: The role of the speaker ("user" or "assistant"). + type: Type of transcription ("input" or "output"). + """ + + text: str + role: Role + type: Literal["input", "output"] + + class InterruptionDetectedEvent(TypedDict): """Interruption detection event. @@ -180,7 +214,9 @@ class BidirectionalStreamEvent(StreamEvent, total=False): Attributes: audioOutput: Audio output from the model. audioInput: Audio input sent to the model. + imageInput: Image input sent to the model. textOutput: Text output from the model. + transcript: Audio transcription (input or output). interruptionDetected: User interruption detection. BidirectionalConnectionStart: connection start event. BidirectionalConnectionEnd: connection end event. @@ -190,7 +226,9 @@ class BidirectionalStreamEvent(StreamEvent, total=False): audioOutput: Optional[AudioOutputEvent] audioInput: Optional[AudioInputEvent] + imageInput: Optional[ImageInputEvent] textOutput: Optional[TextOutputEvent] + transcript: Optional[TranscriptEvent] interruptionDetected: Optional[InterruptionDetectedEvent] BidirectionalConnectionStart: Optional[BidirectionalConnectionStartEvent] BidirectionalConnectionEnd: Optional[BidirectionalConnectionEndEvent] From 0fc511025dc13dca32a1d86c5e6abe8e348836a7 Mon Sep 17 00:00:00 2001 From: Jack Yuan Date: Wed, 29 Oct 2025 16:48:21 -0400 Subject: [PATCH 22/23] feat: model interface change. --- .../bidirectional_streaming/__init__.py | 2 +- .../bidirectional_streaming/agent/agent.py | 2 +- .../event_loop/bidirectional_event_loop.py | 2 +- .../models/__init__.py | 2 +- .../models/base_model.py | 53 ++++++++ .../models/bidirectional_model.py | 113 ------------------ .../models/gemini_live.py | 2 +- .../models/novasonic.py | 2 +- .../bidirectional_streaming/models/openai.py | 2 +- .../types/bidirectional_streaming.py | 10 ++ 10 files changed, 70 insertions(+), 120 deletions(-) create mode 100644 src/strands/experimental/bidirectional_streaming/models/base_model.py delete mode 100644 src/strands/experimental/bidirectional_streaming/models/bidirectional_model.py diff --git a/src/strands/experimental/bidirectional_streaming/__init__.py b/src/strands/experimental/bidirectional_streaming/__init__.py index 3c47dd957..695393dd8 100644 --- a/src/strands/experimental/bidirectional_streaming/__init__.py +++ b/src/strands/experimental/bidirectional_streaming/__init__.py @@ -4,7 +4,7 @@ from .agent.agent import BidirectionalAgent # Advanced interfaces (for custom implementations) -from .models.bidirectional_model import BidirectionalModel, BidirectionalModelSession +from .models.base_model import BidirectionalModel, BidirectionalModelSession # Model providers - What users need to create models from .models.novasonic import NovaSonicBidirectionalModel diff --git a/src/strands/experimental/bidirectional_streaming/agent/agent.py b/src/strands/experimental/bidirectional_streaming/agent/agent.py index 820a6c490..d0f9807bd 100644 --- a/src/strands/experimental/bidirectional_streaming/agent/agent.py +++ b/src/strands/experimental/bidirectional_streaming/agent/agent.py @@ -30,7 +30,7 @@ from ....types.tools import ToolResult, ToolUse from ....types.traces import AttributeValue from ..event_loop.bidirectional_event_loop import start_bidirectional_connection, stop_bidirectional_connection -from ..models.bidirectional_model import BidirectionalModel +from ..models.base_model import BaseModel from ..types.bidirectional_streaming import AudioInputEvent, BidirectionalStreamEvent, ImageInputEvent logger = logging.getLogger(__name__) diff --git a/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py b/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py index bbf5fb425..f82b5911d 100644 --- a/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py +++ b/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py @@ -21,7 +21,7 @@ from ....types._events import ToolResultEvent, ToolStreamEvent from ....types.content import Message from ....types.tools import ToolResult, ToolUse -from ..models.bidirectional_model import BidirectionalModelSession +from ..models.base_model import BaseModel logger = logging.getLogger(__name__) diff --git a/src/strands/experimental/bidirectional_streaming/models/__init__.py b/src/strands/experimental/bidirectional_streaming/models/__init__.py index c5287d15d..16e2d18c9 100644 --- a/src/strands/experimental/bidirectional_streaming/models/__init__.py +++ b/src/strands/experimental/bidirectional_streaming/models/__init__.py @@ -1,6 +1,6 @@ """Bidirectional model interfaces and implementations.""" -from .bidirectional_model import BidirectionalModel, BidirectionalModelSession +from .base_model import BidirectionalModel, BidirectionalModelSession from .gemini_live import GeminiLiveBidirectionalModel, GeminiLiveSession from .novasonic import NovaSonicBidirectionalModel, NovaSonicSession from .openai import OpenAIRealtimeBidirectionalModel, OpenAIRealtimeSession diff --git a/src/strands/experimental/bidirectional_streaming/models/base_model.py b/src/strands/experimental/bidirectional_streaming/models/base_model.py new file mode 100644 index 000000000..19e3747f4 --- /dev/null +++ b/src/strands/experimental/bidirectional_streaming/models/base_model.py @@ -0,0 +1,53 @@ +"""Bidirectional model interface for real-time streaming conversations. + +Defines the interface for models that support bidirectional streaming capabilities. +Provides abstractions for different model providers with connection-based communication +patterns that support real-time audio and text interaction. + +Features: +- connection-based persistent connections +- Real-time bidirectional communication +- Provider-agnostic event normalization +- Tool execution integration +""" + +from typing import AsyncIterable, Protocol, Union + +from ....types.content import Messages +from ....types.tools import ToolResult, ToolSpec +from ..types.bidirectional_streaming import ( + AudioInputEvent, + BidirectionalStreamEvent, + ImageInputEvent, + TextInputEvent, +) + + +class BaseModel(Protocol): + """Unified interface for bidirectional streaming models. + + Combines model configuration and session communication in a single abstraction. + Providers implement this directly without separate model/session classes. + """ + + async def connect( + self, + system_prompt: str | None = None, + tools: list[ToolSpec] | None = None, + messages: Messages | None = None, + **kwargs, + ) -> None: + """Establish bidirectional connection with the model.""" + ... + + async def close(self) -> None: + """Close connection and cleanup resources.""" + ... + + async def receive(self) -> AsyncIterable[BidirectionalStreamEvent]: + """Receive events from the model in standardized format.""" + ... + + async def send(self, content: Union[TextInputEvent, ImageInputEvent, AudioInputEvent, ToolResult]) -> None: + """Send structured content to the model.""" + ... \ No newline at end of file diff --git a/src/strands/experimental/bidirectional_streaming/models/bidirectional_model.py b/src/strands/experimental/bidirectional_streaming/models/bidirectional_model.py deleted file mode 100644 index 42485561b..000000000 --- a/src/strands/experimental/bidirectional_streaming/models/bidirectional_model.py +++ /dev/null @@ -1,113 +0,0 @@ -"""Bidirectional model interface for real-time streaming conversations. - -Defines the interface for models that support bidirectional streaming capabilities. -Provides abstractions for different model providers with connection-based communication -patterns that support real-time audio and text interaction. - -Features: -- connection-based persistent connections -- Real-time bidirectional communication -- Provider-agnostic event normalization -- Tool execution integration -""" - -import abc -import logging -from typing import AsyncIterable - -from ....types.content import Messages -from ....types.tools import ToolSpec -from ..types.bidirectional_streaming import AudioInputEvent, BidirectionalStreamEvent, ImageInputEvent - -logger = logging.getLogger(__name__) - - -class BidirectionalModelSession(abc.ABC): - """Abstract interface for model-specific bidirectional communication connections. - - Defines the contract for managing persistent streaming connections with individual - model providers, handling audio/text input, receiving events, and managing - tool execution results. - """ - - @abc.abstractmethod - async def receive_events(self) -> AsyncIterable[BidirectionalStreamEvent]: - """Receive events from the model in standardized format. - - Converts provider-specific events to a common format that can be - processed uniformly by the event loop. - """ - raise NotImplementedError - - @abc.abstractmethod - async def send_audio_content(self, audio_input: AudioInputEvent) -> None: - """Send audio content to the model during an active connection. - - Handles audio encoding and provider-specific formatting while presenting - a simple AudioInputEvent interface. - """ - raise NotImplementedError - - # TODO: remove with interface unification - async def send_image_content(self, image_input: ImageInputEvent) -> None: - """Send image content to the model during an active connection. - - Handles image encoding and provider-specific formatting while presenting - a simple ImageInputEvent interface. - """ - raise NotImplementedError - - @abc.abstractmethod - async def send_text_content(self, text: str, **kwargs) -> None: - """Send text content to the model during ongoing generation. - - Allows natural interruption and follow-up questions without requiring - connection restart. - """ - raise NotImplementedError - - @abc.abstractmethod - async def send_interrupt(self) -> None: - """Send interruption signal to stop generation immediately. - - Enables responsive conversational experiences where users can - naturally interrupt during model responses. - """ - raise NotImplementedError - - @abc.abstractmethod - async def send_tool_result(self, tool_use_id: str, result: dict[str, any]) -> None: - """Send tool execution result to the model. - - Formats and sends tool results according to the provider's specific protocol. - Handles both successful results and error cases through the result dictionary. - """ - raise NotImplementedError - - @abc.abstractmethod - async def close(self) -> None: - """Close the connection and cleanup resources.""" - raise NotImplementedError - - -class BidirectionalModel(abc.ABC): - """Interface for models that support bidirectional streaming. - - Defines the contract for creating persistent streaming connections that support - real-time audio and text communication with AI models. - """ - - @abc.abstractmethod - async def create_bidirectional_connection( - self, - system_prompt: str | None = None, - tools: list[ToolSpec] | None = None, - messages: Messages | None = None, - **kwargs, - ) -> BidirectionalModelSession: - """Create a bidirectional connection with the model. - - Establishes a persistent connection for real-time communication while - abstracting provider-specific initialization requirements. - """ - raise NotImplementedError diff --git a/src/strands/experimental/bidirectional_streaming/models/gemini_live.py b/src/strands/experimental/bidirectional_streaming/models/gemini_live.py index 64c4d7348..afadcba1c 100644 --- a/src/strands/experimental/bidirectional_streaming/models/gemini_live.py +++ b/src/strands/experimental/bidirectional_streaming/models/gemini_live.py @@ -33,7 +33,7 @@ TextOutputEvent, TranscriptEvent, ) -from .bidirectional_model import BidirectionalModel, BidirectionalModelSession +from .base_model import BaseModel logger = logging.getLogger(__name__) diff --git a/src/strands/experimental/bidirectional_streaming/models/novasonic.py b/src/strands/experimental/bidirectional_streaming/models/novasonic.py index 134ff73fd..6fdd40c03 100644 --- a/src/strands/experimental/bidirectional_streaming/models/novasonic.py +++ b/src/strands/experimental/bidirectional_streaming/models/novasonic.py @@ -37,7 +37,7 @@ TextOutputEvent, UsageMetricsEvent, ) -from .bidirectional_model import BidirectionalModel, BidirectionalModelSession +from .base_model import BaseModel logger = logging.getLogger(__name__) diff --git a/src/strands/experimental/bidirectional_streaming/models/openai.py b/src/strands/experimental/bidirectional_streaming/models/openai.py index 7d009b1c7..dd436d60e 100644 --- a/src/strands/experimental/bidirectional_streaming/models/openai.py +++ b/src/strands/experimental/bidirectional_streaming/models/openai.py @@ -26,7 +26,7 @@ TextOutputEvent, VoiceActivityEvent, ) -from .bidirectional_model import BidirectionalModel, BidirectionalModelSession +from .base_model import BaseModel logger = logging.getLogger(__name__) diff --git a/src/strands/experimental/bidirectional_streaming/types/bidirectional_streaming.py b/src/strands/experimental/bidirectional_streaming/types/bidirectional_streaming.py index 4b215d74e..73f86a469 100644 --- a/src/strands/experimental/bidirectional_streaming/types/bidirectional_streaming.py +++ b/src/strands/experimental/bidirectional_streaming/types/bidirectional_streaming.py @@ -87,6 +87,16 @@ class ImageInputEvent(TypedDict): mimeType: str encoding: Literal["base64", "raw"] +class TextInputEvent(TypedDict): + """Text input event for sending text messages to the model. + + Used for sending text messages through the send() method. + + Attributes: + text: The text content to send to the model. + """ + + text: str class TextOutputEvent(TypedDict): """Text output event from the model during bidirectional streaming. From f0e3d656aba0186d9523d8bc14a285ce5498ea28 Mon Sep 17 00:00:00 2001 From: Jack Yuan Date: Wed, 29 Oct 2025 16:55:59 -0400 Subject: [PATCH 23/23] fix: refine docstring --- .../models/base_model.py | 36 ++++++++++++------- 1 file changed, 24 insertions(+), 12 deletions(-) diff --git a/src/strands/experimental/bidirectional_streaming/models/base_model.py b/src/strands/experimental/bidirectional_streaming/models/base_model.py index 19e3747f4..cf7eb461c 100644 --- a/src/strands/experimental/bidirectional_streaming/models/base_model.py +++ b/src/strands/experimental/bidirectional_streaming/models/base_model.py @@ -1,14 +1,17 @@ -"""Bidirectional model interface for real-time streaming conversations. +"""Protocol interface for real-time bidirectional streaming AI models. -Defines the interface for models that support bidirectional streaming capabilities. -Provides abstractions for different model providers with connection-based communication -patterns that support real-time audio and text interaction. +This module defines the BaseModel protocol that standardizes how AI models handle +real-time, two-way communication with audio, text, images, and tool interactions. +It abstracts provider-specific implementations (Gemini Live, Nova Sonic, OpenAI Realtime) +into a unified interface for seamless integration. -Features: -- connection-based persistent connections -- Real-time bidirectional communication -- Provider-agnostic event normalization -- Tool execution integration +The protocol enables: +- Persistent streaming connections with automatic reconnection +- Real-time audio input/output with interruption support +- Multi-modal content (text, audio, images) in both directions +- Function calling and tool execution during conversations +- Standardized event formats across different AI providers +- Async/await patterns for non-blocking operations """ from typing import AsyncIterable, Protocol, Union @@ -24,10 +27,19 @@ class BaseModel(Protocol): - """Unified interface for bidirectional streaming models. + """Protocol defining the interface for real-time bidirectional AI models. - Combines model configuration and session communication in a single abstraction. - Providers implement this directly without separate model/session classes. + This protocol standardizes how AI models handle persistent streaming connections + for real-time conversations with audio, text, images, and tool interactions. + Implementations handle provider-specific connection management, event processing, + and content serialization while exposing a consistent async interface. + + Models implementing this protocol support: + - WebSocket or streaming API connections + - Real-time audio input/output with voice activity detection + - Multi-modal content streaming (text, audio, images) + - Function calling and tool execution + - Interruption handling and conversation state management """ async def connect(