diff --git a/pyproject.toml b/pyproject.toml index b542c7481..e079ec263 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -54,6 +54,25 @@ sagemaker = [ "boto3-stubs[sagemaker-runtime]>=1.26.0,<2.0.0", "openai>=1.68.0,<2.0.0", # SageMaker uses OpenAI-compatible interface ] +bidirectional-streaming-nova = [ + "pyaudio>=0.2.13", + "rx>=3.2.0", + "smithy-aws-core>=0.0.1", + "pytz", + "aws_sdk_bedrock_runtime", +] +bidirectional-streaming-openai = [ + "pyaudio>=0.2.13", + "websockets>=12.0,<14.0", +] +bidirectional-streaming = [ + "pyaudio>=0.2.13", + "rx>=3.2.0", + "smithy-aws-core>=0.0.1", + "pytz", + "aws_sdk_bedrock_runtime", + "websockets>=12.0,<14.0", +] otel = ["opentelemetry-exporter-otlp-proto-http>=1.30.0,<2.0.0"] docs = [ "sphinx>=5.0.0,<9.0.0", @@ -69,7 +88,7 @@ a2a = [ "fastapi>=0.115.12,<1.0.0", "starlette>=0.46.2,<1.0.0", ] -all = ["strands-agents[a2a,anthropic,docs,gemini,litellm,llamaapi,mistral,ollama,openai,writer,sagemaker,otel]"] +all = ["strands-agents[a2a,anthropic,docs,gemini,bidirectional-streaming,docs,litellm,llamaapi,mistral,ollama,openai,writer,sagemaker,otel]"] dev = [ "commitizen>=4.4.0,<5.0.0", diff --git a/src/strands/experimental/bidirectional_streaming/__init__.py b/src/strands/experimental/bidirectional_streaming/__init__.py new file mode 100644 index 000000000..695393dd8 --- /dev/null +++ b/src/strands/experimental/bidirectional_streaming/__init__.py @@ -0,0 +1,44 @@ +"""Bidirectional streaming package.""" + +# Main components - Primary user interface +from .agent.agent import BidirectionalAgent + +# Advanced interfaces (for custom implementations) +from .models.base_model import BidirectionalModel, BidirectionalModelSession + +# Model providers - What users need to create models +from .models.novasonic import NovaSonicBidirectionalModel +from .models.openai import OpenAIRealtimeBidirectionalModel + +# Event types - For type hints and event handling +from .types.bidirectional_streaming import ( + AudioInputEvent, + AudioOutputEvent, + BidirectionalStreamEvent, + InterruptionDetectedEvent, + TextOutputEvent, + UsageMetricsEvent, + VoiceActivityEvent, +) + +__all__ = [ + # Main interface + "BidirectionalAgent", + + # Model providers + "NovaSonicBidirectionalModel", + "OpenAIRealtimeBidirectionalModel", + + # Event types + "AudioInputEvent", + "AudioOutputEvent", + "TextOutputEvent", + "InterruptionDetectedEvent", + "BidirectionalStreamEvent", + "VoiceActivityEvent", + "UsageMetricsEvent", + + # Model interface + "BidirectionalModel", + "BidirectionalModelSession", +] diff --git a/src/strands/experimental/bidirectional_streaming/agent/__init__.py b/src/strands/experimental/bidirectional_streaming/agent/__init__.py new file mode 100644 index 000000000..c490e001d --- /dev/null +++ b/src/strands/experimental/bidirectional_streaming/agent/__init__.py @@ -0,0 +1,5 @@ +"""Bidirectional agent for real-time streaming conversations.""" + +from .agent import BidirectionalAgent + +__all__ = ["BidirectionalAgent"] diff --git a/src/strands/experimental/bidirectional_streaming/agent/agent.py b/src/strands/experimental/bidirectional_streaming/agent/agent.py new file mode 100644 index 000000000..d0f9807bd --- /dev/null +++ b/src/strands/experimental/bidirectional_streaming/agent/agent.py @@ -0,0 +1,441 @@ +"""Bidirectional Agent for real-time streaming conversations. + +Provides real-time audio and text interaction through persistent streaming sessions. +Unlike traditional request-response patterns, this agent maintains long-running +conversations where users can interrupt, provide additional input, and receive +continuous responses including audio output. + +Key capabilities: +- Persistent conversation sessions with concurrent processing +- Real-time audio input/output streaming +- Mid-conversation interruption and tool execution +- Event-driven communication with model providers +""" + +import asyncio +import json +import logging +import random +from concurrent.futures import ThreadPoolExecutor +from typing import Any, AsyncIterable, Callable, Mapping, Optional + +from .... import _identifier +from ....hooks import HookProvider, HookRegistry +from ....telemetry.metrics import EventLoopMetrics +from ....tools.executors import ConcurrentToolExecutor +from ....tools.executors._executor import ToolExecutor +from ....tools.registry import ToolRegistry +from ....tools.watcher import ToolWatcher +from ....types.content import Message, Messages +from ....types.tools import ToolResult, ToolUse +from ....types.traces import AttributeValue +from ..event_loop.bidirectional_event_loop import start_bidirectional_connection, stop_bidirectional_connection +from ..models.base_model import BaseModel +from ..types.bidirectional_streaming import AudioInputEvent, BidirectionalStreamEvent, ImageInputEvent + +logger = logging.getLogger(__name__) + +_DEFAULT_AGENT_NAME = "Strands Agents" +_DEFAULT_AGENT_ID = "default" + + +class BidirectionalAgent: + """Agent for bidirectional streaming conversations. + + Enables real-time audio and text interaction with AI models through persistent + sessions. Supports concurrent tool execution and interruption handling. + """ + + class ToolCaller: + """Call tool as a function for bidirectional agent.""" + + def __init__(self, agent: "BidirectionalAgent") -> None: + """Initialize tool caller with agent reference.""" + # WARNING: Do not add any other member variables or methods as this could result in a name conflict with + # agent tools and thus break their execution. + self._agent = agent + + def __getattr__(self, name: str) -> Callable[..., Any]: + """Call tool as a function. + + This method enables the method-style interface (e.g., `agent.tool.tool_name(param="value")`). + It matches underscore-separated names to hyphenated tool names (e.g., 'some_thing' matches 'some-thing'). + + Args: + name: The name of the attribute (tool) being accessed. + + Returns: + A function that when called will execute the named tool. + + Raises: + AttributeError: If no tool with the given name exists or if multiple tools match the given name. + """ + + def caller( + user_message_override: Optional[str] = None, + record_direct_tool_call: Optional[bool] = None, + **kwargs: Any, + ) -> Any: + """Call a tool directly by name. + + Args: + user_message_override: Optional custom message to record instead of default + record_direct_tool_call: Whether to record direct tool calls in message history. + For bidirectional agents, this is always True to maintain conversation history. + **kwargs: Keyword arguments to pass to the tool. + + Returns: + The result returned by the tool. + + Raises: + AttributeError: If the tool doesn't exist. + """ + normalized_name = self._find_normalized_tool_name(name) + + # Create unique tool ID and set up the tool request + tool_id = f"tooluse_{name}_{random.randint(100000000, 999999999)}" + tool_use: ToolUse = { + "toolUseId": tool_id, + "name": normalized_name, + "input": kwargs.copy(), + } + tool_results: list[ToolResult] = [] + invocation_state = kwargs + + async def acall() -> ToolResult: + async for event in ToolExecutor._stream(self._agent, tool_use, tool_results, invocation_state): + _ = event + + return tool_results[0] + + def tcall() -> ToolResult: + return asyncio.run(acall()) + + with ThreadPoolExecutor() as executor: + future = executor.submit(tcall) + tool_result = future.result() + + # Always record direct tool calls for bidirectional agents to maintain conversation history + # Use agent's record_direct_tool_call setting if not overridden + if record_direct_tool_call is not None: + should_record_direct_tool_call = record_direct_tool_call + else: + should_record_direct_tool_call = self._agent.record_direct_tool_call + + if should_record_direct_tool_call: + # Create a record of this tool execution in the message history + self._agent._record_tool_execution(tool_use, tool_result, user_message_override) + + return tool_result + + return caller + + def _find_normalized_tool_name(self, name: str) -> str: + """Lookup the tool represented by name, replacing characters with underscores as necessary.""" + tool_registry = self._agent.tool_registry.registry + + if tool_registry.get(name, None): + return name + + # If the desired name contains underscores, it might be a placeholder for characters that can't be + # represented as python identifiers but are valid as tool names, such as dashes. In that case, find + # all tools that can be represented with the normalized name + if "_" in name: + filtered_tools = [ + tool_name for (tool_name, tool) in tool_registry.items() if tool_name.replace("-", "_") == name + ] + + # The registry itself defends against similar names, so we can just take the first match + if filtered_tools: + return filtered_tools[0] + + raise AttributeError(f"Tool '{name}' not found") + + def __init__( + self, + model: BidirectionalModel, + tools: list | None = None, + system_prompt: str | None = None, + messages: Messages | None = None, + record_direct_tool_call: bool = True, + load_tools_from_directory: bool = False, + agent_id: Optional[str] = None, + name: Optional[str] = None, + tool_executor: Optional[ToolExecutor] = None, + hooks: Optional[list[HookProvider]] = None, + trace_attributes: Optional[Mapping[str, AttributeValue]] = None, + description: Optional[str] = None, + ): + """Initialize bidirectional agent with required model and optional configuration. + + Args: + model: BidirectionalModel instance supporting streaming sessions. + tools: Optional list of tools available to the model. + system_prompt: Optional system prompt for conversations. + messages: Optional conversation history to initialize with. + record_direct_tool_call: Whether to record direct tool calls in message history. + load_tools_from_directory: Whether to load and automatically reload tools in the `./tools/` directory. + agent_id: Optional ID for the agent, useful for session management and multi-agent scenarios. + name: Name of the Agent. + tool_executor: Definition of tool execution strategy (e.g., sequential, concurrent, etc.). + hooks: Hooks to be added to the agent hook registry. + trace_attributes: Custom trace attributes to apply to the agent's trace span. + description: Description of what the Agent does. + """ + self.model = model + self.system_prompt = system_prompt + self.messages = messages or [] + + # Agent identification + self.agent_id = _identifier.validate(agent_id or _DEFAULT_AGENT_ID, _identifier.Identifier.AGENT) + self.name = name or _DEFAULT_AGENT_NAME + self.description = description + + # Tool execution configuration + self.record_direct_tool_call = record_direct_tool_call + self.load_tools_from_directory = load_tools_from_directory + + # Process trace attributes to ensure they're of compatible types + self.trace_attributes: dict[str, AttributeValue] = {} + if trace_attributes: + for k, v in trace_attributes.items(): + if isinstance(v, (str, int, float, bool)) or ( + isinstance(v, list) and all(isinstance(x, (str, int, float, bool)) for x in v) + ): + self.trace_attributes[k] = v + + # Initialize tool registry + self.tool_registry = ToolRegistry() + + if tools is not None: + self.tool_registry.process_tools(tools) + + self.tool_registry.initialize_tools(self.load_tools_from_directory) + + # Initialize tool watcher if directory loading is enabled + if self.load_tools_from_directory: + self.tool_watcher = ToolWatcher(tool_registry=self.tool_registry) + + # Initialize tool executor + self.tool_executor = tool_executor or ConcurrentToolExecutor() + + # Initialize hooks system + self.hooks = HookRegistry() + if hooks: + for hook in hooks: + self.hooks.add_hook(hook) + + # Initialize other components + self.event_loop_metrics = EventLoopMetrics() + self.tool_caller = BidirectionalAgent.ToolCaller(self) + + # Session management + self._session = None + self._output_queue = asyncio.Queue() + + @property + def tool(self) -> ToolCaller: + """Call tool as a function. + + Returns: + Tool caller through which user can invoke tool as a function. + + Example: + ``` + agent = BidirectionalAgent(model=model, tools=[calculator]) + agent.tool.calculator(expression="2+2") + ``` + """ + return self.tool_caller + + @property + def tool_names(self) -> list[str]: + """Get a list of all registered tool names. + + Returns: + Names of all tools available to this agent. + """ + all_tools = self.tool_registry.get_all_tools_config() + return list(all_tools.keys()) + + def _record_tool_execution( + self, + tool: ToolUse, + tool_result: ToolResult, + user_message_override: Optional[str], + ) -> None: + """Record a tool execution in the message history. + + Creates a sequence of messages that represent the tool execution: + + 1. A user message describing the tool call + 2. An assistant message with the tool use + 3. A user message with the tool result + 4. An assistant message acknowledging the tool call + + Args: + tool: The tool call information. + tool_result: The result returned by the tool. + user_message_override: Optional custom message to include. + """ + # Filter tool input parameters to only include those defined in tool spec + filtered_input = self._filter_tool_parameters_for_recording(tool["name"], tool["input"]) + + # Create user message describing the tool call + input_parameters = json.dumps(filtered_input, default=lambda o: f"<>") + + user_msg_content = [ + {"text": (f"agent.tool.{tool['name']} direct tool call.\nInput parameters: {input_parameters}\n")} + ] + + # Add override message if provided + if user_message_override: + user_msg_content.insert(0, {"text": f"{user_message_override}\n"}) + + # Create filtered tool use for message history + filtered_tool: ToolUse = { + "toolUseId": tool["toolUseId"], + "name": tool["name"], + "input": filtered_input, + } + + # Create the message sequence + user_msg: Message = { + "role": "user", + "content": user_msg_content, + } + tool_use_msg: Message = { + "role": "assistant", + "content": [{"toolUse": filtered_tool}], + } + tool_result_msg: Message = { + "role": "user", + "content": [{"toolResult": tool_result}], + } + assistant_msg: Message = { + "role": "assistant", + "content": [{"text": f"agent.tool.{tool['name']} was called."}], + } + + # Add to message history + self.messages.append(user_msg) + self.messages.append(tool_use_msg) + self.messages.append(tool_result_msg) + self.messages.append(assistant_msg) + + logger.debug("Direct tool call recorded in message history: %s", tool["name"]) + + def _filter_tool_parameters_for_recording(self, tool_name: str, input_params: dict[str, Any]) -> dict[str, Any]: + """Filter input parameters to only include those defined in the tool specification. + + Args: + tool_name: Name of the tool to get specification for + input_params: Original input parameters + + Returns: + Filtered parameters containing only those defined in tool spec + """ + all_tools_config = self.tool_registry.get_all_tools_config() + tool_spec = all_tools_config.get(tool_name) + + if not tool_spec or "inputSchema" not in tool_spec: + return input_params.copy() + + properties = tool_spec["inputSchema"]["json"]["properties"] + return {k: v for k, v in input_params.items() if k in properties} + + async def start(self) -> None: + """Start a persistent bidirectional conversation session. + + Initializes the streaming session and starts background tasks for processing + model events, tool execution, and session management. + + Raises: + ValueError: If conversation already active. + ConnectionError: If session creation fails. + """ + if self._session and self._session.active: + raise ValueError("Conversation already active. Call end() first.") + + logger.debug("Conversation start - initializing session") + self._session = await start_bidirectional_connection(self) + + async def send(self, input_data: str | AudioInputEvent | ImageInputEvent) -> None: + """Send input to the model (text, audio, or image). + + Unified method for sending text, audio, and image input to the model during + an active conversation session. + + Args: + input_data: String for text, AudioInputEvent for audio, or ImageInputEvent for images. + + Raises: + ValueError: If no active session or invalid input type. + """ + self._validate_active_session() + + if isinstance(input_data, str): + # Add user text message to history + self.messages.append({"role": "user", "content": input_data}) + + logger.debug("Text sent: %d characters", len(input_data)) + await self._session.model_session.send_text_content(input_data) + elif isinstance(input_data, dict) and "audioData" in input_data: + # Handle audio input + await self._session.model_session.send_audio_content(input_data) + elif isinstance(input_data, dict) and "imageData" in input_data: + # Handle image input (ImageInputEvent) + await self._session.model_session.send_image_content(input_data) + else: + raise ValueError( + "Input must be either a string (text), AudioInputEvent " + "(dict with audioData, format, sampleRate, channels), or ImageInputEvent " + "(dict with imageData, mimeType, encoding)" + ) + + async def receive(self) -> AsyncIterable[BidirectionalStreamEvent]: + """Receive events from the model including audio, text, and tool calls. + + Yields model output events processed by background tasks including audio output, + text responses, tool calls, and session updates. + + Yields: + BidirectionalStreamEvent: Events from the model session. + """ + while self._session and self._session.active: + try: + event = await asyncio.wait_for(self._output_queue.get(), timeout=0.1) + yield event + except asyncio.TimeoutError: + continue + + async def interrupt(self) -> None: + """Interrupt the current model generation and clear audio buffers. + + Sends interruption signal to stop generation immediately and clears + pending audio output for responsive conversation flow. + + Raises: + ValueError: If no active session. + """ + self._validate_active_session() + await self._session.model_session.send_interrupt() + + async def end(self) -> None: + """End the conversation session and cleanup all resources. + + Terminates the streaming session, cancels background tasks, and + closes the connection to the model provider. + """ + if self._session: + await stop_bidirectional_connection(self._session) + self._session = None + + def _validate_active_session(self) -> None: + """Validate that an active session exists. + + Raises: + ValueError: If no active session. + """ + if not self._session or not self._session.active: + raise ValueError("No active conversation. Call start() first.") diff --git a/src/strands/experimental/bidirectional_streaming/event_loop/__init__.py b/src/strands/experimental/bidirectional_streaming/event_loop/__init__.py new file mode 100644 index 000000000..af8c4e1e1 --- /dev/null +++ b/src/strands/experimental/bidirectional_streaming/event_loop/__init__.py @@ -0,0 +1,15 @@ +"""Event loop management for bidirectional streaming.""" + +from .bidirectional_event_loop import ( + BidirectionalConnection, + bidirectional_event_loop_cycle, + start_bidirectional_connection, + stop_bidirectional_connection, +) + +__all__ = [ + "BidirectionalConnection", + "start_bidirectional_connection", + "stop_bidirectional_connection", + "bidirectional_event_loop_cycle", +] diff --git a/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py b/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py new file mode 100644 index 000000000..f82b5911d --- /dev/null +++ b/src/strands/experimental/bidirectional_streaming/event_loop/bidirectional_event_loop.py @@ -0,0 +1,480 @@ +"""Bidirectional session management for concurrent streaming conversations. + +Manages bidirectional communication sessions with concurrent processing of model events, +tool execution, and audio processing. Provides coordination between background tasks +while maintaining a simple interface for agent interaction. + +Features: +- Concurrent task management for model events and tool execution +- Interruption handling with audio buffer clearing +- Tool execution with cancellation support +- Session lifecycle management +""" + +import asyncio +import logging +import traceback +import uuid + +from ....tools._validator import validate_and_prepare_tools +from ....telemetry.metrics import Trace +from ....types._events import ToolResultEvent, ToolStreamEvent +from ....types.content import Message +from ....types.tools import ToolResult, ToolUse +from ..models.base_model import BaseModel + +logger = logging.getLogger(__name__) + +# Session constants +TOOL_QUEUE_TIMEOUT = 0.5 +SUPERVISION_INTERVAL = 0.1 + + +class BidirectionalConnection: + """Session wrapper for bidirectional communication with concurrent task management. + + Coordinates background tasks for model event processing, tool execution, and audio + handling while providing a simple interface for agent interactions. + """ + + def __init__(self, model_session: BidirectionalModelSession, agent: "BidirectionalAgent") -> None: + """Initialize session with model session and agent reference. + + Args: + model_session: Provider-specific bidirectional model session. + agent: BidirectionalAgent instance for tool registry access. + """ + self.model_session = model_session + self.agent = agent + self.active = True + + # Background processing coordination + self.background_tasks = [] + self.tool_queue = asyncio.Queue() + self.audio_output_queue = asyncio.Queue() + + # Task management for cleanup + self.pending_tool_tasks: dict[str, asyncio.Task] = {} + + # Interruption handling (model-agnostic) + self.interrupted = False + self.interruption_lock = asyncio.Lock() + + # Tool execution tracking + self.tool_count = 0 + + +async def start_bidirectional_connection(agent: "BidirectionalAgent") -> BidirectionalConnection: + """Initialize bidirectional session with conycurrent background tasks. + + Creates a model-specific session and starts background tasks for processing + model events, executing tools, and managing the session lifecycle. + + Args: + agent: BidirectionalAgent instance. + + Returns: + BidirectionalConnection: Active session with background tasks running. + """ + logger.debug("Starting bidirectional session - initializing model session") + + # Create provider-specific session + model_session = await agent.model.create_bidirectional_connection( + system_prompt=agent.system_prompt, tools=agent.tool_registry.get_all_tool_specs(), messages=agent.messages + ) + + # Create session wrapper for background processing + session = BidirectionalConnection(model_session=model_session, agent=agent) + + # Start concurrent background processors IMMEDIATELY after session creation + # This is critical - Nova Sonic needs response processing during initialization + logger.debug("Starting background processors for concurrent processing") + session.background_tasks = [ + asyncio.create_task(_process_model_events(session)), # Handle model responses + asyncio.create_task(_process_tool_execution(session)), # Execute tools concurrently + ] + + # Start main coordination cycle + session.main_cycle_task = asyncio.create_task(bidirectional_event_loop_cycle(session)) + + logger.debug("Session ready with %d background tasks", len(session.background_tasks)) + return session + + +async def stop_bidirectional_connection(session: BidirectionalConnection) -> None: + """End session and cleanup resources including background tasks. + + Args: + session: BidirectionalConnection to cleanup. + """ + if not session.active: + return + + logger.debug("Session cleanup starting") + session.active = False + + # Cancel pending tool tasks + for _, task in session.pending_tool_tasks.items(): + if not task.done(): + task.cancel() + + # Cancel background tasks + for task in session.background_tasks: + if not task.done(): + task.cancel() + + # Cancel main cycle task + if hasattr(session, "main_cycle_task") and not session.main_cycle_task.done(): + session.main_cycle_task.cancel() + + # Wait for tasks to complete + all_tasks = session.background_tasks + list(session.pending_tool_tasks.values()) + if hasattr(session, "main_cycle_task"): + all_tasks.append(session.main_cycle_task) + + if all_tasks: + await asyncio.gather(*all_tasks, return_exceptions=True) + + # Close model session + await session.model_session.close() + logger.debug("Session closed") + + +async def bidirectional_event_loop_cycle(session: BidirectionalConnection) -> None: + """Main event loop coordinator that runs continuously during the session. + + Monitors background tasks, manages session state, and handles session lifecycle. + Provides supervision for concurrent model event processing and tool execution. + + Args: + session: BidirectionalConnection to coordinate. + """ + while session.active: + try: + # Check if background processors are still running + if all(task.done() for task in session.background_tasks): + logger.debug("Session end - all processors completed") + session.active = False + break + + # Check for failed background tasks + for i, task in enumerate(session.background_tasks): + if task.done() and not task.cancelled(): + exception = task.exception() + if exception: + logger.error("Session error in processor %d: %s", i, str(exception)) + session.active = False + raise exception + + # Brief pause before next supervision check + await asyncio.sleep(SUPERVISION_INTERVAL) + + except asyncio.CancelledError: + break + except Exception as e: + logger.error("Event loop error: %s", str(e)) + session.active = False + raise + + +async def _handle_interruption(session: BidirectionalConnection) -> None: + """Handle interruption detection with task cancellation and audio buffer clearing. + + Cancels pending tool tasks and clears audio output queues to ensure responsive + interruption handling during conversations. Protected by async lock to prevent + concurrent execution and race conditions. + + Args: + session: BidirectionalConnection to handle interruption for. + """ + async with session.interruption_lock: + # If already interrupted, skip duplicate processing + if session.interrupted: + logger.debug("Interruption already in progress") + return + + logger.debug("Interruption detected") + session.interrupted = True + + # Cancel all pending tool execution tasks + cancelled_tools = 0 + for _task_id, task in list(session.pending_tool_tasks.items()): + if not task.done(): + task.cancel() + cancelled_tools += 1 + logger.debug("Tool task cancelled: %s", _task_id) + + if cancelled_tools > 0: + logger.debug("Tool tasks cancelled: %d", cancelled_tools) + + # Clear all queued audio output events + cleared_count = 0 + while True: + try: + session.audio_output_queue.get_nowait() + cleared_count += 1 + except asyncio.QueueEmpty: + break + + # Also clear the agent's audio output queue + audio_cleared = 0 + # Create a temporary list to hold non-audio events + temp_events = [] + try: + while True: + event = session.agent._output_queue.get_nowait() + if event.get("audioOutput"): + audio_cleared += 1 + else: + # Keep non-audio events + temp_events.append(event) + except asyncio.QueueEmpty: + pass + + # Put back non-audio events + for event in temp_events: + session.agent._output_queue.put_nowait(event) + + if audio_cleared > 0: + logger.debug("Agent audio queue cleared: %d events", audio_cleared) + + if cleared_count > 0: + logger.debug("Session audio queue cleared: %d events", cleared_count) + + # Reset interruption flag after clearing (automatic recovery) + session.interrupted = False + logger.debug("Interruption handled - tools cancelled: %d, audio cleared: %d", cancelled_tools, cleared_count) + + +async def _process_model_events(session: BidirectionalConnection) -> None: + """Process model events and convert them to Strands format. + + Background task that handles all model responses, converts provider-specific + events to standardized formats, and manages interruption detection. + + Args: + session: BidirectionalConnection containing model session. + """ + logger.debug("Model events processor started") + try: + async for provider_event in session.model_session.receive_events(): + if not session.active: + break + + # Basic validation - skip invalid events + if not isinstance(provider_event, dict): + continue + + strands_event = provider_event + + # Handle interruption detection (provider converts raw patterns to interruptionDetected) + if strands_event.get("interruptionDetected"): + logger.debug("Interruption forwarded") + await _handle_interruption(session) + # Forward interruption event to agent for application-level handling + await session.agent._output_queue.put(strands_event) + continue + + # Queue tool requests for concurrent execution + if strands_event.get("toolUse"): + tool_name = strands_event["toolUse"].get("name") + logger.debug("Tool usage detected: %s", tool_name) + await session.tool_queue.put(strands_event["toolUse"]) + continue + + # Send output events to Agent for receive() method + if strands_event.get("audioOutput") or strands_event.get("textOutput"): + await session.agent._output_queue.put(strands_event) + + # Update Agent conversation history using existing patterns + if strands_event.get("messageStop"): + logger.debug("Message added to history") + session.agent.messages.append(strands_event["messageStop"]["message"]) + + # Handle user audio transcripts - add to message history + if strands_event.get("textOutput") and strands_event["textOutput"].get("role") == "user": + user_transcript = strands_event["textOutput"]["text"] + if user_transcript.strip(): # Only add non-empty transcripts + user_message = {"role": "user", "content": user_transcript} + session.agent.messages.append(user_message) + logger.debug("User transcript added to history") + + except Exception as e: + logger.error("Model events error: %s", str(e)) + traceback.print_exc() + finally: + logger.debug("Model events processor stopped") + + +async def _process_tool_execution(session: BidirectionalConnection) -> None: + """Execute tools concurrently with interruption support. + + Background task that manages tool execution without blocking model event + processing or user interaction. Uses proper asyncio cancellation for + interruption handling rather than manual state checks. + + Args: + session: BidirectionalConnection containing tool queue. + """ + logger.debug("Tool execution processor started") + while session.active: + try: + tool_use = await asyncio.wait_for(session.tool_queue.get(), timeout=TOOL_QUEUE_TIMEOUT) + tool_name = tool_use.get("name") + tool_id = tool_use.get("toolUseId") + + session.tool_count += 1 + print(f"\nTool #{session.tool_count}: {tool_name}") + + logger.debug("Tool execution started: %s (id: %s)", tool_name, tool_id) + + task_id = str(uuid.uuid4()) + task = asyncio.create_task(_execute_tool_with_strands(session, tool_use)) + session.pending_tool_tasks[task_id] = task + + def cleanup_task(completed_task: asyncio.Task, task_id: str = task_id) -> None: + try: + # Remove from pending tasks + if task_id in session.pending_tool_tasks: + del session.pending_tool_tasks[task_id] + + # Log completion status + if completed_task.cancelled(): + logger.debug("Tool task cancelled: %s", task_id) + elif completed_task.exception(): + logger.error("Tool task error: %s - %s", task_id, str(completed_task.exception())) + else: + logger.debug("Tool task completed: %s", task_id) + except Exception as e: + logger.error("Tool task cleanup failed: %s - %s", task_id, str(e)) + + task.add_done_callback(cleanup_task) + + except asyncio.TimeoutError: + if not session.active: + break + # Remove completed tasks from tracking + completed_tasks = [task_id for task_id, task in session.pending_tool_tasks.items() if task.done()] + for task_id in completed_tasks: + if task_id in session.pending_tool_tasks: + del session.pending_tool_tasks[task_id] + + if completed_tasks: + logger.debug("Periodic task cleanup: %d tasks", len(completed_tasks)) + + continue + except Exception as e: + logger.error("Tool execution error: %s", str(e)) + if not session.active: + break + + logger.debug("Tool execution processor stopped") + + + + + +async def _execute_tool_with_strands(session: BidirectionalConnection, tool_use: dict) -> None: + """Execute tool using the complete Strands tool execution system. + + Uses proper Strands ToolExecutor system with validation, error handling, + and event streaming. + + Args: + session: BidirectionalConnection for context. + tool_use: Tool use event to execute. + """ + tool_name = tool_use.get("name") + tool_id = tool_use.get("toolUseId") + + logger.debug("Executing tool: %s (id: %s)", tool_name, tool_id) + + try: + # Create message structure for validation + tool_message: Message = {"role": "assistant", "content": [{"toolUse": tool_use}]} + + # Use Strands validation system + tool_uses: list[ToolUse] = [] + tool_results: list[ToolResult] = [] + invalid_tool_use_ids: list[str] = [] + + validate_and_prepare_tools(tool_message, tool_uses, tool_results, invalid_tool_use_ids) + + # Filter valid tools + valid_tool_uses = [tu for tu in tool_uses if tu.get("toolUseId") not in invalid_tool_use_ids] + + if not valid_tool_uses: + logger.warning("No valid tools after validation: %s", tool_name) + return + + # Create invocation state for tool execution + invocation_state = { + "agent": session.agent, + "model": session.agent.model, + "messages": session.agent.messages, + "system_prompt": session.agent.system_prompt, + } + + # Create cycle trace and span + cycle_trace = Trace("Bidirectional Tool Execution") + cycle_span = None + + tool_events = session.agent.tool_executor._execute( + session.agent, + valid_tool_uses, + tool_results, + cycle_trace, + cycle_span, + invocation_state + ) + + # Process tool events and send results to provider + async for tool_event in tool_events: + if isinstance(tool_event, ToolResultEvent): + tool_result = tool_event.tool_result + tool_use_id = tool_result.get("toolUseId") + + # Send result through provider-specific session + await session.model_session.send_tool_result(tool_use_id, tool_result) + logger.debug("Tool result sent: %s", tool_use_id) + + # Handle streaming events if needed later + elif isinstance(tool_event, ToolStreamEvent): + logger.debug("Tool stream event: %s", tool_event) + pass + + # Add tool result message to conversation history + if tool_results: + from ....hooks import MessageAddedEvent + + tool_result_message: Message = { + "role": "user", + "content": [{"toolResult": result} for result in tool_results], + } + + session.agent.messages.append(tool_result_message) + session.agent.hooks.invoke_callbacks(MessageAddedEvent(agent=session.agent, message=tool_result_message)) + logger.debug("Tool result message added to history: %s", tool_name) + + logger.debug("Tool execution completed: %s", tool_name) + + except asyncio.CancelledError: + logger.debug("Tool execution cancelled: %s (id: %s)", tool_name, tool_id) + raise + except Exception as e: + logger.error("Tool execution error: %s - %s", tool_name, str(e)) + + # Send error result + error_result: ToolResult = { + "toolUseId": tool_id, + "status": "error", + "content": [{"text": f"Error: {str(e)}"}] + } + try: + await session.model_session.send_tool_result(tool_id, error_result) + logger.debug("Error result sent: %s", tool_id) + except Exception: + logger.error("Failed to send error result: %s", tool_id) + pass # Session might be closed + + diff --git a/src/strands/experimental/bidirectional_streaming/models/__init__.py b/src/strands/experimental/bidirectional_streaming/models/__init__.py new file mode 100644 index 000000000..16e2d18c9 --- /dev/null +++ b/src/strands/experimental/bidirectional_streaming/models/__init__.py @@ -0,0 +1,17 @@ +"""Bidirectional model interfaces and implementations.""" + +from .base_model import BidirectionalModel, BidirectionalModelSession +from .gemini_live import GeminiLiveBidirectionalModel, GeminiLiveSession +from .novasonic import NovaSonicBidirectionalModel, NovaSonicSession +from .openai import OpenAIRealtimeBidirectionalModel, OpenAIRealtimeSession + +__all__ = [ + "BidirectionalModel", + "BidirectionalModelSession", + "GeminiLiveBidirectionalModel", + "GeminiLiveSession", + "NovaSonicBidirectionalModel", + "NovaSonicSession", + "OpenAIRealtimeBidirectionalModel", + "OpenAIRealtimeSession", +] diff --git a/src/strands/experimental/bidirectional_streaming/models/base_model.py b/src/strands/experimental/bidirectional_streaming/models/base_model.py new file mode 100644 index 000000000..cf7eb461c --- /dev/null +++ b/src/strands/experimental/bidirectional_streaming/models/base_model.py @@ -0,0 +1,65 @@ +"""Protocol interface for real-time bidirectional streaming AI models. + +This module defines the BaseModel protocol that standardizes how AI models handle +real-time, two-way communication with audio, text, images, and tool interactions. +It abstracts provider-specific implementations (Gemini Live, Nova Sonic, OpenAI Realtime) +into a unified interface for seamless integration. + +The protocol enables: +- Persistent streaming connections with automatic reconnection +- Real-time audio input/output with interruption support +- Multi-modal content (text, audio, images) in both directions +- Function calling and tool execution during conversations +- Standardized event formats across different AI providers +- Async/await patterns for non-blocking operations +""" + +from typing import AsyncIterable, Protocol, Union + +from ....types.content import Messages +from ....types.tools import ToolResult, ToolSpec +from ..types.bidirectional_streaming import ( + AudioInputEvent, + BidirectionalStreamEvent, + ImageInputEvent, + TextInputEvent, +) + + +class BaseModel(Protocol): + """Protocol defining the interface for real-time bidirectional AI models. + + This protocol standardizes how AI models handle persistent streaming connections + for real-time conversations with audio, text, images, and tool interactions. + Implementations handle provider-specific connection management, event processing, + and content serialization while exposing a consistent async interface. + + Models implementing this protocol support: + - WebSocket or streaming API connections + - Real-time audio input/output with voice activity detection + - Multi-modal content streaming (text, audio, images) + - Function calling and tool execution + - Interruption handling and conversation state management + """ + + async def connect( + self, + system_prompt: str | None = None, + tools: list[ToolSpec] | None = None, + messages: Messages | None = None, + **kwargs, + ) -> None: + """Establish bidirectional connection with the model.""" + ... + + async def close(self) -> None: + """Close connection and cleanup resources.""" + ... + + async def receive(self) -> AsyncIterable[BidirectionalStreamEvent]: + """Receive events from the model in standardized format.""" + ... + + async def send(self, content: Union[TextInputEvent, ImageInputEvent, AudioInputEvent, ToolResult]) -> None: + """Send structured content to the model.""" + ... \ No newline at end of file diff --git a/src/strands/experimental/bidirectional_streaming/models/gemini_live.py b/src/strands/experimental/bidirectional_streaming/models/gemini_live.py new file mode 100644 index 000000000..afadcba1c --- /dev/null +++ b/src/strands/experimental/bidirectional_streaming/models/gemini_live.py @@ -0,0 +1,499 @@ +"""Gemini Live API bidirectional model provider using official Google GenAI SDK. + +Implements the BidirectionalModel interface for Google's Gemini Live API using the +official Google GenAI SDK for simplified and robust WebSocket communication. + +Key improvements over custom WebSocket implementation: +- Uses official google-genai SDK with native Live API support +- Simplified session management with client.aio.live.connect() +- Built-in tool integration and event handling +- Automatic WebSocket connection management and error handling +- Native support for audio/text streaming and interruption +""" + +import asyncio +import base64 +import logging +import uuid +from typing import Any, AsyncIterable, Dict, List, Optional + +from google import genai +from google.genai import types as genai_types +from google.genai.types import LiveServerMessage, LiveServerContent + +from ....types.content import Messages +from ....types.tools import ToolSpec, ToolUse +from ..types.bidirectional_streaming import ( + AudioInputEvent, + AudioOutputEvent, + BidirectionalConnectionEndEvent, + BidirectionalConnectionStartEvent, + ImageInputEvent, + InterruptionDetectedEvent, + TextOutputEvent, + TranscriptEvent, +) +from .base_model import BaseModel + +logger = logging.getLogger(__name__) + +# Audio format constants +GEMINI_INPUT_SAMPLE_RATE = 16000 +GEMINI_OUTPUT_SAMPLE_RATE = 24000 +GEMINI_CHANNELS = 1 + + +class GeminiLiveSession(BidirectionalModelSession): + """Gemini Live API session using official Google GenAI SDK. + + Provides a clean interface to Gemini Live API using the official SDK, + eliminating custom WebSocket handling and providing robust error handling. + """ + + def __init__(self, client: genai.Client, model_id: str, config: Dict[str, Any]): + """Initialize Gemini Live API session. + + Args: + client: Gemini client instance + model_id: Model identifier + config: Model configuration including live config + """ + self.client = client + self.model_id = model_id + self.config = config + self.session_id = str(uuid.uuid4()) + self._active = True + self.live_session = None + self.live_session_cm = None + + + + async def initialize( + self, + system_prompt: Optional[str] = None, + tools: Optional[List[ToolSpec]] = None, + messages: Optional[Messages] = None + ) -> None: + """Initialize Gemini Live API session by creating the connection.""" + + try: + # Build live config + live_config = self.config.get("live_config") + + if live_config is None: + raise ValueError("live_config is required but not found in session config") + + # Create the context manager + self.live_session_cm = self.client.aio.live.connect( + model=self.model_id, + config=live_config + ) + + # Enter the context manager + self.live_session = await self.live_session_cm.__aenter__() + + # Send initial message history if provided + if messages: + await self._send_message_history(messages) + + + except Exception as e: + logger.error("Error initializing Gemini Live session: %s", e) + raise + + async def _send_message_history(self, messages: Messages) -> None: + """Send conversation history to Gemini Live API. + + Sends each message as a separate turn with the correct role to maintain + proper conversation context. Follows the same pattern as the non-bidirectional + Gemini model implementation. + """ + if not messages: + return + + # Convert each message to Gemini format and send separately + for message in messages: + content_parts = [] + for content_block in message["content"]: + if "text" in content_block: + content_parts.append(genai_types.Part(text=content_block["text"])) + + if content_parts: + # Map role correctly - Gemini uses "user" and "model" roles + # "assistant" role from Messages format maps to "model" in Gemini + role = "model" if message["role"] == "assistant" else message["role"] + content = genai_types.Content(role=role, parts=content_parts) + await self.live_session.send_client_content(turns=content) + + async def receive_events(self) -> AsyncIterable[Dict[str, Any]]: + """Receive Gemini Live API events and convert to provider-agnostic format.""" + + # Emit connection start event + connection_start: BidirectionalConnectionStartEvent = { + "connectionId": self.session_id, + "metadata": {"provider": "gemini_live", "model_id": self.config.get("model_id")} + } + yield {"BidirectionalConnectionStart": connection_start} + + try: + # Wrap in while loop to restart after turn_complete (SDK limitation workaround) + while self._active: + try: + async for message in self.live_session.receive(): + if not self._active: + break + + # Convert to provider-agnostic format + provider_event = self._convert_gemini_live_event(message) + if provider_event: + yield provider_event + + # SDK exits receive loop after turn_complete - restart automatically + if self._active: + logger.debug("Restarting receive loop after turn completion") + + except Exception as e: + logger.error("Error in receive iteration: %s", e) + # Small delay before retrying to avoid tight error loops + await asyncio.sleep(0.1) + + except Exception as e: + logger.error("Fatal error in receive loop: %s", e) + finally: + # Emit connection end event when exiting + connection_end: BidirectionalConnectionEndEvent = { + "connectionId": self.session_id, + "reason": "connection_complete", + "metadata": {"provider": "gemini_live"} + } + yield {"BidirectionalConnectionEnd": connection_end} + + def _convert_gemini_live_event(self, message: LiveServerMessage) -> Optional[Dict[str, Any]]: + """Convert Gemini Live API events to provider-agnostic format. + + Handles different types of text output: + - inputTranscription: User's speech transcribed to text (emitted as transcript event) + - outputTranscription: Model's audio transcribed to text (emitted as transcript event) + - modelTurn text: Actual text response from the model (emitted as textOutput) + """ + try: + # Handle interruption first (from server_content) + if message.server_content and message.server_content.interrupted: + interruption: InterruptionDetectedEvent = { + "reason": "user_input" + } + return {"interruptionDetected": interruption} + + # Handle input transcription (user's speech) - emit as transcript event + if message.server_content and message.server_content.input_transcription: + input_transcript = message.server_content.input_transcription + # Check if the transcription object has text content + if hasattr(input_transcript, 'text') and input_transcript.text: + transcription_text = input_transcript.text + logger.debug(f"Input transcription detected: {transcription_text}") + transcript: TranscriptEvent = { + "text": transcription_text, + "role": "user", + "type": "input" + } + return {"transcript": transcript} + + # Handle output transcription (model's audio) - emit as transcript event + if message.server_content and message.server_content.output_transcription: + output_transcript = message.server_content.output_transcription + # Check if the transcription object has text content + if hasattr(output_transcript, 'text') and output_transcript.text: + transcription_text = output_transcript.text + logger.debug(f"Output transcription detected: {transcription_text}") + transcript: TranscriptEvent = { + "text": transcription_text, + "role": "assistant", + "type": "output" + } + return {"transcript": transcript} + + # Handle actual text output from model (not transcription) + # The SDK's message.text property accesses modelTurn.parts[].text + if message.text: + text_output: TextOutputEvent = { + "text": message.text, + "role": "assistant" + } + return {"textOutput": text_output} + + # Handle audio output using SDK's built-in data property + if message.data: + audio_output: AudioOutputEvent = { + "audioData": message.data, + "format": "pcm", + "sampleRate": GEMINI_OUTPUT_SAMPLE_RATE, + "channels": GEMINI_CHANNELS, + "encoding": "raw" + } + return {"audioOutput": audio_output} + + # Handle tool calls + if message.tool_call and message.tool_call.function_calls: + for func_call in message.tool_call.function_calls: + tool_use_event: ToolUse = { + "toolUseId": func_call.id, + "name": func_call.name, + "input": func_call.args or {} + } + return {"toolUse": tool_use_event} + + # Silently ignore setup_complete, turn_complete, generation_complete, and usage_metadata messages + return None + + except Exception as e: + logger.error("Error converting Gemini Live event: %s", e) + logger.error("Message type: %s", type(message).__name__) + logger.error("Message attributes: %s", [attr for attr in dir(message) if not attr.startswith('_')]) + return None + + async def send_audio_content(self, audio_input: AudioInputEvent) -> None: + """Send audio content using Gemini Live API. + + Gemini Live expects continuous audio streaming via send_realtime_input. + This automatically triggers VAD and can interrupt ongoing responses. + """ + if not self._active: + return + + try: + # Create audio blob for the SDK + audio_blob = genai_types.Blob( + data=audio_input["audioData"], + mime_type=f"audio/pcm;rate={GEMINI_INPUT_SAMPLE_RATE}" + ) + + # Send real-time audio input - this automatically handles VAD and interruption + await self.live_session.send_realtime_input(audio=audio_blob) + + except Exception as e: + logger.error("Error sending audio content: %s", e) + + async def send_image_content(self, image_input: ImageInputEvent) -> None: + """Send image content using Gemini Live API. + + Sends image frames following the same pattern as the GitHub example. + Images are sent as base64-encoded data with MIME type. + """ + if not self._active: + return + + try: + # Prepare the message based on encoding + if image_input["encoding"] == "base64": + # Data is already base64 encoded + if isinstance(image_input["imageData"], bytes): + data_str = image_input["imageData"].decode() + else: + data_str = image_input["imageData"] + else: + # Raw bytes - need to base64 encode + data_str = base64.b64encode(image_input["imageData"]).decode() + + # Create the message in the format expected by Gemini Live + msg = { + "mime_type": image_input["mimeType"], + "data": data_str + } + + # Send using the same method as the GitHub example + await self.live_session.send(input=msg) + + except Exception as e: + logger.error("Error sending image content: %s", e) + + async def send_text_content(self, text: str, **kwargs) -> None: + """Send text content using Gemini Live API.""" + if not self._active: + return + + try: + # Create content with text + content = genai_types.Content( + role="user", + parts=[genai_types.Part(text=text)] + ) + + # Send as client content + await self.live_session.send_client_content(turns=content) + + except Exception as e: + logger.error("Error sending text content: %s", e) + + async def send_interrupt(self) -> None: + """Send interruption signal to Gemini Live API. + + Gemini Live uses automatic VAD-based interruption. When new audio input + is detected, it automatically interrupts the ongoing generation. + We don't need to send explicit interrupt signals like Nova Sonic. + """ + if not self._active: + return + + try: + # Gemini Live handles interruption automatically through VAD + # When new audio input is sent via send_realtime_input, it automatically + # interrupts any ongoing generation. No explicit interrupt signal needed. + logger.debug("Interrupt requested - Gemini Live handles this automatically via VAD") + + except Exception as e: + logger.error("Error in interrupt handling: %s", e) + + async def send_tool_result(self, tool_use_id: str, result: Dict[str, Any]) -> None: + """Send tool result using Gemini Live API.""" + if not self._active: + return + + try: + # Create function response + func_response = genai_types.FunctionResponse( + id=tool_use_id, + name=tool_use_id, # Gemini uses name as identifier + response=result + ) + + # Send tool response + await self.live_session.send_tool_response(function_responses=[func_response]) + except Exception as e: + logger.error("Error sending tool result: %s", e) + + async def send_tool_error(self, tool_use_id: str, error: str) -> None: + """Send tool error using Gemini Live API.""" + error_result = {"error": error} + await self.send_tool_result(tool_use_id, error_result) + + async def close(self) -> None: + """Close Gemini Live API connection.""" + if not self._active: + return + + self._active = False + + try: + # Exit the context manager properly + if self.live_session_cm: + await self.live_session_cm.__aexit__(None, None, None) + except Exception as e: + logger.error("Error closing Gemini Live session: %s", e) + raise + + +class GeminiLiveBidirectionalModel(BidirectionalModel): + """Gemini Live API model implementation using official Google GenAI SDK. + + Provides access to Google's Gemini Live API through the bidirectional + streaming interface, using the official SDK for robust and simple integration. + """ + + def __init__( + self, + model_id: str = "models/gemini-2.0-flash-live-preview-04-09", + api_key: Optional[str] = None, + **config + ): + """Initialize Gemini Live API bidirectional model. + + Args: + model_id: Gemini Live model identifier. + api_key: Google AI API key for authentication. + **config: Additional configuration. + """ + self.model_id = model_id + self.api_key = api_key + self.config = config + + # Create Gemini client with proper API version + client_kwargs = {} + if api_key: + client_kwargs["api_key"] = api_key + + # Use v1alpha for Live API as it has better model support + client_kwargs["http_options"] = {"api_version": "v1alpha"} + + self.client = genai.Client(**client_kwargs) + + async def create_bidirectional_connection( + self, + system_prompt: Optional[str] = None, + tools: Optional[List[ToolSpec]] = None, + messages: Optional[Messages] = None, + **kwargs + ) -> BidirectionalModelSession: + """Create Gemini Live API bidirectional connection using official SDK.""" + + try: + # Build configuration + live_config = self._build_live_config(system_prompt, tools, **kwargs) + + # Create session config + session_config = self._get_session_config() + session_config["live_config"] = live_config + + # Create and initialize session wrapper + session = GeminiLiveSession(self.client, self.model_id, session_config) + await session.initialize(system_prompt, tools, messages) + + return session + + except Exception as e: + logger.error("Failed to create Gemini Live connection: %s", e) + raise + + def _build_live_config( + self, + system_prompt: Optional[str] = None, + tools: Optional[List[ToolSpec]] = None, + **kwargs + ) -> Dict[str, Any]: + """Build LiveConnectConfig for the official SDK. + + Simply passes through all config parameters from params, allowing users + to configure any Gemini Live API parameter directly. + """ + # Start with user config from params + config_dict = {} + if "params" in self.config: + config_dict.update(self.config["params"]) + + # Override with any kwargs + config_dict.update(kwargs) + + # Add system instruction if provided + if system_prompt: + config_dict["system_instruction"] = system_prompt + + # Add tools if provided + if tools: + config_dict["tools"] = self._format_tools_for_live_api(tools) + + return config_dict + + def _format_tools_for_live_api(self, tool_specs: List[ToolSpec]) -> List[genai_types.Tool]: + """Format tool specs for Gemini Live API.""" + if not tool_specs: + return [] + + return [ + genai_types.Tool( + function_declarations=[ + genai_types.FunctionDeclaration( + description=tool_spec["description"], + name=tool_spec["name"], + parameters_json_schema=tool_spec["inputSchema"]["json"], + ) + for tool_spec in tool_specs + ], + ), + ] + + def _get_session_config(self) -> Dict[str, Any]: + """Get session configuration for Gemini Live API.""" + return { + "model_id": self.model_id, + "params": self.config.get("params"), + **self.config + } \ No newline at end of file diff --git a/src/strands/experimental/bidirectional_streaming/models/novasonic.py b/src/strands/experimental/bidirectional_streaming/models/novasonic.py new file mode 100644 index 000000000..6fdd40c03 --- /dev/null +++ b/src/strands/experimental/bidirectional_streaming/models/novasonic.py @@ -0,0 +1,716 @@ +"""Nova Sonic bidirectional model provider for real-time streaming conversations. + +Implements the BidirectionalModel interface for Amazon's Nova Sonic, handling the +complex event sequencing and audio processing required by Nova Sonic's +InvokeModelWithBidirectionalStream protocol. + +Nova Sonic specifics: +- Hierarchical event sequences: connectionStart → promptStart → content streaming +- Base64-encoded audio format with hex encoding +- Tool execution with content containers and identifier tracking +- 8-minute connection limits with proper cleanup sequences +- Interruption detection through stopReason events +""" + +import asyncio +import base64 +import json +import logging +import time +import traceback +import uuid +from typing import AsyncIterable + +from aws_sdk_bedrock_runtime.client import BedrockRuntimeClient, InvokeModelWithBidirectionalStreamOperationInput +from aws_sdk_bedrock_runtime.config import Config, HTTPAuthSchemeResolver, SigV4AuthScheme +from aws_sdk_bedrock_runtime.models import BidirectionalInputPayloadPart, InvokeModelWithBidirectionalStreamInputChunk, InvokeModelWithBidirectionalStreamOperationOutput +from smithy_aws_core.identity.environment import EnvironmentCredentialsResolver + +from ....types.content import Messages +from ....types.tools import ToolSpec, ToolUse +from ..types.bidirectional_streaming import ( + AudioInputEvent, + AudioOutputEvent, + BidirectionalConnectionEndEvent, + BidirectionalConnectionStartEvent, + InterruptionDetectedEvent, + TextOutputEvent, + UsageMetricsEvent, +) +from .base_model import BaseModel + +logger = logging.getLogger(__name__) + +# Nova Sonic configuration constants +NOVA_INFERENCE_CONFIG = {"maxTokens": 1024, "topP": 0.9, "temperature": 0.7} + +NOVA_AUDIO_INPUT_CONFIG = { + "mediaType": "audio/lpcm", + "sampleRateHertz": 16000, + "sampleSizeBits": 16, + "channelCount": 1, + "audioType": "SPEECH", + "encoding": "base64", +} + +NOVA_AUDIO_OUTPUT_CONFIG = { + "mediaType": "audio/lpcm", + "sampleRateHertz": 24000, + "sampleSizeBits": 16, + "channelCount": 1, + "voiceId": "matthew", + "encoding": "base64", + "audioType": "SPEECH", +} + +NOVA_TEXT_CONFIG = {"mediaType": "text/plain"} +NOVA_TOOL_CONFIG = {"mediaType": "application/json"} + +# Timing constants +SILENCE_THRESHOLD = 2.0 +EVENT_DELAY = 0.1 +RESPONSE_TIMEOUT = 1.0 + + +class NovaSonicSession(BidirectionalModelSession): + """Nova Sonic connection implementation handling the provider's specific protocol. + + Manages Nova Sonic's complex event sequencing, audio format conversion, and + tool execution patterns while providing the standard BidirectionalModelSession + interface. + """ + + def __init__(self, stream: InvokeModelWithBidirectionalStreamOperationOutput, config: dict[str, any]) -> None: + """Initialize Nova Sonic connection. + + Args: + stream: Nova Sonic bidirectional stream operation output from AWS SDK. + config: Model configuration. + """ + self.stream = stream + self.config = config + self.prompt_name = str(uuid.uuid4()) + self._active = True + + # Nova Sonic requires unique content names + self.audio_content_name = str(uuid.uuid4()) + self.text_content_name = str(uuid.uuid4()) + + # Audio connection state + self.audio_connection_active = False + self.last_audio_time = None + self.silence_threshold = SILENCE_THRESHOLD + self.silence_task = None + + # Validate stream + if not stream: + logger.error("Stream is None") + raise ValueError("Stream cannot be None") + + logger.debug("Nova Sonic connection initialized with prompt: %s", self.prompt_name) + + async def initialize( + self, + system_prompt: str | None = None, + tools: list[ToolSpec] | None = None, + messages: Messages | None = None, + ) -> None: + """Initialize Nova Sonic connection with required protocol sequence.""" + try: + system_prompt = system_prompt or "You are a helpful assistant. Keep responses brief." + + init_events = self._build_initialization_events(system_prompt, tools or [], messages) + + logger.debug("Nova Sonic initialization - sending %d events", len(init_events)) + await self._send_initialization_events(init_events) + + logger.info("Nova Sonic connection initialized successfully") + self._response_task = asyncio.create_task(self._process_responses()) + + except Exception as e: + logger.error("Error during Nova Sonic initialization: %s", e) + raise + + def _build_initialization_events( + self, system_prompt: str, tools: list[ToolSpec], messages: Messages | None + ) -> list[str]: + """Build the sequence of initialization events.""" + events = [self._get_connection_start_event(), self._get_prompt_start_event(tools)] + + events.extend(self._get_system_prompt_events(system_prompt)) + + # Message history would be processed here if needed in the future + # Currently not implemented as it's not used in the existing test cases + + return events + + async def _send_initialization_events(self, events: list[str]) -> None: + """Send initialization events with required delays.""" + for _i, event in enumerate(events): + await self._send_nova_event(event) + await asyncio.sleep(EVENT_DELAY) + + async def _process_responses(self) -> None: + """Process Nova Sonic responses continuously.""" + logger.debug("Nova Sonic response processor started") + + try: + while self._active: + try: + output = await asyncio.wait_for(self.stream.await_output(), timeout=RESPONSE_TIMEOUT) + result = await output[1].receive() + + if result.value and result.value.bytes_: + await self._handle_response_data(result.value.bytes_.decode("utf-8")) + + except asyncio.TimeoutError: + await asyncio.sleep(0.1) + continue + except Exception as e: + logger.warning("Nova Sonic response error: %s", e) + await asyncio.sleep(0.1) + continue + + except Exception as e: + logger.error("Nova Sonic fatal error: %s", e) + finally: + logger.debug("Nova Sonic response processor stopped") + + async def _handle_response_data(self, response_data: str) -> None: + """Handle decoded response data from Nova Sonic.""" + try: + json_data = json.loads(response_data) + + if "event" in json_data: + nova_event = json_data["event"] + self._log_event_type(nova_event) + + if not hasattr(self, "_event_queue"): + self._event_queue = asyncio.Queue() + + await self._event_queue.put(nova_event) + except json.JSONDecodeError as e: + logger.warning("Nova Sonic JSON decode error: %s", e) + + def _log_event_type(self, nova_event: dict[str, any]) -> None: + """Log specific Nova Sonic event types for debugging.""" + if "usageEvent" in nova_event: + logger.debug("Nova usage: %s", nova_event["usageEvent"]) + elif "textOutput" in nova_event: + logger.debug("Nova text output") + elif "toolUse" in nova_event: + tool_use = nova_event["toolUse"] + logger.debug("Nova tool use: %s (id: %s)", tool_use["toolName"], tool_use["toolUseId"]) + elif "audioOutput" in nova_event: + audio_content = nova_event["audioOutput"]["content"] + audio_bytes = base64.b64decode(audio_content) + logger.debug("Nova audio output: %d bytes", len(audio_bytes)) + + async def receive_events(self) -> AsyncIterable[dict[str, any]]: + """Receive Nova Sonic events and convert to provider-agnostic format.""" + if not self.stream: + logger.error("Stream is None") + return + + logger.debug("Nova events - starting event stream") + + # Emit connection start event to Strands event system + connection_start: BidirectionalConnectionStartEvent = { + "connectionId": self.prompt_name, + "metadata": {"provider": "nova_sonic", "model_id": self.config.get("model_id")}, + } + yield {"BidirectionalConnectionStart": connection_start} + + # Initialize event queue if not already done + if not hasattr(self, "_event_queue"): + self._event_queue = asyncio.Queue() + + try: + while self._active: + try: + # Get events from the queue populated by _process_responses + nova_event = await asyncio.wait_for(self._event_queue.get(), timeout=1.0) + + # Convert to provider-agnostic format + provider_event = self._convert_nova_event(nova_event) + if provider_event: + yield provider_event + + except asyncio.TimeoutError: + # No events in queue - continue waiting + continue + + except Exception as e: + logger.error("Error receiving Nova Sonic event: %s", e) + logger.error(traceback.format_exc()) + finally: + # Emit connection end event when exiting + connection_end: BidirectionalConnectionEndEvent = { + "connectionId": self.prompt_name, + "reason": "connection_complete", + "metadata": {"provider": "nova_sonic"}, + } + yield {"BidirectionalConnectionEnd": connection_end} + + async def start_audio_connection(self) -> None: + """Start audio input connection (call once before sending audio chunks).""" + if self.audio_connection_active: + return + + logger.debug("Nova audio connection start") + + audio_content_start = json.dumps( + { + "event": { + "contentStart": { + "promptName": self.prompt_name, + "contentName": self.audio_content_name, + "type": "AUDIO", + "interactive": True, + "role": "USER", + "audioInputConfiguration": NOVA_AUDIO_INPUT_CONFIG, + } + } + } + ) + + await self._send_nova_event(audio_content_start) + self.audio_connection_active = True + + async def send_audio_content(self, audio_input: AudioInputEvent) -> None: + """Send audio using Nova Sonic protocol-specific format.""" + if not self._active: + return + + # Start audio connection if not already active + if not self.audio_connection_active: + await self.start_audio_connection() + + # Update last audio time and cancel any pending silence task + self.last_audio_time = time.time() + if self.silence_task and not self.silence_task.done(): + self.silence_task.cancel() + + # Convert audio to Nova Sonic base64 format + nova_audio_data = base64.b64encode(audio_input["audioData"]).decode("utf-8") + + # Send audio input event + audio_event = json.dumps( + { + "event": { + "audioInput": { + "promptName": self.prompt_name, + "contentName": self.audio_content_name, + "content": nova_audio_data, + } + } + } + ) + + await self._send_nova_event(audio_event) + + # Start silence detection task + self.silence_task = asyncio.create_task(self._check_silence()) + + async def _check_silence(self) -> None: + """Check for silence and automatically end audio connection.""" + try: + await asyncio.sleep(self.silence_threshold) + if self.audio_connection_active and self.last_audio_time: + elapsed = time.time() - self.last_audio_time + if elapsed >= self.silence_threshold: + logger.debug("Nova silence detected: %.2f seconds", elapsed) + await self.end_audio_input() + except asyncio.CancelledError: + pass + + async def end_audio_input(self) -> None: + """End current audio input connection to trigger Nova Sonic processing.""" + if not self.audio_connection_active: + return + + logger.debug("Nova audio connection end") + + audio_content_end = json.dumps( + {"event": {"contentEnd": {"promptName": self.prompt_name, "contentName": self.audio_content_name}}} + ) + + await self._send_nova_event(audio_content_end) + self.audio_connection_active = False + + async def send_text_content(self, text: str, **kwargs) -> None: + """Send text content using Nova Sonic format.""" + if not self._active: + return + + content_name = str(uuid.uuid4()) + events = [ + self._get_text_content_start_event(content_name), + self._get_text_input_event(content_name, text), + self._get_content_end_event(content_name), + ] + + for event in events: + await self._send_nova_event(event) + + async def send_interrupt(self) -> None: + """Send interruption signal to Nova Sonic.""" + if not self._active: + return + + # Nova Sonic handles interruption through special input events + interrupt_event = { + "event": { + "audioInput": { + "promptName": self.prompt_name, + "contentName": self.audio_content_name, + "stopReason": "INTERRUPTED", + } + } + } + await self._send_nova_event(interrupt_event) + + async def send_tool_result(self, tool_use_id: str, result: dict[str, any]) -> None: + """Send tool result using Nova Sonic toolResult format.""" + if not self._active: + return + + logger.debug("Nova tool result send: %s", tool_use_id) + content_name = str(uuid.uuid4()) + events = [ + self._get_tool_content_start_event(content_name, tool_use_id), + self._get_tool_result_event(content_name, result), + self._get_content_end_event(content_name), + ] + + for _i, event in enumerate(events): + await self._send_nova_event(event) + + async def close(self) -> None: + """Close Nova Sonic connection with proper cleanup sequence.""" + if not self._active: + return + + logger.debug("Nova cleanup - starting connection close") + self._active = False + + # Cancel response processing task if running + if hasattr(self, "_response_task") and not self._response_task.done(): + self._response_task.cancel() + try: + await self._response_task + except asyncio.CancelledError: + pass + + try: + # End audio connection if active + if self.audio_connection_active: + await self.end_audio_input() + + # Send cleanup events + cleanup_events = [self._get_prompt_end_event(), self._get_connection_end_event()] + + for event in cleanup_events: + try: + await self._send_nova_event(event) + except Exception as e: + logger.warning("Error during Nova Sonic cleanup: %s", e) + + # Close stream + try: + await self.stream.input_stream.close() + except Exception as e: + logger.warning("Error closing Nova Sonic stream: %s", e) + + except Exception as e: + logger.error("Nova cleanup error: %s", str(e)) + finally: + logger.debug("Nova connection closed") + + def _convert_nova_event(self, nova_event: dict[str, any]) -> dict[str, any] | None: + """Convert Nova Sonic events to provider-agnostic format.""" + # Handle audio output + if "audioOutput" in nova_event: + audio_content = nova_event["audioOutput"]["content"] + audio_bytes = base64.b64decode(audio_content) + + audio_output: AudioOutputEvent = { + "audioData": audio_bytes, + "format": "pcm", + "sampleRate": 24000, + "channels": 1, + "encoding": "base64", + } + + return {"audioOutput": audio_output} + + # Handle text output + elif "textOutput" in nova_event: + text_content = nova_event["textOutput"]["content"] + # Use stored role from contentStart event, fallback to event role + role = getattr(self, "_current_role", nova_event["textOutput"].get("role", "assistant")) + + # Check for Nova Sonic interruption pattern (matches working sample) + if '{ "interrupted" : true }' in text_content: + logger.debug("Nova interruption detected in text") + interruption: InterruptionDetectedEvent = {"reason": "user_input"} + return {"interruptionDetected": interruption} + + # Show transcription for user speech - ALWAYS show these regardless of DEBUG flag + if role == "USER": + print(f"User: {text_content}") + elif role == "ASSISTANT": + print(f"Assistant: {text_content}") + + text_output: TextOutputEvent = {"text": text_content, "role": role.lower()} + + return {"textOutput": text_output} + + # Handle tool use + elif "toolUse" in nova_event: + tool_use = nova_event["toolUse"] + + tool_use_event: ToolUse = { + "toolUseId": tool_use["toolUseId"], + "name": tool_use["toolName"], + "input": json.loads(tool_use["content"]), + } + + return {"toolUse": tool_use_event} + + # Handle interruption + elif nova_event.get("stopReason") == "INTERRUPTED": + logger.debug("Nova interruption stop reason") + + interruption: InterruptionDetectedEvent = {"reason": "user_input"} + + return {"interruptionDetected": interruption} + + # Handle usage events - convert to standardized format + elif "usageEvent" in nova_event: + usage_data = nova_event["usageEvent"] + usage_metrics: UsageMetricsEvent = { + "totalTokens": usage_data.get("totalTokens", 0), + "inputTokens": usage_data.get("totalInputTokens", 0), + "outputTokens": usage_data.get("totalOutputTokens", 0), + "audioTokens": usage_data.get("details", {}).get("total", {}).get("output", {}).get("speechTokens", 0) + } + return {"usageMetrics": usage_metrics} + + # Handle content start events (track role) + elif "contentStart" in nova_event: + role = nova_event["contentStart"].get("role", "unknown") + # Store role for subsequent text output events + self._current_role = role + return None + + # Handle other events + else: + return None + + # Nova Sonic event template methods + def _get_connection_start_event(self) -> str: + """Generate Nova Sonic connection start event.""" + return json.dumps({"event": {"sessionStart": {"inferenceConfiguration": NOVA_INFERENCE_CONFIG}}}) + + def _get_prompt_start_event(self, tools: list[ToolSpec]) -> str: + """Generate Nova Sonic prompt start event with tool configuration.""" + prompt_start_event = { + "event": { + "promptStart": { + "promptName": self.prompt_name, + "textOutputConfiguration": NOVA_TEXT_CONFIG, + "audioOutputConfiguration": NOVA_AUDIO_OUTPUT_CONFIG, + } + } + } + + if tools: + tool_config = self._build_tool_configuration(tools) + prompt_start_event["event"]["promptStart"]["toolUseOutputConfiguration"] = NOVA_TOOL_CONFIG + prompt_start_event["event"]["promptStart"]["toolConfiguration"] = {"tools": tool_config} + + return json.dumps(prompt_start_event) + + def _build_tool_configuration(self, tools: list[ToolSpec]) -> list[dict]: + """Build tool configuration from tool specs.""" + tool_config = [] + for tool in tools: + input_schema = ( + {"json": json.dumps(tool["inputSchema"]["json"])} + if "json" in tool["inputSchema"] + else {"json": json.dumps(tool["inputSchema"])} + ) + + tool_config.append( + {"toolSpec": {"name": tool["name"], "description": tool["description"], "inputSchema": input_schema}} + ) + return tool_config + + def _get_system_prompt_events(self, system_prompt: str) -> list[str]: + """Generate system prompt events.""" + content_name = str(uuid.uuid4()) + return [ + self._get_text_content_start_event(content_name, "SYSTEM"), + self._get_text_input_event(content_name, system_prompt), + self._get_content_end_event(content_name), + ] + + def _get_text_content_start_event(self, content_name: str, role: str = "USER") -> str: + """Generate text content start event.""" + return json.dumps( + { + "event": { + "contentStart": { + "promptName": self.prompt_name, + "contentName": content_name, + "type": "TEXT", + "role": role, + "interactive": True, + "textInputConfiguration": NOVA_TEXT_CONFIG, + } + } + } + ) + + def _get_tool_content_start_event(self, content_name: str, tool_use_id: str) -> str: + """Generate tool content start event.""" + return json.dumps( + { + "event": { + "contentStart": { + "promptName": self.prompt_name, + "contentName": content_name, + "interactive": False, + "type": "TOOL", + "role": "TOOL", + "toolResultInputConfiguration": { + "toolUseId": tool_use_id, + "type": "TEXT", + "textInputConfiguration": NOVA_TEXT_CONFIG, + }, + } + } + } + ) + + def _get_text_input_event(self, content_name: str, text: str) -> str: + """Generate text input event.""" + return json.dumps( + {"event": {"textInput": {"promptName": self.prompt_name, "contentName": content_name, "content": text}}} + ) + + def _get_tool_result_event(self, content_name: str, result: dict[str, any]) -> str: + """Generate tool result event.""" + return json.dumps( + { + "event": { + "toolResult": { + "promptName": self.prompt_name, + "contentName": content_name, + "content": json.dumps(result), + } + } + } + ) + + def _get_content_end_event(self, content_name: str) -> str: + """Generate content end event.""" + return json.dumps({"event": {"contentEnd": {"promptName": self.prompt_name, "contentName": content_name}}}) + + def _get_prompt_end_event(self) -> str: + """Generate prompt end event.""" + return json.dumps({"event": {"promptEnd": {"promptName": self.prompt_name}}}) + + def _get_connection_end_event(self) -> str: + """Generate connection end event.""" + return json.dumps({"event": {"connectionEnd": {}}}) + + async def _send_nova_event(self, event: str) -> None: + """Send event JSON string to Nova Sonic stream.""" + try: + # Event is already a JSON string + bytes_data = event.encode("utf-8") + chunk = InvokeModelWithBidirectionalStreamInputChunk(value=BidirectionalInputPayloadPart(bytes_=bytes_data)) + await self.stream.input_stream.send(chunk) + logger.debug("Successfully sent Nova Sonic event") + + except Exception as e: + logger.error("Error sending Nova Sonic event: %s", e) + logger.error("Event was: %s", event) + raise + + +class NovaSonicBidirectionalModel(BidirectionalModel): + """Nova Sonic model implementation for bidirectional streaming. + + Provides access to Amazon's Nova Sonic model through the bidirectional + streaming interface, handling AWS authentication and connection management. + """ + + def __init__(self, model_id: str = "amazon.nova-sonic-v1:0", region: str = "us-east-1", **config: any) -> None: + """Initialize Nova Sonic bidirectional model. + + Args: + model_id: Nova Sonic model identifier. + region: AWS region. + **config: Additional configuration. + """ + self.model_id = model_id + self.region = region + self.config = config + self._client = None + + logger.debug("Nova Sonic bidirectional model initialized: %s", model_id) + + async def create_bidirectional_connection( + self, + system_prompt: str | None = None, + tools: list[ToolSpec] | None = None, + messages: Messages | None = None, + **kwargs, + ) -> BidirectionalModelSession: + """Create Nova Sonic bidirectional connection.""" + logger.debug("Nova connection create - starting") + + # Initialize client if needed + if not self._client: + await self._initialize_client() + + # Start Nova Sonic bidirectional stream + try: + stream = await self._client.invoke_model_with_bidirectional_stream( + InvokeModelWithBidirectionalStreamOperationInput(model_id=self.model_id) + ) + + # Create and initialize connection + connection = NovaSonicSession(stream, self.config) + await connection.initialize(system_prompt, tools, messages) + + logger.debug("Nova connection created") + return connection + except Exception as e: + logger.error("Nova connection create error: %s", str(e)) + logger.error("Failed to create Nova Sonic connection: %s", e) + raise + + async def _initialize_client(self) -> None: + """Initialize Nova Sonic client.""" + try: + config = Config( + endpoint_uri=f"https://bedrock-runtime.{self.region}.amazonaws.com", + region=self.region, + aws_credentials_identity_resolver=EnvironmentCredentialsResolver(), + auth_scheme_resolver=HTTPAuthSchemeResolver(), + auth_schemes={"aws.auth#sigv4": SigV4AuthScheme(service="bedrock")}, + ) + + self._client = BedrockRuntimeClient(config=config) + logger.debug("Nova Sonic client initialized") + + except ImportError as e: + logger.error("Nova Sonic dependencies not available: %s", e) + raise + except Exception as e: + logger.error("Error initializing Nova Sonic client: %s", e) + raise diff --git a/src/strands/experimental/bidirectional_streaming/models/openai.py b/src/strands/experimental/bidirectional_streaming/models/openai.py new file mode 100644 index 000000000..dd436d60e --- /dev/null +++ b/src/strands/experimental/bidirectional_streaming/models/openai.py @@ -0,0 +1,502 @@ +"""OpenAI Realtime API provider for Strands bidirectional streaming. + +Provides real-time audio and text communication through OpenAI's Realtime API +with WebSocket connections, voice activity detection, and function calling. +""" + +import asyncio +import base64 +import json +import logging +import uuid +from typing import AsyncIterable + +import websockets +from websockets.client import WebSocketClientProtocol +from websockets.exceptions import ConnectionClosed + +from ....types.content import Messages +from ....types.tools import ToolSpec, ToolUse +from ..types.bidirectional_streaming import ( + AudioInputEvent, + AudioOutputEvent, + BidirectionalConnectionEndEvent, + BidirectionalConnectionStartEvent, + BidirectionalStreamEvent, + TextOutputEvent, + VoiceActivityEvent, +) +from .base_model import BaseModel + +logger = logging.getLogger(__name__) + +# OpenAI Realtime API configuration +OPENAI_REALTIME_URL = "wss://api.openai.com/v1/realtime" +DEFAULT_MODEL = "gpt-realtime" + +AUDIO_FORMAT = {"type": "audio/pcm", "rate": 24000} + +DEFAULT_SESSION_CONFIG = { + "type": "realtime", + "instructions": "You are a helpful assistant. Please speak in English and keep your responses clear and concise.", + "output_modalities": ["audio"], + "audio": { + "input": { + "format": AUDIO_FORMAT, + "turn_detection": { + "type": "server_vad", + "threshold": 0.5, + "prefix_padding_ms": 300, + "silence_duration_ms": 500, + } + }, + "output": {"format": AUDIO_FORMAT, "voice": "alloy"}, + }, +} + + +class OpenAIRealtimeSession(BidirectionalModelSession): + """OpenAI Realtime API session for real-time audio/text streaming. + + Manages WebSocket connection to OpenAI's Realtime API with automatic VAD, + function calling, and event conversion to Strands format. + """ + + def __init__(self, websocket: WebSocketClientProtocol, config: dict[str, any]) -> None: + """Initialize OpenAI Realtime session.""" + self.websocket = websocket + self.config = config + self.session_id = str(uuid.uuid4()) + self._active = True + + self._event_queue = asyncio.Queue() + self._response_task = None + self._function_call_buffer = {} + + logger.debug("OpenAI Realtime session initialized: %s", self.session_id) + + def _require_active(self) -> bool: + """Check if session is active.""" + return self._active + + def _create_text_event(self, text: str, role: str) -> dict[str, any]: + """Create standardized text output event.""" + text_output: TextOutputEvent = {"text": text, "role": role} + return {"textOutput": text_output} + + def _create_voice_activity_event(self, activity_type: str) -> dict[str, any]: + """Create standardized voice activity event.""" + voice_activity: VoiceActivityEvent = {"activityType": activity_type} + return {"voiceActivity": voice_activity} + + + + async def initialize( + self, + system_prompt: str | None = None, + tools: list[ToolSpec] | None = None, + messages: Messages | None = None, + ) -> None: + """Initialize session with configuration.""" + try: + session_config = self._build_session_config(system_prompt, tools) + await self._send_event({"type": "session.update", "session": session_config}) + + if messages: + await self._add_conversation_history(messages) + + self._response_task = asyncio.create_task(self._process_responses()) + logger.info("OpenAI Realtime session initialized successfully") + + except Exception as e: + logger.error("Error during OpenAI Realtime initialization: %s", e) + raise + + def _build_session_config(self, system_prompt: str | None, tools: list[ToolSpec] | None) -> dict: + """Build session configuration for OpenAI Realtime API.""" + config = DEFAULT_SESSION_CONFIG.copy() + + if system_prompt: + config["instructions"] = system_prompt + + if tools: + config["tools"] = self._convert_tools_to_openai_format(tools) + + custom_config = self.config.get("session", {}) + supported_params = { + "type", "output_modalities", "instructions", "voice", "audio", + "tools", "tool_choice", "input_audio_format", "output_audio_format", + "input_audio_transcription", "turn_detection" + } + + for key, value in custom_config.items(): + if key in supported_params: + config[key] = value + else: + logger.warning("Ignoring unsupported session parameter: %s", key) + + return config + + def _convert_tools_to_openai_format(self, tools: list[ToolSpec]) -> list[dict]: + """Convert Strands tool specifications to OpenAI Realtime API format.""" + openai_tools = [] + + for tool in tools: + input_schema = tool["inputSchema"] + if "json" in input_schema: + schema = json.loads(input_schema["json"]) if isinstance(input_schema["json"], str) else input_schema["json"] + else: + schema = input_schema + + # OpenAI Realtime API expects flat structure, not nested under "function" + openai_tool = { + "type": "function", + "name": tool["name"], + "description": tool["description"], + "parameters": schema + } + openai_tools.append(openai_tool) + + return openai_tools + + async def _add_conversation_history(self, messages: Messages) -> None: + """Add conversation history to the session.""" + for message in messages: + conversation_item = { + "type": "conversation.item.create", + "item": {"type": "message", "role": message["role"], "content": []} + } + + content = message.get("content", "") + if isinstance(content, str): + conversation_item["item"]["content"].append({"type": "input_text", "text": content}) + elif isinstance(content, list): + for item in content: + if isinstance(item, dict) and item.get("type") == "text": + conversation_item["item"]["content"].append({"type": "input_text", "text": item.get("text", "")}) + + await self._send_event(conversation_item) + + async def _process_responses(self) -> None: + """Process incoming WebSocket messages.""" + logger.debug("OpenAI Realtime response processor started") + + try: + async for message in self.websocket: + if not self._active: + break + + try: + event = json.loads(message) + await self._event_queue.put(event) + except json.JSONDecodeError as e: + logger.warning("Failed to parse OpenAI event: %s", e) + continue + + except ConnectionClosed: + logger.debug("OpenAI Realtime WebSocket connection closed") + except Exception as e: + logger.error("Error in OpenAI Realtime response processing: %s", e) + finally: + self._active = False + logger.debug("OpenAI Realtime response processor stopped") + + async def receive_events(self) -> AsyncIterable[BidirectionalStreamEvent]: + """Receive OpenAI events and convert to Strands format.""" + connection_start: BidirectionalConnectionStartEvent = { + "connectionId": self.session_id, + "metadata": {"provider": "openai_realtime", "model": self.config.get("model", DEFAULT_MODEL)}, + } + yield {"BidirectionalConnectionStart": connection_start} + + try: + while self._active: + try: + openai_event = await asyncio.wait_for(self._event_queue.get(), timeout=1.0) + provider_event = self._convert_openai_event(openai_event) + if provider_event: + yield provider_event + except asyncio.TimeoutError: + continue + + except Exception as e: + logger.error("Error receiving OpenAI Realtime event: %s", e) + finally: + connection_end: BidirectionalConnectionEndEvent = { + "connectionId": self.session_id, + "reason": "connection_complete", + "metadata": {"provider": "openai_realtime"}, + } + yield {"BidirectionalConnectionEnd": connection_end} + + def _convert_openai_event(self, openai_event: dict[str, any]) -> dict[str, any] | None: + """Convert OpenAI events to Strands format.""" + event_type = openai_event.get("type") + + # Audio output + if event_type == "response.output_audio.delta": + audio_data = base64.b64decode(openai_event["delta"]) + audio_output: AudioOutputEvent = { + "audioData": audio_data, + "format": "pcm", + "sampleRate": 24000, + "channels": 1, + "encoding": None, + } + return {"audioOutput": audio_output} + + # Assistant text output events - combine multiple similar events + elif event_type in ["response.output_text.delta", "response.output_audio_transcript.delta"]: + return self._create_text_event(openai_event["delta"], "assistant") + + # User transcription events - combine multiple similar events + elif event_type in ["conversation.item.input_audio_transcription.delta", + "conversation.item.input_audio_transcription.completed"]: + text_key = "delta" if "delta" in event_type else "transcript" + text = openai_event.get(text_key, "") + return self._create_text_event(text, "user") if text.strip() else None + + elif event_type == "conversation.item.input_audio_transcription.segment": + segment_data = openai_event.get("segment", {}) + text = segment_data.get("text", "") + return self._create_text_event(text, "user") if text.strip() else None + + elif event_type == "conversation.item.input_audio_transcription.failed": + error_info = openai_event.get("error", {}) + logger.warning("OpenAI transcription failed: %s", error_info.get("message", "Unknown error")) + return None + + # Function call processing + elif event_type == "response.function_call_arguments.delta": + call_id = openai_event.get("call_id") + delta = openai_event.get("delta", "") + if call_id: + if call_id not in self._function_call_buffer: + self._function_call_buffer[call_id] = {"call_id": call_id, "name": "", "arguments": delta} + else: + self._function_call_buffer[call_id]["arguments"] += delta + return None + + elif event_type == "response.function_call_arguments.done": + call_id = openai_event.get("call_id") + if call_id and call_id in self._function_call_buffer: + function_call = self._function_call_buffer[call_id] + try: + tool_use: ToolUse = { + "toolUseId": call_id, + "name": function_call["name"], + "input": json.loads(function_call["arguments"]) if function_call["arguments"] else {}, + } + del self._function_call_buffer[call_id] + return {"toolUse": tool_use} + except (json.JSONDecodeError, KeyError) as e: + logger.warning("Error parsing function arguments for %s: %s", call_id, e) + del self._function_call_buffer[call_id] + return None + + # Voice activity detection events - combine similar events using mapping + elif event_type in ["input_audio_buffer.speech_started", "input_audio_buffer.speech_stopped", + "input_audio_buffer.timeout_triggered"]: + # Map event types to activity types + activity_map = { + "input_audio_buffer.speech_started": "speech_started", + "input_audio_buffer.speech_stopped": "speech_stopped", + "input_audio_buffer.timeout_triggered": "timeout" + } + return self._create_voice_activity_event(activity_map[event_type]) + + # Lifecycle events (log only) - combine multiple similar events + elif event_type in ["conversation.item.retrieve", "conversation.item.added"]: + item = openai_event.get("item", {}) + action = "retrieved" if "retrieve" in event_type else "added" + logger.debug("OpenAI conversation item %s: %s", action, item.get("id")) + return None + + elif event_type == "conversation.item.done": + logger.debug("OpenAI conversation item done: %s", openai_event.get("item", {}).get("id")) + + item = openai_event.get("item", {}) + if item.get("type") == "message" and item.get("role") == "assistant": + content_parts = item.get("content", []) + if content_parts: + message_content = [] + for content_part in content_parts: + if content_part.get("type") == "output_text": + message_content.append({"type": "text", "text": content_part.get("text", "")}) + elif content_part.get("type") == "output_audio": + transcript = content_part.get("transcript", "") + if transcript: + message_content.append({"type": "text", "text": transcript}) + + if message_content: + message = {"role": "assistant", "content": message_content} + return {"messageStop": {"message": message}} + return None + + # Response output events - combine similar events + elif event_type in ["response.output_item.added", "response.output_item.done", + "response.content_part.added", "response.content_part.done"]: + item_data = openai_event.get("item") or openai_event.get("part") + logger.debug("OpenAI %s: %s", event_type, item_data.get("id") if item_data else "unknown") + + # Track function call names from response.output_item.added + if event_type == "response.output_item.added": + item = openai_event.get("item", {}) + if item.get("type") == "function_call": + call_id = item.get("call_id") + function_name = item.get("name") + if call_id and function_name: + if call_id not in self._function_call_buffer: + self._function_call_buffer[call_id] = {"call_id": call_id, "name": function_name, "arguments": ""} + else: + self._function_call_buffer[call_id]["name"] = function_name + return None + + # Session/buffer events - combine simple log-only events + elif event_type in ["input_audio_buffer.committed", "input_audio_buffer.cleared", + "session.created", "session.updated"]: + logger.debug("OpenAI %s event", event_type) + return None + + elif event_type == "error": + logger.error("OpenAI Realtime error: %s", openai_event.get("error", {})) + return None + + else: + logger.debug("Unhandled OpenAI event type: %s", event_type) + return None + + async def send_audio_content(self, audio_input: AudioInputEvent) -> None: + """Send audio content to OpenAI for processing.""" + if not self._require_active(): + return + + audio_base64 = base64.b64encode(audio_input["audioData"]).decode("utf-8") + await self._send_event({"type": "input_audio_buffer.append", "audio": audio_base64}) + + async def send_text_content(self, text: str) -> None: + """Send text content to OpenAI for processing.""" + if not self._require_active(): + return + + item_data = { + "type": "message", + "role": "user", + "content": [{"type": "input_text", "text": text}] + } + await self._send_event({"type": "conversation.item.create", "item": item_data}) + await self._send_event({"type": "response.create"}) + + async def send_interrupt(self) -> None: + """Send interruption signal to OpenAI.""" + if not self._require_active(): + return + + await self._send_event({"type": "response.cancel"}) + + async def send_tool_result(self, tool_use_id: str, result: dict[str, any]) -> None: + """Send tool result back to OpenAI.""" + if not self._require_active(): + return + + logger.debug("OpenAI tool result send: %s", tool_use_id) + result_text = json.dumps(result) if not isinstance(result, str) else result + + item_data = { + "type": "function_call_output", + "call_id": tool_use_id, + "output": result_text + } + await self._send_event({"type": "conversation.item.create", "item": item_data}) + await self._send_event({"type": "response.create"}) + + async def close(self) -> None: + """Close session and cleanup resources.""" + if not self._active: + return + + logger.debug("OpenAI Realtime cleanup - starting connection close") + self._active = False + + if self._response_task and not self._response_task.done(): + self._response_task.cancel() + try: + await self._response_task + except asyncio.CancelledError: + pass + + try: + await self.websocket.close() + except Exception as e: + logger.warning("Error closing OpenAI Realtime WebSocket: %s", e) + + logger.debug("OpenAI Realtime connection closed") + + async def _send_event(self, event: dict[str, any]) -> None: + """Send event to OpenAI via WebSocket.""" + try: + message = json.dumps(event) + await self.websocket.send(message) + logger.debug("Sent OpenAI event: %s", event.get("type")) + except Exception as e: + logger.error("Error sending OpenAI event: %s", e) + raise + + +class OpenAIRealtimeBidirectionalModel(BidirectionalModel): + """OpenAI Realtime API provider for Strands bidirectional streaming. + + Provides real-time audio/text communication through OpenAI's Realtime API + with WebSocket connections, voice activity detection, and function calling. + """ + + def __init__( + self, + model: str = DEFAULT_MODEL, + api_key: str | None = None, + **config: any + ) -> None: + """Initialize OpenAI Realtime bidirectional model.""" + self.model = model + self.api_key = api_key + self.config = config + + import os + if not self.api_key: + self.api_key = os.getenv("OPENAI_API_KEY") + if not self.api_key: + raise ValueError("OpenAI API key is required. Set OPENAI_API_KEY environment variable or pass api_key parameter.") + + logger.debug("OpenAI Realtime bidirectional model initialized: %s", model) + + async def create_bidirectional_connection( + self, + system_prompt: str | None = None, + tools: list[ToolSpec] | None = None, + messages: Messages | None = None, + **kwargs, + ) -> BidirectionalModelSession: + """Create bidirectional connection to OpenAI Realtime API.""" + logger.info("Creating OpenAI Realtime connection...") + + try: + url = f"{OPENAI_REALTIME_URL}?model={self.model}" + + headers = [("Authorization", f"Bearer {self.api_key}")] + if "organization" in self.config: + headers.append(("OpenAI-Organization", self.config["organization"])) + if "project" in self.config: + headers.append(("OpenAI-Project", self.config["project"])) + + websocket = await websockets.connect(url, additional_headers=headers) + logger.info("WebSocket connected successfully") + + session = OpenAIRealtimeSession(websocket, self.config) + await session.initialize(system_prompt, tools, messages) + + logger.info("OpenAI Realtime connection established") + return session + + except Exception as e: + logger.error("OpenAI connection error: %s", e) + raise \ No newline at end of file diff --git a/src/strands/experimental/bidirectional_streaming/tests/test_bidi_novasonic.py b/src/strands/experimental/bidirectional_streaming/tests/test_bidi_novasonic.py new file mode 100644 index 000000000..8c3ae3b4c --- /dev/null +++ b/src/strands/experimental/bidirectional_streaming/tests/test_bidi_novasonic.py @@ -0,0 +1,225 @@ +"""Test suite for bidirectional streaming with real-time audio interaction. + +Tests the complete bidirectional streaming system including audio input/output, +interruption handling, and concurrent tool execution using Nova Sonic. +""" + +import asyncio +import sys +from pathlib import Path + +# Add the src directory to Python path +sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent.parent)) +import os +import time + +import pyaudio +from strands_tools import calculator + +from strands.experimental.bidirectional_streaming.agent.agent import BidirectionalAgent +from strands.experimental.bidirectional_streaming.models.novasonic import NovaSonicBidirectionalModel + + +def test_direct_tools(): + """Test direct tool calling.""" + print("Testing direct tool calling...") + + # Check AWS credentials + if not all([os.getenv("AWS_ACCESS_KEY_ID"), os.getenv("AWS_SECRET_ACCESS_KEY")]): + print("AWS credentials not set - skipping test") + return + + try: + model = NovaSonicBidirectionalModel() + agent = BidirectionalAgent(model=model, tools=[calculator]) + + # Test calculator + result = agent.tool.calculator(expression="2 * 3") + content = result.get("content", [{}])[0].get("text", "") + print(f"Result: {content}") + print("Test completed") + + except Exception as e: + print(f"Test failed: {e}") + + +async def play(context): + """Play audio output with responsive interruption support.""" + audio = pyaudio.PyAudio() + speaker = audio.open( + channels=1, + format=pyaudio.paInt16, + output=True, + rate=24000, + frames_per_buffer=1024, + ) + + try: + while context["active"]: + try: + # Check for interruption first + if context.get("interrupted", False): + # Clear entire audio queue immediately + while not context["audio_out"].empty(): + try: + context["audio_out"].get_nowait() + except asyncio.QueueEmpty: + break + + context["interrupted"] = False + await asyncio.sleep(0.05) + continue + + # Get next audio data + audio_data = await asyncio.wait_for(context["audio_out"].get(), timeout=0.1) + + if audio_data and context["active"]: + chunk_size = 1024 + for i in range(0, len(audio_data), chunk_size): + # Check for interruption before each chunk + if context.get("interrupted", False) or not context["active"]: + break + + end = min(i + chunk_size, len(audio_data)) + chunk = audio_data[i:end] + speaker.write(chunk) + await asyncio.sleep(0.001) + + except asyncio.TimeoutError: + continue # No audio available + except asyncio.QueueEmpty: + await asyncio.sleep(0.01) + except asyncio.CancelledError: + break + + except asyncio.CancelledError: + pass + finally: + speaker.close() + audio.terminate() + + +async def record(context): + """Record audio input from microphone.""" + audio = pyaudio.PyAudio() + microphone = audio.open( + channels=1, + format=pyaudio.paInt16, + frames_per_buffer=1024, + input=True, + rate=16000, + ) + + try: + while context["active"]: + try: + audio_bytes = microphone.read(1024, exception_on_overflow=False) + context["audio_in"].put_nowait(audio_bytes) + await asyncio.sleep(0.01) + except asyncio.CancelledError: + break + except asyncio.CancelledError: + pass + finally: + microphone.close() + audio.terminate() + + +async def receive(agent, context): + """Receive and process events from agent.""" + try: + async for event in agent.receive(): + # Handle audio output + if "audioOutput" in event: + if not context.get("interrupted", False): + context["audio_out"].put_nowait(event["audioOutput"]["audioData"]) + + # Handle interruption events + elif "interruptionDetected" in event: + context["interrupted"] = True + elif "interrupted" in event: + context["interrupted"] = True + + # Handle text output with interruption detection + elif "textOutput" in event: + text_content = event["textOutput"].get("content", "") + role = event["textOutput"].get("role", "unknown") + + # Check for text-based interruption patterns + if '{ "interrupted" : true }' in text_content: + context["interrupted"] = True + elif "interrupted" in text_content.lower(): + context["interrupted"] = True + + # Log text output + if role.upper() == "USER": + print(f"User: {text_content}") + elif role.upper() == "ASSISTANT": + print(f"Assistant: {text_content}") + + except asyncio.CancelledError: + pass + + +async def send(agent, context): + """Send audio input to agent.""" + try: + while time.time() - context["start_time"] < context["duration"]: + try: + audio_bytes = context["audio_in"].get_nowait() + audio_event = {"audioData": audio_bytes, "format": "pcm", "sampleRate": 16000, "channels": 1} + await agent.send(audio_event) + except asyncio.QueueEmpty: + await asyncio.sleep(0.01) # Restored to working timing + except asyncio.CancelledError: + break + + context["active"] = False + except asyncio.CancelledError: + pass + + +async def main(duration=180): + """Main function for bidirectional streaming test.""" + print("Starting bidirectional streaming test...") + print("Audio optimizations: 1024-byte buffers, balanced smooth playback + responsive interruption") + + # Initialize model and agent + model = NovaSonicBidirectionalModel(region="us-east-1") + agent = BidirectionalAgent(model=model, tools=[calculator], system_prompt="You are a helpful assistant.") + + await agent.start() + + # Create shared context for all tasks + context = { + "active": True, + "audio_in": asyncio.Queue(), + "audio_out": asyncio.Queue(), + "connection": agent._session, + "duration": duration, + "start_time": time.time(), + "interrupted": False, + } + + print("Speak into microphone. Press Ctrl+C to exit.") + + try: + # Run all tasks concurrently + await asyncio.gather( + play(context), record(context), receive(agent, context), send(agent, context), return_exceptions=True + ) + except KeyboardInterrupt: + print("\nInterrupted by user") + except asyncio.CancelledError: + print("\nTest cancelled") + finally: + print("Cleaning up...") + context["active"] = False + await agent.end() + + +if __name__ == "__main__": + # Test direct tool calling first + test_direct_tools() + + asyncio.run(main()) diff --git a/src/strands/experimental/bidirectional_streaming/tests/test_bidi_openai.py b/src/strands/experimental/bidirectional_streaming/tests/test_bidi_openai.py new file mode 100644 index 000000000..660040f3e --- /dev/null +++ b/src/strands/experimental/bidirectional_streaming/tests/test_bidi_openai.py @@ -0,0 +1,287 @@ +#!/usr/bin/env python3 +"""Test OpenAI Realtime API speech-to-speech interaction.""" + +import asyncio +import os +import sys +import time +from pathlib import Path + +# Add the src directory to Python path +sys.path.insert(0, str(Path(__file__).parent / "src")) + +import pyaudio +from strands_tools import calculator + +from strands.experimental.bidirectional_streaming.agent.agent import BidirectionalAgent +from strands.experimental.bidirectional_streaming.models.openai import OpenAIRealtimeBidirectionalModel + + +async def play(context): + """Handle audio playback with interruption support.""" + audio = pyaudio.PyAudio() + + try: + speaker = audio.open( + format=pyaudio.paInt16, + channels=1, + rate=24000, # OpenAI Realtime uses 24kHz + output=True, + frames_per_buffer=1024, + ) + + while context["active"]: + try: + # Check for interruption + if context.get("interrupted", False): + # Clear audio queue on interruption + while not context["audio_out"].empty(): + try: + context["audio_out"].get_nowait() + except asyncio.QueueEmpty: + break + + context["interrupted"] = False + await asyncio.sleep(0.05) + continue + + # Get audio data with timeout + try: + audio_data = await asyncio.wait_for(context["audio_out"].get(), timeout=0.1) + + if audio_data and context["active"]: + # Play in chunks to allow interruption + chunk_size = 1024 + for i in range(0, len(audio_data), chunk_size): + if context.get("interrupted", False) or not context["active"]: + break + + chunk = audio_data[i:i + chunk_size] + speaker.write(chunk) + await asyncio.sleep(0.001) # Brief pause for responsiveness + + except asyncio.TimeoutError: + continue + + except asyncio.CancelledError: + break + + except asyncio.CancelledError: + pass + except Exception as e: + print(f"Audio playback error: {e}") + finally: + try: + speaker.close() + except: + pass + audio.terminate() + + +async def record(context): + """Handle microphone recording.""" + audio = pyaudio.PyAudio() + + try: + microphone = audio.open( + format=pyaudio.paInt16, + channels=1, + rate=24000, # Match OpenAI's expected input rate + input=True, + frames_per_buffer=1024, + ) + + while context["active"]: + try: + audio_bytes = microphone.read(1024, exception_on_overflow=False) + await context["audio_in"].put(audio_bytes) + await asyncio.sleep(0.01) + except asyncio.CancelledError: + break + + except asyncio.CancelledError: + pass + except Exception as e: + print(f"Microphone recording error: {e}") + finally: + try: + microphone.close() + except: + pass + audio.terminate() + + +async def receive(agent, context): + """Handle events from the agent.""" + try: + async for event in agent.receive(): + if not context["active"]: + break + + # Handle audio output + if "audioOutput" in event: + audio_data = event["audioOutput"]["audioData"] + + if not context.get("interrupted", False): + await context["audio_out"].put(audio_data) + + # Handle text output (transcripts) + elif "textOutput" in event: + text_output = event["textOutput"] + role = text_output.get("role", "assistant") + text = text_output.get("text", "").strip() + + if text: + if role == "user": + print(f"User: {text}") + elif role == "assistant": + print(f"Assistant: {text}") + + # Handle interruption detection + elif "interruptionDetected" in event: + context["interrupted"] = True + + # Handle connection events + elif "BidirectionalConnectionStart" in event: + pass # Silent connection start + elif "BidirectionalConnectionEnd" in event: + context["active"] = False + break + + except asyncio.CancelledError: + pass + except Exception as e: + print(f"Receive handler error: {e}") + finally: + pass + + +async def send(agent, context): + """Send audio from microphone to agent.""" + try: + while context["active"]: + try: + audio_bytes = await asyncio.wait_for(context["audio_in"].get(), timeout=0.1) + + # Create audio event in expected format + audio_event = { + "audioData": audio_bytes, + "format": "pcm", + "sampleRate": 24000, + "channels": 1 + } + + await agent.send(audio_event) + + except asyncio.TimeoutError: + continue + except asyncio.CancelledError: + break + + except asyncio.CancelledError: + pass + except Exception as e: + print(f"Send handler error: {e}") + finally: + pass + + +async def main(): + """Main test function for OpenAI voice chat.""" + print("Starting OpenAI Realtime API test...") + + # Check API key + api_key = os.getenv("OPENAI_API_KEY") + if not api_key: + print("OPENAI_API_KEY environment variable not set") + return False + + # Check audio system + try: + audio = pyaudio.PyAudio() + audio.terminate() + except Exception as e: + print(f"Audio system error: {e}") + return False + + # Create OpenAI model + model = OpenAIRealtimeBidirectionalModel( + model="gpt-4o-realtime-preview", + api_key=api_key, + session={ + "output_modalities": ["audio"], + "audio": { + "input": { + "format": {"type": "audio/pcm", "rate": 24000}, + "turn_detection": { + "type": "server_vad", + "threshold": 0.5, + "silence_duration_ms": 700 + } + }, + "output": { + "format": {"type": "audio/pcm", "rate": 24000}, + "voice": "alloy" + } + } + } + ) + + # Create agent + agent = BidirectionalAgent( + model=model, + tools=[calculator], + system_prompt="You are a helpful voice assistant. Keep your responses brief and natural. Say hello when you first connect." + ) + + # Start the session + await agent.start() + + # Create shared context + context = { + "active": True, + "audio_in": asyncio.Queue(), + "audio_out": asyncio.Queue(), + "interrupted": False, + "start_time": time.time() + } + + print("Speak into your microphone. Press Ctrl+C to stop.") + + try: + # Run all tasks concurrently + await asyncio.gather( + play(context), + record(context), + receive(agent, context), + send(agent, context), + return_exceptions=True + ) + + except KeyboardInterrupt: + print("\nInterrupted by user") + except asyncio.CancelledError: + print("\nTest cancelled") + except Exception as e: + print(f"\nError during voice chat: {e}") + finally: + print("Cleaning up...") + context["active"] = False + + try: + await agent.end() + except Exception as e: + print(f"Cleanup error: {e}") + + return True + + +if __name__ == "__main__": + try: + asyncio.run(main()) + except KeyboardInterrupt: + print("\nTest interrupted by user") + except Exception as e: + print(f"Test error: {e}") + import traceback + traceback.print_exc() \ No newline at end of file diff --git a/src/strands/experimental/bidirectional_streaming/tests/test_gemini_live.py b/src/strands/experimental/bidirectional_streaming/tests/test_gemini_live.py new file mode 100644 index 000000000..4469e819a --- /dev/null +++ b/src/strands/experimental/bidirectional_streaming/tests/test_gemini_live.py @@ -0,0 +1,359 @@ +"""Test suite for Gemini Live bidirectional streaming with camera support. + +Tests the Gemini Live API with real-time audio and video interaction including: +- Audio input/output streaming +- Camera frame capture and transmission +- Interruption handling +- Concurrent tool execution +- Transcript events + +Requirements: +- pip install opencv-python pillow pyaudio google-genai +- Camera access permissions +- GOOGLE_AI_API_KEY environment variable +""" + +import asyncio +import base64 +import io +import logging +import os +import sys +from pathlib import Path + +# Add the src directory to Python path +sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent.parent)) +import time + +try: + import cv2 + import PIL.Image + CAMERA_AVAILABLE = True +except ImportError as e: + print(f"Camera dependencies not available: {e}") + print("Install with: pip install opencv-python pillow") + CAMERA_AVAILABLE = False + +import pyaudio +from strands_tools import calculator + +from strands.experimental.bidirectional_streaming.agent.agent import BidirectionalAgent +from strands.experimental.bidirectional_streaming.models.gemini_live import GeminiLiveBidirectionalModel + +# Configure logging - debug only for Gemini Live, info for everything else +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') +gemini_logger = logging.getLogger('strands.experimental.bidirectional_streaming.models.gemini_live') +gemini_logger.setLevel(logging.DEBUG) +logger = logging.getLogger(__name__) + + +async def play(context): + """Play audio output with responsive interruption support.""" + audio = pyaudio.PyAudio() + speaker = audio.open( + channels=1, + format=pyaudio.paInt16, + output=True, + rate=24000, + frames_per_buffer=1024, + ) + + try: + while context["active"]: + try: + # Check for interruption first + if context.get("interrupted", False): + # Clear entire audio queue immediately + while not context["audio_out"].empty(): + try: + context["audio_out"].get_nowait() + except asyncio.QueueEmpty: + break + + context["interrupted"] = False + await asyncio.sleep(0.05) + continue + + # Get next audio data + audio_data = await asyncio.wait_for(context["audio_out"].get(), timeout=0.1) + + if audio_data and context["active"]: + chunk_size = 1024 + for i in range(0, len(audio_data), chunk_size): + # Check for interruption before each chunk + if context.get("interrupted", False) or not context["active"]: + break + + end = min(i + chunk_size, len(audio_data)) + chunk = audio_data[i:end] + speaker.write(chunk) + await asyncio.sleep(0.001) + + except asyncio.TimeoutError: + continue # No audio available + except asyncio.QueueEmpty: + await asyncio.sleep(0.01) + except asyncio.CancelledError: + break + + except asyncio.CancelledError: + pass + finally: + speaker.close() + audio.terminate() + + +async def record(context): + """Record audio input from microphone.""" + audio = pyaudio.PyAudio() + + # List all available audio devices + print("Available audio devices:") + for i in range(audio.get_device_count()): + device_info = audio.get_device_info_by_index(i) + if device_info['maxInputChannels'] > 0: # Only show input devices + print(f" Device {i}: {device_info['name']} (inputs: {device_info['maxInputChannels']})") + + # Get default input device info + default_device = audio.get_default_input_device_info() + print(f"\nUsing default input device: {default_device['name']} (Device {default_device['index']})") + + microphone = audio.open( + channels=1, + format=pyaudio.paInt16, + frames_per_buffer=1024, + input=True, + rate=16000, + ) + + try: + while context["active"]: + try: + audio_bytes = microphone.read(1024, exception_on_overflow=False) + context["audio_in"].put_nowait(audio_bytes) + await asyncio.sleep(0.01) + except asyncio.CancelledError: + break + except asyncio.CancelledError: + pass + finally: + microphone.close() + audio.terminate() + + +async def receive(agent, context): + """Receive and process events from agent.""" + try: + async for event in agent.receive(): + # Debug: Log all event types + event_types = [k for k in event.keys() if not k.startswith('_')] + if event_types: + logger.debug(f"Received event types: {event_types}") + + # Handle audio output + if "audioOutput" in event: + if not context.get("interrupted", False): + context["audio_out"].put_nowait(event["audioOutput"]["audioData"]) + + # Handle interruption events + elif "interruptionDetected" in event: + context["interrupted"] = True + elif "interrupted" in event: + context["interrupted"] = True + + # Handle text output + elif "textOutput" in event: + text_content = event["textOutput"].get("text", "") + role = event["textOutput"].get("role", "unknown") + + # Check for text-based interruption patterns + if '{ "interrupted" : true }' in text_content: + context["interrupted"] = True + elif "interrupted" in text_content.lower(): + context["interrupted"] = True + + # Log text output + if role.upper() == "USER": + print(f"User: {text_content}") + elif role.upper() == "ASSISTANT": + print(f"Assistant: {text_content}") + + # Handle transcript events (audio transcriptions) + elif "transcript" in event: + transcript_text = event["transcript"].get("text", "") + transcript_role = event["transcript"].get("role", "unknown") + transcript_type = event["transcript"].get("type", "unknown") + + # Print transcripts with special formatting to distinguish from text output + if transcript_role.upper() == "USER": + print(f"🎤 User (transcript): {transcript_text}") + elif transcript_role.upper() == "ASSISTANT": + print(f"🔊 Assistant (transcript): {transcript_text}") + + # Handle turn complete events + elif "turnComplete" in event: + logger.debug("Turn complete event received - model ready for next input") + # Reset interrupted state since the turn is complete + context["interrupted"] = False + + except asyncio.CancelledError: + pass + + +def _get_frame(cap): + """Capture and process a frame from camera.""" + if not CAMERA_AVAILABLE: + return None + + # Read the frame + ret, frame = cap.read() + # Check if the frame was read successfully + if not ret: + return None + # Convert BGR to RGB color space + # OpenCV captures in BGR but PIL expects RGB format + # This prevents the blue tint in the video feed + frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + img = PIL.Image.fromarray(frame_rgb) + img.thumbnail([1024, 1024]) + + image_io = io.BytesIO() + img.save(image_io, format="jpeg") + image_io.seek(0) + + mime_type = "image/jpeg" + image_bytes = image_io.read() + return {"mime_type": mime_type, "data": base64.b64encode(image_bytes).decode()} + + +async def get_frames(context): + """Capture frames from camera and send to agent.""" + if not CAMERA_AVAILABLE: + print("Camera not available - skipping video capture") + return + + # This takes about a second, and will block the whole program + # causing the audio pipeline to overflow if you don't to_thread it. + cap = await asyncio.to_thread(cv2.VideoCapture, 0) # 0 represents the default camera + + print("Camera initialized. Starting video capture...") + + try: + while context["active"] and time.time() - context["start_time"] < context["duration"]: + frame = await asyncio.to_thread(_get_frame, cap) + if frame is None: + break + + # Send frame to agent as image input + try: + image_event = { + "imageData": frame["data"], + "mimeType": frame["mime_type"], + "encoding": "base64" + } + await context["agent"].send(image_event) + print("📸 Frame sent to model") + except Exception as e: + logger.error(f"Error sending frame: {e}") + + # Wait 1 second between frames (1 FPS) + await asyncio.sleep(1.0) + + except asyncio.CancelledError: + pass + finally: + # Release the VideoCapture object + cap.release() + + +async def send(agent, context): + """Send audio input to agent.""" + try: + while time.time() - context["start_time"] < context["duration"]: + try: + audio_bytes = context["audio_in"].get_nowait() + audio_event = {"audioData": audio_bytes, "format": "pcm", "sampleRate": 16000, "channels": 1} + await agent.send(audio_event) + except asyncio.QueueEmpty: + await asyncio.sleep(0.01) + except asyncio.CancelledError: + break + + context["active"] = False + except asyncio.CancelledError: + pass + + +async def main(duration=180): + """Main function for Gemini Live bidirectional streaming test with camera support.""" + print("Starting Gemini Live bidirectional streaming test with camera...") + print("Audio optimizations: 1024-byte buffers, balanced smooth playback + responsive interruption") + print("Video: Camera frames sent at 1 FPS to model") + + # Get API key from environment variable + api_key = os.getenv("GOOGLE_AI_API_KEY") + + if not api_key: + print("ERROR: GOOGLE_AI_API_KEY environment variable not set") + print("Please set it with: export GOOGLE_AI_API_KEY=your_api_key") + return + + # Initialize Gemini Live model with proper configuration + logger.info("Initializing Gemini Live model with API key") + + model = GeminiLiveBidirectionalModel( + model_id="gemini-2.5-flash-native-audio-preview-09-2025", + api_key=api_key, + params={ + "response_modalities": ["AUDIO"], + "output_audio_transcription": {}, # Enable output transcription + "input_audio_transcription": {} # Enable input transcription + } + ) + logger.info("Gemini Live model initialized successfully") + print("Using Gemini Live model") + + agent = BidirectionalAgent( + model=model, + tools=[calculator], + system_prompt="You are a helpful assistant." + ) + + await agent.start() + + # Create shared context for all tasks + context = { + "active": True, + "audio_in": asyncio.Queue(), + "audio_out": asyncio.Queue(), + "connection": agent._session, + "duration": duration, + "start_time": time.time(), + "interrupted": False, + "agent": agent, # Add agent reference for camera task + } + + print("Speak into microphone and show things to camera. Press Ctrl+C to exit.") + + try: + # Run all tasks concurrently including camera + await asyncio.gather( + play(context), + record(context), + receive(agent, context), + send(agent, context), + get_frames(context), # Add camera task + return_exceptions=True + ) + except KeyboardInterrupt: + print("\nInterrupted by user") + except asyncio.CancelledError: + print("\nTest cancelled") + finally: + print("Cleaning up...") + context["active"] = False + await agent.end() + + +if __name__ == "__main__": + asyncio.run(main()) \ No newline at end of file diff --git a/src/strands/experimental/bidirectional_streaming/types/__init__.py b/src/strands/experimental/bidirectional_streaming/types/__init__.py new file mode 100644 index 000000000..d040ee436 --- /dev/null +++ b/src/strands/experimental/bidirectional_streaming/types/__init__.py @@ -0,0 +1,39 @@ +"""Type definitions for bidirectional streaming.""" + +from .bidirectional_streaming import ( + DEFAULT_CHANNELS, + DEFAULT_SAMPLE_RATE, + SUPPORTED_AUDIO_FORMATS, + SUPPORTED_CHANNELS, + SUPPORTED_SAMPLE_RATES, + AudioInputEvent, + AudioOutputEvent, + BidirectionalConnectionEndEvent, + BidirectionalConnectionStartEvent, + BidirectionalStreamEvent, + ImageInputEvent, + InterruptionDetectedEvent, + TextOutputEvent, + TranscriptEvent, + UsageMetricsEvent, + VoiceActivityEvent, +) + +__all__ = [ + "AudioInputEvent", + "AudioOutputEvent", + "BidirectionalConnectionEndEvent", + "BidirectionalConnectionStartEvent", + "BidirectionalStreamEvent", + "ImageInputEvent", + "InterruptionDetectedEvent", + "TextOutputEvent", + "TranscriptEvent", + "UsageMetricsEvent", + "VoiceActivityEvent", + "SUPPORTED_AUDIO_FORMATS", + "SUPPORTED_SAMPLE_RATES", + "SUPPORTED_CHANNELS", + "DEFAULT_SAMPLE_RATE", + "DEFAULT_CHANNELS", +] diff --git a/src/strands/experimental/bidirectional_streaming/types/bidirectional_streaming.py b/src/strands/experimental/bidirectional_streaming/types/bidirectional_streaming.py new file mode 100644 index 000000000..73f86a469 --- /dev/null +++ b/src/strands/experimental/bidirectional_streaming/types/bidirectional_streaming.py @@ -0,0 +1,246 @@ +"""Bidirectional streaming types for real-time audio/text conversations. + +Type definitions for bidirectional streaming that extends Strands' existing streaming +capabilities with real-time audio and persistent connection support. + +Key features: +- Audio input/output events with standardized formats +- Interruption detection and handling +- connection lifecycle management +- Provider-agnostic event types +- Backwards compatibility with existing StreamEvent types + +Audio format normalization: +- Supports PCM, WAV, Opus, and MP3 formats +- Standardizes sample rates (16kHz, 24kHz, 48kHz) +- Normalizes channel configurations (mono/stereo) +- Abstracts provider-specific encodings +""" + +from typing import Any, Dict, Literal, Optional + +from typing_extensions import TypedDict + +from ....types.content import Role +from ....types.streaming import StreamEvent + +# Audio format constants +SUPPORTED_AUDIO_FORMATS = ["pcm", "wav", "opus", "mp3"] +SUPPORTED_SAMPLE_RATES = [16000, 24000, 48000] +SUPPORTED_CHANNELS = [1, 2] # 1=mono, 2=stereo +DEFAULT_SAMPLE_RATE = 16000 +DEFAULT_CHANNELS = 1 + + +class AudioOutputEvent(TypedDict): + """Audio output event from the model. + + Provides standardized audio output format across different providers using + raw bytes instead of provider-specific encodings. + + Attributes: + audioData: Raw audio bytes (not base64 or hex encoded). + format: Audio format from SUPPORTED_AUDIO_FORMATS. + sampleRate: Sample rate from SUPPORTED_SAMPLE_RATES. + channels: Channel count from SUPPORTED_CHANNELS. + encoding: Original provider encoding for debugging purposes. + """ + + audioData: bytes + format: Literal["pcm", "wav", "opus", "mp3"] + sampleRate: Literal[16000, 24000, 48000] + channels: Literal[1, 2] + encoding: Optional[str] + + +class AudioInputEvent(TypedDict): + """Audio input event for sending audio to the model. + + Used for sending audio data through the send() method. + + Attributes: + audioData: Raw audio bytes to send to model. + format: Audio format from SUPPORTED_AUDIO_FORMATS. + sampleRate: Sample rate from SUPPORTED_SAMPLE_RATES. + channels: Channel count from SUPPORTED_CHANNELS. + """ + + audioData: bytes + format: Literal["pcm", "wav", "opus", "mp3"] + sampleRate: Literal[16000, 24000, 48000] + channels: Literal[1, 2] + + +class ImageInputEvent(TypedDict): + """Image input event for sending images/video frames to the model. + + Used for sending image data through the send() method. Supports both + raw image bytes and base64-encoded data. + + Attributes: + imageData: Image bytes (raw or base64-encoded string). + mimeType: MIME type (e.g., "image/jpeg", "image/png"). + encoding: How the imageData is encoded. + """ + + imageData: bytes | str + mimeType: str + encoding: Literal["base64", "raw"] + +class TextInputEvent(TypedDict): + """Text input event for sending text messages to the model. + + Used for sending text messages through the send() method. + + Attributes: + text: The text content to send to the model. + """ + + text: str + +class TextOutputEvent(TypedDict): + """Text output event from the model during bidirectional streaming. + + Attributes: + text: The text content from the model. + role: The role of the message sender. + """ + + text: str + role: Role + + +class TranscriptEvent(TypedDict): + """Transcript event for audio transcriptions. + + Used for both input transcriptions (user speech) and output transcriptions + (model audio). These are informational and separate from actual text responses. + + Attributes: + text: The transcribed text. + role: The role of the speaker ("user" or "assistant"). + type: Type of transcription ("input" or "output"). + """ + + text: str + role: Role + type: Literal["input", "output"] + + +class InterruptionDetectedEvent(TypedDict): + """Interruption detection event. + + Signals when user interruption is detected during model generation. + + Attributes: + reason: Interruption reason from predefined set. + """ + + reason: Literal["user_input", "vad_detected", "manual"] + + +class BidirectionalConnectionStartEvent(TypedDict, total=False): + """connection start event for bidirectional streaming. + + Attributes: + connectionId: Unique connection identifier. + metadata: Provider-specific connection metadata. + """ + + connectionId: Optional[str] + metadata: Optional[Dict[str, Any]] + + +class BidirectionalConnectionEndEvent(TypedDict): + """connection end event for bidirectional streaming. + + Attributes: + reason: Reason for connection end from predefined set. + connectionId: Unique connection identifier. + metadata: Provider-specific connection metadata. + """ + + reason: Literal["user_request", "timeout", "error", "connection_complete"] + connectionId: Optional[str] + metadata: Optional[Dict[str, Any]] + +class UsageMetricsEvent(TypedDict): + """Token usage and performance tracking. + + Provides standardized usage metrics across providers for cost monitoring + and performance optimization. + + Attributes: + totalTokens: Total tokens used in the interaction. + inputTokens: Tokens used for input processing. + outputTokens: Tokens used for output generation. + audioTokens: Tokens used specifically for audio processing. + """ + + totalTokens: Optional[int] + inputTokens: Optional[int] + outputTokens: Optional[int] + audioTokens: Optional[int] + + +class VoiceActivityEvent(TypedDict): + """Voice activity detection event for speech monitoring. + + Provides standardized voice activity detection events across providers + to enable speech-aware applications and better conversation flow. + + Attributes: + activityType: Type of voice activity detected. + """ + + activityType: Literal["speech_started", "speech_stopped", "timeout"] + + +class UsageMetricsEvent(TypedDict): + """Token usage and performance tracking. + + Provides standardized usage metrics across providers for cost monitoring + and performance optimization. + + Attributes: + totalTokens: Total tokens used in the interaction. + inputTokens: Tokens used for input processing. + outputTokens: Tokens used for output generation. + audioTokens: Tokens used specifically for audio processing. + """ + + totalTokens: Optional[int] + inputTokens: Optional[int] + outputTokens: Optional[int] + audioTokens: Optional[int] + + +class BidirectionalStreamEvent(StreamEvent, total=False): + """Bidirectional stream event extending existing StreamEvent. + + Extends the existing StreamEvent type with bidirectional-specific events + while maintaining full backward compatibility with existing Strands streaming. + + Attributes: + audioOutput: Audio output from the model. + audioInput: Audio input sent to the model. + imageInput: Image input sent to the model. + textOutput: Text output from the model. + transcript: Audio transcription (input or output). + interruptionDetected: User interruption detection. + BidirectionalConnectionStart: connection start event. + BidirectionalConnectionEnd: connection end event. + voiceActivity: Voice activity detection events. + usageMetrics: Token usage and performance metrics. + """ + + audioOutput: Optional[AudioOutputEvent] + audioInput: Optional[AudioInputEvent] + imageInput: Optional[ImageInputEvent] + textOutput: Optional[TextOutputEvent] + transcript: Optional[TranscriptEvent] + interruptionDetected: Optional[InterruptionDetectedEvent] + BidirectionalConnectionStart: Optional[BidirectionalConnectionStartEvent] + BidirectionalConnectionEnd: Optional[BidirectionalConnectionEndEvent] + voiceActivity: Optional[VoiceActivityEvent] + usageMetrics: Optional[UsageMetricsEvent]