Skip to content

Commit 5cfc9ed

Browse files
authored
iterative tool handler process (#340)
1 parent 1421aad commit 5cfc9ed

File tree

8 files changed

+73
-32
lines changed

8 files changed

+73
-32
lines changed

src/strands/agent/agent.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ def caller(
129129
}
130130

131131
# Execute the tool
132-
tool_result = self._agent.tool_handler.process(
132+
events = self._agent.tool_handler.process(
133133
tool=tool_use,
134134
model=self._agent.model,
135135
system_prompt=self._agent.system_prompt,
@@ -138,6 +138,12 @@ def caller(
138138
kwargs=kwargs,
139139
)
140140

141+
try:
142+
while True:
143+
next(events)
144+
except StopIteration as stop:
145+
tool_result = cast(ToolResult, stop.value)
146+
141147
if record_direct_tool_call is not None:
142148
should_record_direct_tool_call = record_direct_tool_call
143149
else:

src/strands/handlers/tool_handler.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from ..tools.registry import ToolRegistry
77
from ..types.content import Messages
88
from ..types.models import Model
9-
from ..types.tools import ToolConfig, ToolHandler, ToolUse
9+
from ..types.tools import ToolConfig, ToolGenerator, ToolHandler, ToolUse
1010

1111
logger = logging.getLogger(__name__)
1212

@@ -35,7 +35,7 @@ def process(
3535
messages: Messages,
3636
tool_config: ToolConfig,
3737
kwargs: dict[str, Any],
38-
) -> Any:
38+
) -> ToolGenerator:
3939
"""Process a tool invocation.
4040
4141
Looks up the tool in the registry and invokes it with the provided parameters.
@@ -48,8 +48,11 @@ def process(
4848
tool_config: Configuration for the tool.
4949
kwargs: Additional keyword arguments passed to the tool.
5050
51+
Yields:
52+
Events of the tool invocation.
53+
5154
Returns:
52-
The result of the tool invocation, or an error response if the tool fails or is not found.
55+
The final tool result or an error response if the tool fails or is not found.
5356
"""
5457
logger.debug("tool=<%s> | invoking", tool)
5558
tool_use_id = tool["toolUseId"]
@@ -82,7 +85,9 @@ def process(
8285
}
8386
)
8487

85-
return tool_func.invoke(tool, **kwargs)
88+
result = tool_func.invoke(tool, **kwargs)
89+
yield {"result": result} # Placeholder until tool_func becomes a generator from which we can yield from
90+
return result
8691

8792
except Exception as e:
8893
logger.exception("tool_name=<%s> | failed to process tool", tool_name)

src/strands/tools/executor.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,13 @@
1313
from ..tools.tools import InvalidToolUseNameException, validate_tool_use
1414
from ..types.content import Message
1515
from ..types.event_loop import ParallelToolExecutorInterface
16-
from ..types.tools import ToolResult, ToolUse
16+
from ..types.tools import ToolGenerator, ToolResult, ToolUse
1717

1818
logger = logging.getLogger(__name__)
1919

2020

2121
def run_tools(
22-
handler: Callable[[ToolUse], ToolResult],
22+
handler: Callable[[ToolUse], Generator[dict[str, Any], None, ToolResult]],
2323
tool_uses: list[ToolUse],
2424
event_loop_metrics: EventLoopMetrics,
2525
invalid_tool_use_ids: list[str],
@@ -44,16 +44,15 @@ def run_tools(
4444
Events of the tool invocations. Tool results are appended to `tool_results`.
4545
"""
4646

47-
def handle(tool: ToolUse) -> Generator[dict[str, Any], None, ToolResult]:
47+
def handle(tool: ToolUse) -> ToolGenerator:
4848
tracer = get_tracer()
4949
tool_call_span = tracer.start_tool_call_span(tool, parent_span)
5050

5151
tool_name = tool["name"]
5252
tool_trace = Trace(f"Tool: {tool_name}", parent_id=cycle_trace.id, raw_name=tool_name)
5353
tool_start_time = time.time()
5454

55-
result = handler(tool)
56-
yield {"result": result} # Placeholder until handler becomes a generator from which we can yield from
55+
result = yield from handler(tool)
5756

5857
tool_success = result.get("status") == "success"
5958
tool_duration = time.time() - tool_start_time
@@ -74,14 +73,14 @@ def work(
7473
) -> ToolResult:
7574
events = handle(tool)
7675

77-
while True:
78-
try:
76+
try:
77+
while True:
7978
event = next(events)
8079
worker_queue.put((worker_id, event))
8180
worker_event.wait()
8281

83-
except StopIteration as stop:
84-
return cast(ToolResult, stop.value)
82+
except StopIteration as stop:
83+
return cast(ToolResult, stop.value)
8584

8685
tool_uses = [tool_use for tool_use in tool_uses if tool_use.get("toolUseId") not in invalid_tool_use_ids]
8786

src/strands/types/tools.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
"""
77

88
from abc import ABC, abstractmethod
9-
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Union
9+
from typing import TYPE_CHECKING, Any, Generator, Literal, Optional, Union
1010

1111
from typing_extensions import TypedDict
1212

@@ -90,7 +90,7 @@ class ToolResult(TypedDict):
9090
toolUseId: The unique identifier of the tool use request that produced this result.
9191
"""
9292

93-
content: List[ToolResultContent]
93+
content: list[ToolResultContent]
9494
status: ToolResultStatus
9595
toolUseId: str
9696

@@ -122,9 +122,9 @@ class ToolChoiceTool(TypedDict):
122122

123123

124124
ToolChoice = Union[
125-
Dict[Literal["auto"], ToolChoiceAuto],
126-
Dict[Literal["any"], ToolChoiceAny],
127-
Dict[Literal["tool"], ToolChoiceTool],
125+
dict[Literal["auto"], ToolChoiceAuto],
126+
dict[Literal["any"], ToolChoiceAny],
127+
dict[Literal["tool"], ToolChoiceTool],
128128
]
129129
"""
130130
Configuration for how the model should choose tools.
@@ -135,6 +135,10 @@ class ToolChoiceTool(TypedDict):
135135
"""
136136

137137

138+
ToolGenerator = Generator[dict[str, Any], None, ToolResult]
139+
"""Generator of tool events and a returned tool result."""
140+
141+
138142
class ToolConfig(TypedDict):
139143
"""Configuration for tools in a model request.
140144
@@ -143,7 +147,7 @@ class ToolConfig(TypedDict):
143147
toolChoice: Configuration for how the model should choose tools.
144148
"""
145149

146-
tools: List[Tool]
150+
tools: list[Tool]
147151
toolChoice: ToolChoice
148152

149153

@@ -250,7 +254,7 @@ def process(
250254
messages: "Messages",
251255
tool_config: ToolConfig,
252256
kwargs: dict[str, Any],
253-
) -> ToolResult:
257+
) -> ToolGenerator:
254258
"""Process a tool use request and execute the tool.
255259
256260
Args:
@@ -261,7 +265,10 @@ def process(
261265
tool_config: The tool configuration for the current session.
262266
kwargs: Additional context-specific arguments.
263267
268+
Yields:
269+
Events of the tool invocation.
270+
264271
Returns:
265-
The result of the tool execution.
272+
The final tool result.
266273
"""
267274
...

tests/conftest.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,3 +68,22 @@ def boto3_profile_path(boto3_profile, tmp_path, monkeypatch):
6868
monkeypatch.setenv("AWS_SHARED_CREDENTIALS_FILE", str(path))
6969

7070
return path
71+
72+
73+
## Itertools
74+
75+
76+
@pytest.fixture(scope="session")
77+
def generate():
78+
def generate(generator):
79+
events = []
80+
81+
try:
82+
while True:
83+
event = next(generator)
84+
events.append(event)
85+
86+
except StopIteration as stop:
87+
return events, stop.value
88+
89+
return generate

tests/strands/agent/test_agent.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -853,7 +853,7 @@ def test_agent_init_with_no_model_or_model_id():
853853

854854

855855
def test_agent_tool_no_parameter_conflict(agent, tool_registry, mock_randint):
856-
agent.tool_handler = unittest.mock.Mock()
856+
agent.tool_handler = unittest.mock.Mock(process=unittest.mock.Mock(return_value=iter([])))
857857

858858
@strands.tools.tool(name="system_prompter")
859859
def function(system_prompt: str) -> str:
@@ -880,7 +880,7 @@ def function(system_prompt: str) -> str:
880880

881881

882882
def test_agent_tool_with_name_normalization(agent, tool_registry, mock_randint):
883-
agent.tool_handler = unittest.mock.Mock()
883+
agent.tool_handler = unittest.mock.Mock(process=unittest.mock.Mock(return_value=iter([])))
884884

885885
tool_name = "system-prompter"
886886

@@ -908,8 +908,6 @@ def function(system_prompt: str) -> str:
908908

909909

910910
def test_agent_tool_with_no_normalized_match(agent, tool_registry, mock_randint):
911-
agent.tool_handler = unittest.mock.Mock()
912-
913911
mock_randint.return_value = 1
914912

915913
with pytest.raises(AttributeError) as err:

tests/strands/handlers/test_tool_handler.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,29 +26,33 @@ def identity(a: int) -> int:
2626
return {"toolUseId": "identity", "name": "identity", "input": {"a": 1}}
2727

2828

29-
def test_process(tool_handler, tool_use_identity):
30-
tru_result = tool_handler.process(
29+
def test_process(tool_handler, tool_use_identity, generate):
30+
process = tool_handler.process(
3131
tool_use_identity,
3232
model=unittest.mock.Mock(),
3333
system_prompt="p1",
3434
messages=[],
3535
tool_config={},
3636
kwargs={},
3737
)
38+
39+
_, tru_result = generate(process)
3840
exp_result = {"toolUseId": "identity", "status": "success", "content": [{"text": "1"}]}
3941

4042
assert tru_result == exp_result
4143

4244

43-
def test_process_missing_tool(tool_handler):
44-
tru_result = tool_handler.process(
45+
def test_process_missing_tool(tool_handler, generate):
46+
process = tool_handler.process(
4547
tool={"toolUseId": "missing", "name": "missing", "input": {}},
4648
model=unittest.mock.Mock(),
4749
system_prompt="p1",
4850
messages=[],
4951
tool_config={},
5052
kwargs={},
5153
)
54+
55+
_, tru_result = generate(process)
5256
exp_result = {
5357
"toolUseId": "missing",
5458
"status": "error",

tests/strands/tools/test_executor.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ def moto_autouse(moto_env):
1818
@pytest.fixture
1919
def tool_handler(request):
2020
def handler(tool_use):
21+
yield {"event": "abc"}
2122
return {
2223
**params,
2324
"toolUseId": tool_use["toolUseId"],
@@ -102,7 +103,9 @@ def test_run_tools(
102103
cycle_trace,
103104
parallel_tool_executor,
104105
)
105-
list(stream)
106+
107+
tru_events = list(stream)
108+
exp_events = [{"event": "abc"}]
106109

107110
tru_results = tool_results
108111
exp_results = [
@@ -117,7 +120,7 @@ def test_run_tools(
117120
},
118121
]
119122

120-
assert tru_results == exp_results
123+
assert tru_events == exp_events and tru_results == exp_results
121124

122125

123126
@pytest.mark.parametrize("invalid_tool_use_ids", [["t1"]], indirect=True)

0 commit comments

Comments
 (0)