Skip to content

Commit b4120d8

Browse files
committed
fix: add timeout support for MCP tools to resolve long-running tool hangs
- Add optional timeout parameter to MCPAgentTool constructor - Add default_tool_timeout parameter to MCPClient constructor - Update call_tool methods to use default timeout when none specified - Maintain full backward compatibility with existing code - Add comprehensive test coverage for timeout functionality Fixes strands-agents#625: MCP Client Tool Timeout Issue in Multi-Agent Orchestration 🤖 Assisted by the Amazon Q Developer
1 parent 1f25512 commit b4120d8

File tree

5 files changed

+258
-26
lines changed

5 files changed

+258
-26
lines changed

src/strands/tools/mcp/mcp_agent_tool.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
# ABOUTME: Adapter class that wraps MCP tools and exposes them as AgentTools with timeout support
2+
# ABOUTME: Bridges the gap between MCP protocol tools and the agent framework's tool interface
13
"""MCP Agent Tool module for adapting Model Context Protocol tools to the agent framework.
24
35
This module provides the MCPAgentTool class which serves as an adapter between
@@ -6,7 +8,8 @@
68
"""
79

810
import logging
9-
from typing import TYPE_CHECKING, Any
11+
from datetime import timedelta
12+
from typing import TYPE_CHECKING, Any, Optional
1013

1114
from mcp.types import Tool as MCPTool
1215
from typing_extensions import override
@@ -28,17 +31,19 @@ class MCPAgentTool(AgentTool):
2831
seamlessly within the agent framework.
2932
"""
3033

31-
def __init__(self, mcp_tool: MCPTool, mcp_client: "MCPClient") -> None:
34+
def __init__(self, mcp_tool: MCPTool, mcp_client: "MCPClient", *, timeout: Optional[timedelta] = None) -> None:
3235
"""Initialize a new MCPAgentTool instance.
3336
3437
Args:
3538
mcp_tool: The MCP tool to adapt
3639
mcp_client: The MCP server connection to use for tool invocation
40+
timeout: Optional timeout for tool execution. If None, uses the MCP client's default timeout.
3741
"""
3842
super().__init__()
3943
logger.debug("tool_name=<%s> | creating mcp agent tool", mcp_tool.name)
4044
self.mcp_tool = mcp_tool
4145
self.mcp_client = mcp_client
46+
self._timeout = timeout
4247

4348
@property
4449
def tool_name(self) -> str:
@@ -96,5 +101,6 @@ async def stream(self, tool_use: ToolUse, invocation_state: dict[str, Any], **kw
96101
tool_use_id=tool_use["toolUseId"],
97102
name=self.tool_name,
98103
arguments=tool_use["input"],
104+
read_timeout_seconds=self._timeout,
99105
)
100106
yield ToolResultEvent(result)

src/strands/tools/mcp/mcp_client.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -63,15 +63,24 @@ class MCPClient:
6363
from MCP tools, it will be returned as the last item in the content array of the ToolResult.
6464
"""
6565

66-
def __init__(self, transport_callable: Callable[[], MCPTransport], *, startup_timeout: int = 30):
66+
def __init__(
67+
self,
68+
transport_callable: Callable[[], MCPTransport],
69+
*,
70+
startup_timeout: int = 30,
71+
default_tool_timeout: Optional[timedelta] = None,
72+
):
6773
"""Initialize a new MCP Server connection.
6874
6975
Args:
7076
transport_callable: A callable that returns an MCPTransport (read_stream, write_stream) tuple
7177
startup_timeout: Timeout after which MCP server initialization should be cancelled
7278
Defaults to 30.
79+
default_tool_timeout: Default timeout for tool calls when no specific timeout is provided.
80+
If None, no default timeout is applied.
7381
"""
7482
self._startup_timeout = startup_timeout
83+
self._default_tool_timeout = default_tool_timeout
7584

7685
mcp_instrumentation()
7786
self._session_id = uuid.uuid4()
@@ -149,7 +158,7 @@ def stop(
149158
- _background_thread_event_loop: AsyncIO event loop in background thread
150159
- _close_event: AsyncIO event to signal thread shutdown
151160
- _init_future: Future for initialization synchronization
152-
161+
153162
Cleanup order:
154163
1. Signal close event to background thread (if session initialized)
155164
2. Wait for background thread to complete
@@ -275,7 +284,7 @@ def call_tool_sync(
275284
tool_use_id: Unique identifier for this tool use
276285
name: Name of the tool to call
277286
arguments: Optional arguments to pass to the tool
278-
read_timeout_seconds: Optional timeout for the tool call
287+
read_timeout_seconds: Optional timeout for the tool call. If None, uses the client's default timeout.
279288
280289
Returns:
281290
MCPToolResult: The result of the tool call
@@ -284,9 +293,12 @@ def call_tool_sync(
284293
if not self._is_session_active():
285294
raise MCPClientInitializationError(CLIENT_SESSION_NOT_RUNNING_ERROR_MESSAGE)
286295

296+
# Use default timeout if no specific timeout provided
297+
effective_timeout = read_timeout_seconds if read_timeout_seconds is not None else self._default_tool_timeout
298+
287299
async def _call_tool_async() -> MCPCallToolResult:
288300
return await cast(ClientSession, self._background_thread_session).call_tool(
289-
name, arguments, read_timeout_seconds
301+
name, arguments, effective_timeout
290302
)
291303

292304
try:
@@ -312,7 +324,7 @@ async def call_tool_async(
312324
tool_use_id: Unique identifier for this tool use
313325
name: Name of the tool to call
314326
arguments: Optional arguments to pass to the tool
315-
read_timeout_seconds: Optional timeout for the tool call
327+
read_timeout_seconds: Optional timeout for the tool call. If None, uses the client's default timeout.
316328
317329
Returns:
318330
MCPToolResult: The result of the tool call
@@ -321,9 +333,12 @@ async def call_tool_async(
321333
if not self._is_session_active():
322334
raise MCPClientInitializationError(CLIENT_SESSION_NOT_RUNNING_ERROR_MESSAGE)
323335

336+
# Use default timeout if no specific timeout provided
337+
effective_timeout = read_timeout_seconds if read_timeout_seconds is not None else self._default_tool_timeout
338+
324339
async def _call_tool_async() -> MCPCallToolResult:
325340
return await cast(ClientSession, self._background_thread_session).call_tool(
326-
name, arguments, read_timeout_seconds
341+
name, arguments, effective_timeout
327342
)
328343

329344
try:

tests/strands/tools/mcp/test_mcp_agent_tool.py

Lines changed: 62 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
1-
from unittest.mock import MagicMock
1+
from datetime import timedelta
2+
from unittest.mock import AsyncMock, MagicMock
23

34
import pytest
45
from mcp.types import Tool as MCPTool
56

67
from strands.tools.mcp import MCPAgentTool, MCPClient
8+
from strands.tools.mcp.mcp_types import MCPToolResult
79
from strands.types._events import ToolResultEvent
810

911

@@ -24,6 +26,10 @@ def mock_mcp_client():
2426
"toolUseId": "test-123",
2527
"content": [{"text": "Success result"}],
2628
}
29+
mock_server.call_tool_async = AsyncMock()
30+
mock_server.call_tool_async.return_value = MCPToolResult(
31+
status="success", toolUseId="test-123", content=[{"text": "Success result"}]
32+
)
2733
return mock_server
2834

2935

@@ -58,6 +64,27 @@ def test_tool_spec_without_description(mock_mcp_tool, mock_mcp_client):
5864
assert tool_spec["description"] == "Tool which performs test_tool"
5965

6066

67+
def test_mcp_agent_tool_with_timeout_initialization(mock_mcp_tool, mock_mcp_client):
68+
"""Test that MCPAgentTool initializes correctly with timeout parameter."""
69+
timeout = timedelta(minutes=5)
70+
tool = MCPAgentTool(mock_mcp_tool, mock_mcp_client, timeout=timeout)
71+
72+
assert tool.tool_name == "test_tool"
73+
assert tool.mcp_tool == mock_mcp_tool
74+
assert tool.mcp_client == mock_mcp_client
75+
assert tool._timeout == timeout
76+
77+
78+
def test_mcp_agent_tool_without_timeout_initialization(mock_mcp_tool, mock_mcp_client):
79+
"""Test that MCPAgentTool initializes correctly without timeout parameter."""
80+
tool = MCPAgentTool(mock_mcp_tool, mock_mcp_client)
81+
82+
assert tool.tool_name == "test_tool"
83+
assert tool.mcp_tool == mock_mcp_tool
84+
assert tool.mcp_client == mock_mcp_client
85+
assert tool._timeout is None
86+
87+
6188
@pytest.mark.asyncio
6289
async def test_stream(mcp_agent_tool, mock_mcp_client, alist):
6390
tool_use = {"toolUseId": "test-123", "name": "test_tool", "input": {"param": "value"}}
@@ -67,5 +94,38 @@ async def test_stream(mcp_agent_tool, mock_mcp_client, alist):
6794

6895
assert tru_events == exp_events
6996
mock_mcp_client.call_tool_async.assert_called_once_with(
70-
tool_use_id="test-123", name="test_tool", arguments={"param": "value"}
97+
tool_use_id="test-123", name="test_tool", arguments={"param": "value"}, read_timeout_seconds=None
98+
)
99+
100+
101+
@pytest.mark.asyncio
102+
async def test_stream_with_timeout(mock_mcp_tool, mock_mcp_client, alist):
103+
"""Test that stream method passes timeout to MCP client."""
104+
timeout = timedelta(minutes=5)
105+
tool = MCPAgentTool(mock_mcp_tool, mock_mcp_client, timeout=timeout)
106+
107+
tool_use = {"toolUseId": "test-123", "name": "test_tool", "input": {"param": "value"}}
108+
109+
tru_events = await alist(tool.stream(tool_use, {}))
110+
exp_events = [ToolResultEvent(mock_mcp_client.call_tool_async.return_value)]
111+
112+
assert tru_events == exp_events
113+
mock_mcp_client.call_tool_async.assert_called_once_with(
114+
tool_use_id="test-123", name="test_tool", arguments={"param": "value"}, read_timeout_seconds=timeout
115+
)
116+
117+
118+
@pytest.mark.asyncio
119+
async def test_stream_without_timeout_passes_none(mock_mcp_tool, mock_mcp_client, alist):
120+
"""Test that stream method passes None timeout when no timeout configured."""
121+
tool = MCPAgentTool(mock_mcp_tool, mock_mcp_client) # No timeout
122+
123+
tool_use = {"toolUseId": "test-123", "name": "test_tool", "input": {"param": "value"}}
124+
125+
tru_events = await alist(tool.stream(tool_use, {}))
126+
exp_events = [ToolResultEvent(mock_mcp_client.call_tool_async.return_value)]
127+
128+
assert tru_events == exp_events
129+
mock_mcp_client.call_tool_async.assert_called_once_with(
130+
tool_use_id="test-123", name="test_tool", arguments={"param": "value"}, read_timeout_seconds=None
71131
)

tests/strands/tools/mcp/test_mcp_client.py

Lines changed: 96 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -522,21 +522,101 @@ def test_stop_with_background_thread_but_no_event_loop():
522522
# Verify cleanup occurred
523523
assert client._background_thread is None
524524

525-
526-
def test_mcp_client_state_reset_after_timeout():
527-
"""Test that all client state is properly reset after timeout."""
528-
def slow_transport():
529-
time.sleep(4) # Longer than timeout
530-
return MagicMock()
531525

532-
client = MCPClient(slow_transport, startup_timeout=2)
533-
534-
# First attempt should timeout
535-
with pytest.raises(MCPClientInitializationError, match="background thread did not start in 2 seconds"):
536-
client.start()
526+
@pytest.mark.asyncio
527+
async def test_call_tool_async_with_default_timeout(mock_transport, mock_session):
528+
"""Test that call_tool_async uses default timeout when none specified."""
529+
from datetime import timedelta
537530

538-
# Verify all state is reset
539-
assert client._background_thread is None
540-
assert client._background_thread_session is None
541-
assert client._background_thread_event_loop is None
542-
assert not client._init_future.done() # New future created
531+
mock_content = MCPTextContent(type="text", text="Test message")
532+
mock_result = MCPCallToolResult(isError=False, content=[mock_content])
533+
mock_session.call_tool.return_value = mock_result
534+
535+
default_timeout = timedelta(minutes=10)
536+
with MCPClient(mock_transport["transport_callable"], default_tool_timeout=default_timeout) as client:
537+
with (
538+
patch("asyncio.run_coroutine_threadsafe") as mock_run_coroutine_threadsafe,
539+
patch("asyncio.wrap_future") as mock_wrap_future,
540+
):
541+
mock_future = MagicMock()
542+
mock_run_coroutine_threadsafe.return_value = mock_future
543+
544+
# Create an async mock that resolves to the mock result
545+
async def mock_awaitable():
546+
return mock_result
547+
548+
mock_wrap_future.return_value = mock_awaitable()
549+
550+
result = await client.call_tool_async(
551+
tool_use_id="test-123",
552+
name="test_tool",
553+
arguments={"param": "value"},
554+
# No read_timeout_seconds specified - should use default
555+
)
556+
557+
# Verify the default timeout was used
558+
mock_run_coroutine_threadsafe.assert_called_once()
559+
mock_wrap_future.assert_called_once_with(mock_future)
560+
561+
assert result["status"] == "success"
562+
assert result["toolUseId"] == "test-123"
563+
564+
565+
def test_call_tool_sync_with_default_timeout(mock_transport, mock_session):
566+
"""Test that call_tool_sync uses default timeout when none specified."""
567+
from datetime import timedelta
568+
569+
mock_content = MCPTextContent(type="text", text="Test message")
570+
mock_session.call_tool.return_value = MCPCallToolResult(isError=False, content=[mock_content])
571+
572+
default_timeout = timedelta(minutes=10)
573+
with MCPClient(mock_transport["transport_callable"], default_tool_timeout=default_timeout) as client:
574+
result = client.call_tool_sync(tool_use_id="test-123", name="test_tool", arguments={"param": "value"})
575+
576+
# The session.call_tool should have been called with the default timeout
577+
mock_session.call_tool.assert_called_once_with("test_tool", {"param": "value"}, default_timeout)
578+
579+
assert result["status"] == "success"
580+
assert result["toolUseId"] == "test-123"
581+
582+
583+
def test_call_tool_sync_explicit_timeout_overrides_default(mock_transport, mock_session):
584+
"""Test that explicit timeout overrides default timeout."""
585+
from datetime import timedelta
586+
587+
mock_content = MCPTextContent(type="text", text="Test message")
588+
mock_session.call_tool.return_value = MCPCallToolResult(isError=False, content=[mock_content])
589+
590+
default_timeout = timedelta(minutes=10)
591+
explicit_timeout = timedelta(minutes=5)
592+
593+
with MCPClient(mock_transport["transport_callable"], default_tool_timeout=default_timeout) as client:
594+
result = client.call_tool_sync(
595+
tool_use_id="test-123",
596+
name="test_tool",
597+
arguments={"param": "value"},
598+
read_timeout_seconds=explicit_timeout,
599+
)
600+
601+
# The session.call_tool should have been called with the explicit timeout, not default
602+
mock_session.call_tool.assert_called_once_with("test_tool", {"param": "value"}, explicit_timeout)
603+
604+
assert result["status"] == "success"
605+
assert result["toolUseId"] == "test-123"
606+
607+
608+
def test_mcp_client_initialization_with_default_timeout():
609+
"""Test that MCPClient can be initialized with default_tool_timeout."""
610+
from datetime import timedelta
611+
612+
default_timeout = timedelta(minutes=15)
613+
client = MCPClient(MagicMock(), default_tool_timeout=default_timeout)
614+
615+
assert client._default_tool_timeout == default_timeout
616+
617+
618+
def test_mcp_client_initialization_without_default_timeout():
619+
"""Test that MCPClient initializes with None default_tool_timeout when not specified."""
620+
client = MCPClient(MagicMock())
621+
622+
assert client._default_tool_timeout is None

0 commit comments

Comments
 (0)