Skip to content

Commit 52677ab

Browse files
authored
remove thread pool wrapper (#339)
1 parent 5cfc9ed commit 52677ab

File tree

11 files changed

+88
-296
lines changed

11 files changed

+88
-296
lines changed

src/strands/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,5 @@
33
from . import agent, event_loop, models, telemetry, types
44
from .agent.agent import Agent
55
from .tools.decorator import tool
6-
from .tools.thread_pool_executor import ThreadPoolExecutorWrapper
76

8-
__all__ = ["Agent", "ThreadPoolExecutorWrapper", "agent", "event_loop", "models", "tool", "types", "telemetry"]
7+
__all__ = ["Agent", "agent", "event_loop", "models", "tool", "types", "telemetry"]

src/strands/agent/agent.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
from ..telemetry.metrics import EventLoopMetrics
2828
from ..telemetry.tracer import get_tracer
2929
from ..tools.registry import ToolRegistry
30-
from ..tools.thread_pool_executor import ThreadPoolExecutorWrapper
3130
from ..tools.watcher import ToolWatcher
3231
from ..types.content import ContentBlock, Message, Messages
3332
from ..types.exceptions import ContextWindowOverflowException
@@ -275,7 +274,6 @@ def __init__(
275274
self.thread_pool_wrapper = None
276275
if max_parallel_tools > 1:
277276
self.thread_pool = ThreadPoolExecutor(max_workers=max_parallel_tools)
278-
self.thread_pool_wrapper = ThreadPoolExecutorWrapper(self.thread_pool)
279277
elif max_parallel_tools < 1:
280278
raise ValueError("max_parallel_tools must be greater than 0")
281279

@@ -358,8 +356,8 @@ def __del__(self) -> None:
358356
359357
Ensures proper shutdown of the thread pool executor if one exists.
360358
"""
361-
if self.thread_pool_wrapper and hasattr(self.thread_pool_wrapper, "shutdown"):
362-
self.thread_pool_wrapper.shutdown(wait=False)
359+
if self.thread_pool:
360+
self.thread_pool.shutdown(wait=False)
363361
logger.debug("thread pool executor shutdown complete")
364362

365363
def __call__(self, prompt: str, **kwargs: Any) -> AgentResult:
@@ -528,7 +526,7 @@ def _execute_event_loop_cycle(self, kwargs: dict[str, Any]) -> Generator[dict[st
528526
messages=self.messages, # will be modified by event_loop_cycle
529527
tool_config=self.tool_config,
530528
tool_handler=self.tool_handler,
531-
tool_execution_handler=self.thread_pool_wrapper,
529+
thread_pool=self.thread_pool,
532530
event_loop_metrics=self.event_loop_metrics,
533531
event_loop_parent_span=self.trace_span,
534532
kwargs=kwargs,

src/strands/event_loop/event_loop.py

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import logging
1212
import time
1313
import uuid
14+
from concurrent.futures import ThreadPoolExecutor
1415
from functools import partial
1516
from typing import Any, Generator, Optional
1617

@@ -20,7 +21,6 @@
2021
from ..telemetry.tracer import get_tracer
2122
from ..tools.executor import run_tools, validate_and_prepare_tools
2223
from ..types.content import Message, Messages
23-
from ..types.event_loop import ParallelToolExecutorInterface
2424
from ..types.exceptions import ContextWindowOverflowException, EventLoopException, ModelThrottledException
2525
from ..types.models import Model
2626
from ..types.streaming import Metrics, StopReason
@@ -41,7 +41,7 @@ def event_loop_cycle(
4141
messages: Messages,
4242
tool_config: Optional[ToolConfig],
4343
tool_handler: Optional[ToolHandler],
44-
tool_execution_handler: Optional[ParallelToolExecutorInterface],
44+
thread_pool: Optional[ThreadPoolExecutor],
4545
event_loop_metrics: EventLoopMetrics,
4646
event_loop_parent_span: Optional[trace.Span],
4747
kwargs: dict[str, Any],
@@ -65,7 +65,7 @@ def event_loop_cycle(
6565
messages: Conversation history messages.
6666
tool_config: Configuration for available tools.
6767
tool_handler: Handler for executing tools.
68-
tool_execution_handler: Optional handler for parallel tool execution.
68+
thread_pool: Optional thread pool for parallel tool execution.
6969
event_loop_metrics: Metrics tracking object for the event loop.
7070
event_loop_parent_span: Span for the parent of this event loop.
7171
kwargs: Additional arguments including:
@@ -210,7 +210,7 @@ def event_loop_cycle(
210210
messages,
211211
tool_config,
212212
tool_handler,
213-
tool_execution_handler,
213+
thread_pool,
214214
event_loop_metrics,
215215
event_loop_parent_span,
216216
cycle_trace,
@@ -256,7 +256,7 @@ def recurse_event_loop(
256256
messages: Messages,
257257
tool_config: Optional[ToolConfig],
258258
tool_handler: Optional[ToolHandler],
259-
tool_execution_handler: Optional[ParallelToolExecutorInterface],
259+
thread_pool: Optional[ThreadPoolExecutor],
260260
event_loop_metrics: EventLoopMetrics,
261261
event_loop_parent_span: Optional[trace.Span],
262262
kwargs: dict[str, Any],
@@ -271,7 +271,7 @@ def recurse_event_loop(
271271
messages: Conversation history messages
272272
tool_config: Configuration for available tools
273273
tool_handler: Handler for tool execution
274-
tool_execution_handler: Optional handler for parallel tool execution.
274+
thread_pool: Optional thread pool for parallel tool execution.
275275
event_loop_metrics: Metrics tracking object for the event loop.
276276
event_loop_parent_span: Span for the parent of this event loop.
277277
kwargs: Arguments to pass through event_loop_cycle
@@ -298,7 +298,7 @@ def recurse_event_loop(
298298
messages=messages,
299299
tool_config=tool_config,
300300
tool_handler=tool_handler,
301-
tool_execution_handler=tool_execution_handler,
301+
thread_pool=thread_pool,
302302
event_loop_metrics=event_loop_metrics,
303303
event_loop_parent_span=event_loop_parent_span,
304304
kwargs=kwargs,
@@ -315,7 +315,7 @@ def _handle_tool_execution(
315315
messages: Messages,
316316
tool_config: ToolConfig,
317317
tool_handler: ToolHandler,
318-
tool_execution_handler: Optional[ParallelToolExecutorInterface],
318+
thread_pool: Optional[ThreadPoolExecutor],
319319
event_loop_metrics: EventLoopMetrics,
320320
event_loop_parent_span: Optional[trace.Span],
321321
cycle_trace: Trace,
@@ -331,20 +331,20 @@ def _handle_tool_execution(
331331
Handles the execution of tools requested by the model during an event loop cycle.
332332
333333
Args:
334-
stop_reason (StopReason): The reason the model stopped generating.
335-
message (Message): The message from the model that may contain tool use requests.
336-
model (Model): The model provider instance.
337-
system_prompt (Optional[str]): The system prompt instructions for the model.
338-
messages (Messages): The conversation history messages.
339-
tool_config (ToolConfig): Configuration for available tools.
340-
tool_handler (ToolHandler): Handler for tool execution.
341-
tool_execution_handler (Optional[ParallelToolExecutorInterface]): Optional handler for parallel tool execution.
342-
event_loop_metrics (EventLoopMetrics): Metrics tracking object for the event loop.
343-
event_loop_parent_span (Any): Span for the parent of this event loop.
344-
cycle_trace (Trace): Trace object for the current event loop cycle.
345-
cycle_span (Any): Span object for tracing the cycle (type may vary).
346-
cycle_start_time (float): Start time of the current cycle.
347-
kwargs (dict[str, Any]): Additional keyword arguments, including request state.
334+
stop_reason: The reason the model stopped generating.
335+
message: The message from the model that may contain tool use requests.
336+
model: The model provider instance.
337+
system_prompt: The system prompt instructions for the model.
338+
messages: The conversation history messages.
339+
tool_config: Configuration for available tools.
340+
tool_handler: Handler for tool execution.
341+
thread_pool: Optional thread pool for parallel tool execution.
342+
event_loop_metrics: Metrics tracking object for the event loop.
343+
event_loop_parent_span: Span for the parent of this event loop.
344+
cycle_trace: Trace object for the current event loop cycle.
345+
cycle_span: Span object for tracing the cycle (type may vary).
346+
cycle_start_time: Start time of the current cycle.
347+
kwargs: Additional keyword arguments, including request state.
348348
349349
Yields:
350350
Tool invocation events along with events yielded from a recursive call to the event loop. The last event is a
@@ -377,7 +377,7 @@ def _handle_tool_execution(
377377
tool_results=tool_results,
378378
cycle_trace=cycle_trace,
379379
parent_span=cycle_span,
380-
parallel_tool_executor=tool_execution_handler,
380+
thread_pool=thread_pool,
381381
)
382382

383383
# Store parent cycle ID for the next cycle
@@ -406,7 +406,7 @@ def _handle_tool_execution(
406406
messages=messages,
407407
tool_config=tool_config,
408408
tool_handler=tool_handler,
409-
tool_execution_handler=tool_execution_handler,
409+
thread_pool=thread_pool,
410410
event_loop_metrics=event_loop_metrics,
411411
event_loop_parent_span=event_loop_parent_span,
412412
kwargs=kwargs,

src/strands/tools/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55

66
from .decorator import tool
77
from .structured_output import convert_pydantic_to_tool_spec
8-
from .thread_pool_executor import ThreadPoolExecutorWrapper
98
from .tools import InvalidToolUseNameException, PythonAgentTool, normalize_schema, normalize_tool_spec
109

1110
__all__ = [
@@ -14,6 +13,5 @@
1413
"InvalidToolUseNameException",
1514
"normalize_schema",
1615
"normalize_tool_spec",
17-
"ThreadPoolExecutorWrapper",
1816
"convert_pydantic_to_tool_spec",
1917
]

src/strands/tools/executor.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import queue
55
import threading
66
import time
7+
from concurrent.futures import ThreadPoolExecutor
78
from typing import Any, Callable, Generator, Optional, cast
89

910
from opentelemetry import trace
@@ -12,7 +13,6 @@
1213
from ..telemetry.tracer import get_tracer
1314
from ..tools.tools import InvalidToolUseNameException, validate_tool_use
1415
from ..types.content import Message
15-
from ..types.event_loop import ParallelToolExecutorInterface
1616
from ..types.tools import ToolGenerator, ToolResult, ToolUse
1717

1818
logger = logging.getLogger(__name__)
@@ -26,7 +26,7 @@ def run_tools(
2626
tool_results: list[ToolResult],
2727
cycle_trace: Trace,
2828
parent_span: Optional[trace.Span] = None,
29-
parallel_tool_executor: Optional[ParallelToolExecutorInterface] = None,
29+
thread_pool: Optional[ThreadPoolExecutor] = None,
3030
) -> Generator[dict[str, Any], None, None]:
3131
"""Execute tools either in parallel or sequentially.
3232
@@ -38,7 +38,7 @@ def run_tools(
3838
tool_results: List to populate with tool results.
3939
cycle_trace: Parent trace for the current cycle.
4040
parent_span: Parent span for the current cycle.
41-
parallel_tool_executor: Optional executor for parallel processing.
41+
thread_pool: Optional thread pool for parallel processing.
4242
4343
Yields:
4444
Events of the tool invocations. Tool results are appended to `tool_results`.
@@ -84,18 +84,14 @@ def work(
8484

8585
tool_uses = [tool_use for tool_use in tool_uses if tool_use.get("toolUseId") not in invalid_tool_use_ids]
8686

87-
if parallel_tool_executor:
88-
logger.debug(
89-
"tool_count=<%s>, tool_executor=<%s> | executing tools in parallel",
90-
len(tool_uses),
91-
type(parallel_tool_executor).__name__,
92-
)
87+
if thread_pool:
88+
logger.debug("tool_count=<%s> | executing tools in parallel", len(tool_uses))
9389

9490
worker_queue: queue.Queue[tuple[int, dict[str, Any]]] = queue.Queue()
9591
worker_events = [threading.Event() for _ in range(len(tool_uses))]
9692

9793
workers = [
98-
parallel_tool_executor.submit(work, tool_use, worker_id, worker_queue, worker_events[worker_id])
94+
thread_pool.submit(work, tool_use, worker_id, worker_queue, worker_events[worker_id])
9995
for worker_id, tool_use in enumerate(tool_uses)
10096
]
10197
logger.debug("tool_count=<%s> | submitted tasks to parallel executor", len(tool_uses))

src/strands/tools/thread_pool_executor.py

Lines changed: 0 additions & 69 deletions
This file was deleted.

src/strands/types/event_loop.py

Lines changed: 2 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
"""Event loop-related type definitions for the SDK."""
22

3-
from typing import Any, Callable, Iterable, Iterator, Literal, Optional, Protocol
3+
from typing import Literal
44

5-
from typing_extensions import TypedDict, runtime_checkable
5+
from typing_extensions import TypedDict
66

77

88
class Usage(TypedDict):
@@ -46,69 +46,3 @@ class Metrics(TypedDict):
4646
- "stop_sequence": Stop sequence encountered
4747
- "tool_use": Model requested to use a tool
4848
"""
49-
50-
51-
@runtime_checkable
52-
class Future(Protocol):
53-
"""Interface representing the result of an asynchronous computation."""
54-
55-
def result(self, timeout: Optional[int] = None) -> Any:
56-
"""Return the result of the call that the future represents.
57-
58-
This method will block until the asynchronous operation completes or until the specified timeout is reached.
59-
60-
Args:
61-
timeout: The number of seconds to wait for the result.
62-
If None, then there is no limit on the wait time.
63-
64-
Returns:
65-
Any: The result of the asynchronous operation.
66-
"""
67-
68-
def done(self) -> bool:
69-
"""Returns true if future is done executing."""
70-
71-
72-
@runtime_checkable
73-
class ParallelToolExecutorInterface(Protocol):
74-
"""Interface for parallel tool execution.
75-
76-
Attributes:
77-
timeout: Default timeout in seconds for futures.
78-
"""
79-
80-
timeout: int = 900 # default 15 minute timeout for futures
81-
82-
def submit(self, fn: Callable[..., Any], /, *args: Any, **kwargs: Any) -> Future:
83-
"""Submit a callable to be executed with the given arguments.
84-
85-
Schedules the callable to be executed as fn(*args, **kwargs) and returns a Future instance representing the
86-
execution of the callable.
87-
88-
Args:
89-
fn: The callable to execute.
90-
*args: Positional arguments to pass to the callable.
91-
**kwargs: Keyword arguments to pass to the callable.
92-
93-
Returns:
94-
Future: A Future representing the given call.
95-
"""
96-
97-
def as_completed(self, futures: Iterable[Future], timeout: Optional[int] = timeout) -> Iterator[Future]:
98-
"""Iterate over the given futures, yielding each as it completes.
99-
100-
Args:
101-
futures: The sequence of Futures to iterate over.
102-
timeout: The maximum number of seconds to wait.
103-
If None, then there is no limit on the wait time.
104-
105-
Returns:
106-
An iterator that yields the given Futures as they complete (finished or cancelled).
107-
"""
108-
109-
def shutdown(self, wait: bool = True) -> None:
110-
"""Shutdown the executor and free associated resources.
111-
112-
Args:
113-
wait: If True, shutdown will not return until all running futures have finished executing.
114-
"""

0 commit comments

Comments
 (0)