Skip to content

Commit d4b02a6

Browse files
jer96jer
authored andcommitted
a2a streaming (#366)
Co-authored-by: jer <jerebill@amazon.com>
1 parent 219b227 commit d4b02a6

File tree

5 files changed

+311
-71
lines changed

5 files changed

+311
-71
lines changed

src/strands/multiagent/a2a/executor.py

Lines changed: 100 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,31 +2,40 @@
22
33
This module provides the StrandsA2AExecutor class, which adapts a Strands Agent
44
to be used as an executor in the A2A protocol. It handles the execution of agent
5-
requests and the conversion of Strands Agent responses to A2A events.
5+
requests and the conversion of Strands Agent streamed responses to A2A events.
6+
7+
The A2A AgentExecutor ensures clients recieve responses for synchronous and
8+
streamed requests to the A2AServer.
69
"""
710

811
import logging
12+
from typing import Any
913

1014
from a2a.server.agent_execution import AgentExecutor, RequestContext
1115
from a2a.server.events import EventQueue
12-
from a2a.types import UnsupportedOperationError
13-
from a2a.utils import new_agent_text_message
16+
from a2a.server.tasks import TaskUpdater
17+
from a2a.types import InternalError, Part, TaskState, TextPart, UnsupportedOperationError
18+
from a2a.utils import new_agent_text_message, new_task
1419
from a2a.utils.errors import ServerError
1520

1621
from ...agent.agent import Agent as SAAgent
17-
from ...agent.agent_result import AgentResult as SAAgentResult
22+
from ...agent.agent import AgentResult as SAAgentResult
1823

19-
log = logging.getLogger(__name__)
24+
logger = logging.getLogger(__name__)
2025

2126

2227
class StrandsA2AExecutor(AgentExecutor):
23-
"""Executor that adapts a Strands Agent to the A2A protocol."""
28+
"""Executor that adapts a Strands Agent to the A2A protocol.
29+
30+
This executor uses streaming mode to handle the execution of agent requests
31+
and converts Strands Agent responses to A2A protocol events.
32+
"""
2433

2534
def __init__(self, agent: SAAgent):
2635
"""Initialize a StrandsA2AExecutor.
2736
2837
Args:
29-
agent: The Strands Agent to adapt to the A2A protocol.
38+
agent: The Strands Agent instance to adapt to the A2A protocol.
3039
"""
3140
self.agent = agent
3241

@@ -37,24 +46,97 @@ async def execute(
3746
) -> None:
3847
"""Execute a request using the Strands Agent and send the response as A2A events.
3948
40-
This method executes the user's input using the Strands Agent and converts
41-
the agent's response to A2A events, which are then sent to the event queue.
49+
This method executes the user's input using the Strands Agent in streaming mode
50+
and converts the agent's response to A2A events.
51+
52+
Args:
53+
context: The A2A request context, containing the user's input and task metadata.
54+
event_queue: The A2A event queue used to send response events back to the client.
55+
56+
Raises:
57+
ServerError: If an error occurs during agent execution
58+
"""
59+
task = context.current_task
60+
if not task:
61+
task = new_task(context.message) # type: ignore
62+
await event_queue.enqueue_event(task)
63+
64+
updater = TaskUpdater(event_queue, task.id, task.contextId)
65+
66+
try:
67+
await self._execute_streaming(context, updater)
68+
except Exception as e:
69+
raise ServerError(error=InternalError()) from e
70+
71+
async def _execute_streaming(self, context: RequestContext, updater: TaskUpdater) -> None:
72+
"""Execute request in streaming mode.
73+
74+
Streams the agent's response in real-time, sending incremental updates
75+
as they become available from the agent.
4276
4377
Args:
4478
context: The A2A request context, containing the user's input and other metadata.
45-
event_queue: The A2A event queue, used to send response events.
79+
updater: The task updater for managing task state and sending updates.
80+
"""
81+
logger.info("Executing request in streaming mode")
82+
user_input = context.get_user_input()
83+
try:
84+
async for event in self.agent.stream_async(user_input):
85+
await self._handle_streaming_event(event, updater)
86+
except Exception:
87+
logger.exception("Error in streaming execution")
88+
raise
89+
90+
async def _handle_streaming_event(self, event: dict[str, Any], updater: TaskUpdater) -> None:
91+
"""Handle a single streaming event from the Strands Agent.
92+
93+
Processes streaming events from the agent, converting data chunks to A2A
94+
task updates and handling the final result when streaming is complete.
95+
96+
Args:
97+
event: The streaming event from the agent, containing either 'data' for
98+
incremental content or 'result' for the final response.
99+
updater: The task updater for managing task state and sending updates.
100+
"""
101+
logger.debug("Streaming event: %s", event)
102+
if "data" in event:
103+
if text_content := event["data"]:
104+
await updater.update_status(
105+
TaskState.working,
106+
new_agent_text_message(
107+
text_content,
108+
updater.context_id,
109+
updater.task_id,
110+
),
111+
)
112+
elif "result" in event:
113+
await self._handle_agent_result(event["result"], updater)
114+
else:
115+
logger.warning("Unexpected streaming event: %s", event)
116+
117+
async def _handle_agent_result(self, result: SAAgentResult | None, updater: TaskUpdater) -> None:
118+
"""Handle the final result from the Strands Agent.
119+
120+
Processes the agent's final result, extracts text content from the response,
121+
and adds it as an artifact to the task before marking the task as complete.
122+
123+
Args:
124+
result: The agent result object containing the final response, or None if no result.
125+
updater: The task updater for managing task state and adding the final artifact.
46126
"""
47-
result: SAAgentResult = self.agent(context.get_user_input())
48-
if result.message and "content" in result.message:
49-
for content_block in result.message["content"]:
50-
if "text" in content_block:
51-
await event_queue.enqueue_event(new_agent_text_message(content_block["text"]))
127+
if final_content := str(result):
128+
await updater.add_artifact(
129+
[Part(root=TextPart(text=final_content))],
130+
name="agent_response",
131+
)
132+
await updater.complete()
52133

53134
async def cancel(self, context: RequestContext, event_queue: EventQueue) -> None:
54135
"""Cancel an ongoing execution.
55136
56-
This method is called when a request is cancelled. Currently, cancellation
57-
is not supported, so this method raises an UnsupportedOperationError.
137+
This method is called when a request cancellation is requested. Currently,
138+
cancellation is not supported by the Strands Agent executor, so this method
139+
always raises an UnsupportedOperationError.
58140
59141
Args:
60142
context: The A2A request context.
@@ -64,4 +146,5 @@ async def cancel(self, context: RequestContext, event_queue: EventQueue) -> None
64146
ServerError: Always raised with an UnsupportedOperationError, as cancellation
65147
is not currently supported.
66148
"""
149+
logger.warning("Cancellation requested but not supported")
67150
raise ServerError(error=UnsupportedOperationError())

src/strands/multiagent/a2a/server.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,7 @@ def __init__(
5252
self.strands_agent = agent
5353
self.name = self.strands_agent.name
5454
self.description = self.strands_agent.description
55-
# TODO: enable configurable capabilities and request handler
56-
self.capabilities = AgentCapabilities()
55+
self.capabilities = AgentCapabilities(streaming=True)
5756
self.request_handler = DefaultRequestHandler(
5857
agent_executor=StrandsA2AExecutor(self.strands_agent),
5958
task_store=InMemoryTaskStore(),

tests/multiagent/a2a/conftest.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,10 @@ def mock_strands_agent():
2222
mock_result.message = {"content": [{"text": "Test response"}]}
2323
agent.return_value = mock_result
2424

25+
# Setup async methods
26+
agent.invoke_async = AsyncMock(return_value=mock_result)
27+
agent.stream_async = AsyncMock(return_value=iter([]))
28+
2529
# Setup mock tool registry
2630
mock_tool_registry = MagicMock()
2731
mock_tool_registry.get_all_tools_config.return_value = {}

0 commit comments

Comments
 (0)