33import abc
44import asyncio
55from contextlib import AbstractAsyncContextManager , AsyncExitStack
6+ from datetime import timedelta
67from pathlib import Path
78from typing import Any , Literal
89
@@ -54,7 +55,7 @@ async def call_tool(self, tool_name: str, arguments: dict[str, Any] | None) -> C
5455class _MCPServerWithClientSession (MCPServer , abc .ABC ):
5556 """Base class for MCP servers that use a `ClientSession` to communicate with the server."""
5657
57- def __init__ (self , cache_tools_list : bool ):
58+ def __init__ (self , cache_tools_list : bool , client_session_timeout_seconds : float | None ):
5859 """
5960 Args:
6061 cache_tools_list: Whether to cache the tools list. If `True`, the tools list will be
@@ -63,12 +64,16 @@ def __init__(self, cache_tools_list: bool):
6364 by calling `invalidate_tools_cache()`. You should set this to `True` if you know the
6465 server will not change its tools list, because it can drastically improve latency
6566 (by avoiding a round-trip to the server every time).
67+
68+ client_session_timeout_seconds: the read timeout passed to the MCP ClientSession.
6669 """
6770 self .session : ClientSession | None = None
6871 self .exit_stack : AsyncExitStack = AsyncExitStack ()
6972 self ._cleanup_lock : asyncio .Lock = asyncio .Lock ()
7073 self .cache_tools_list = cache_tools_list
7174
75+ self .client_session_timeout_seconds = client_session_timeout_seconds
76+
7277 # The cache is always dirty at startup, so that we fetch tools at least once
7378 self ._cache_dirty = True
7479 self ._tools_list : list [MCPTool ] | None = None
@@ -101,7 +106,15 @@ async def connect(self):
101106 try :
102107 transport = await self .exit_stack .enter_async_context (self .create_streams ())
103108 read , write = transport
104- session = await self .exit_stack .enter_async_context (ClientSession (read , write ))
109+ session = await self .exit_stack .enter_async_context (
110+ ClientSession (
111+ read ,
112+ write ,
113+ timedelta (seconds = self .client_session_timeout_seconds )
114+ if self .client_session_timeout_seconds
115+ else None ,
116+ )
117+ )
105118 await session .initialize ()
106119 self .session = session
107120 except Exception as e :
@@ -183,6 +196,7 @@ def __init__(
183196 params : MCPServerStdioParams ,
184197 cache_tools_list : bool = False ,
185198 name : str | None = None ,
199+ client_session_timeout_seconds : float | None = 5 ,
186200 ):
187201 """Create a new MCP server based on the stdio transport.
188202
@@ -199,8 +213,9 @@ def __init__(
199213 improve latency (by avoiding a round-trip to the server every time).
200214 name: A readable name for the server. If not provided, we'll create one from the
201215 command.
216+ client_session_timeout_seconds: the read timeout passed to the MCP ClientSession.
202217 """
203- super ().__init__ (cache_tools_list )
218+ super ().__init__ (cache_tools_list , client_session_timeout_seconds )
204219
205220 self .params = StdioServerParameters (
206221 command = params ["command" ],
@@ -257,6 +272,7 @@ def __init__(
257272 params : MCPServerSseParams ,
258273 cache_tools_list : bool = False ,
259274 name : str | None = None ,
275+ client_session_timeout_seconds : float | None = 5 ,
260276 ):
261277 """Create a new MCP server based on the HTTP with SSE transport.
262278
@@ -274,8 +290,10 @@ def __init__(
274290
275291 name: A readable name for the server. If not provided, we'll create one from the
276292 URL.
293+
294+ client_session_timeout_seconds: the read timeout passed to the MCP ClientSession.
277295 """
278- super ().__init__ (cache_tools_list )
296+ super ().__init__ (cache_tools_list , client_session_timeout_seconds )
279297
280298 self .params = params
281299 self ._name = name or f"sse: { self .params ['url' ]} "
0 commit comments