22
33import base64
44import functools
5+ import warnings
56from abc import ABC , abstractmethod
67from asyncio import Lock
78from collections .abc import AsyncIterator , Awaitable , Sequence
89from contextlib import AbstractAsyncContextManager , AsyncExitStack , asynccontextmanager
910from dataclasses import dataclass , field , replace
11+ from datetime import timedelta
1012from pathlib import Path
1113from typing import Any , Callable
1214
3739 ) from _import_error
3840
3941# after mcp imports so any import error maps to this file, not _mcp.py
40- from . import _mcp , exceptions , messages , models
42+ from . import _mcp , _utils , exceptions , messages , models
4143
4244__all__ = 'MCPServer' , 'MCPServerStdio' , 'MCPServerHTTP' , 'MCPServerSSE' , 'MCPServerStreamableHTTP'
4345
@@ -59,6 +61,7 @@ class MCPServer(AbstractToolset[Any], ABC):
5961 log_level : mcp_types .LoggingLevel | None = None
6062 log_handler : LoggingFnT | None = None
6163 timeout : float = 5
64+ read_timeout : float = 5 * 60
6265 process_tool_call : ProcessToolCallback | None = None
6366 allow_sampling : bool = True
6467 max_retries : int = 1
@@ -208,6 +211,7 @@ async def __aenter__(self) -> Self:
208211 write_stream = self ._write_stream ,
209212 sampling_callback = self ._sampling_callback if self .allow_sampling else None ,
210213 logging_callback = self .log_handler ,
214+ read_timeout_seconds = timedelta (seconds = self .read_timeout ),
211215 )
212216 self ._client = await self ._exit_stack .enter_async_context (client )
213217
@@ -401,7 +405,7 @@ def __repr__(self) -> str:
401405 return f'MCPServerStdio(command={ self .command !r} , args={ self .args !r} , tool_prefix={ self .tool_prefix !r} )'
402406
403407
404- @dataclass
408+ @dataclass ( init = False )
405409class _MCPServerHTTP (MCPServer ):
406410 url : str
407411 """The URL of the endpoint on the MCP server."""
@@ -438,10 +442,10 @@ class _MCPServerHTTP(MCPServer):
438442 ```
439443 """
440444
441- sse_read_timeout : float = 5 * 60
442- """Maximum time in seconds to wait for new SSE messages before timing out.
445+ read_timeout : float = 5 * 60
446+ """Maximum time in seconds to wait for new messages before timing out.
443447
444- This timeout applies to the long-lived SSE connection after it's established.
448+ This timeout applies to the long-lived connection after it's established.
445449 If no new messages are received within this time, the connection will be considered stale
446450 and may be closed. Defaults to 5 minutes (300 seconds).
447451 """
@@ -485,6 +489,51 @@ class _MCPServerHTTP(MCPServer):
485489 sampling_model : models .Model | None = None
486490 """The model to use for sampling."""
487491
492+ def __init__ (
493+ self ,
494+ * ,
495+ url : str ,
496+ headers : dict [str , str ] | None = None ,
497+ http_client : httpx .AsyncClient | None = None ,
498+ read_timeout : float | None = None ,
499+ tool_prefix : str | None = None ,
500+ log_level : mcp_types .LoggingLevel | None = None ,
501+ log_handler : LoggingFnT | None = None ,
502+ timeout : float = 5 ,
503+ process_tool_call : ProcessToolCallback | None = None ,
504+ allow_sampling : bool = True ,
505+ max_retries : int = 1 ,
506+ sampling_model : models .Model | None = None ,
507+ ** kwargs : Any ,
508+ ):
509+ # Handle deprecated sse_read_timeout parameter
510+ if 'sse_read_timeout' in kwargs :
511+ if read_timeout is not None :
512+ raise TypeError ("'read_timeout' and 'sse_read_timeout' cannot be set at the same time." )
513+
514+ warnings .warn (
515+ "'sse_read_timeout' is deprecated, use 'read_timeout' instead." , DeprecationWarning , stacklevel = 2
516+ )
517+ read_timeout = kwargs .pop ('sse_read_timeout' )
518+
519+ _utils .validate_empty_kwargs (kwargs )
520+
521+ if read_timeout is None :
522+ read_timeout = 5 * 60
523+
524+ self .url = url
525+ self .headers = headers
526+ self .http_client = http_client
527+ self .tool_prefix = tool_prefix
528+ self .log_level = log_level
529+ self .log_handler = log_handler
530+ self .timeout = timeout
531+ self .process_tool_call = process_tool_call
532+ self .allow_sampling = allow_sampling
533+ self .max_retries = max_retries
534+ self .sampling_model = sampling_model
535+ self .read_timeout = read_timeout
536+
488537 @property
489538 @abstractmethod
490539 def _transport_client (
@@ -522,7 +571,7 @@ async def client_streams(
522571 self ._transport_client ,
523572 url = self .url ,
524573 timeout = self .timeout ,
525- sse_read_timeout = self .sse_read_timeout ,
574+ sse_read_timeout = self .read_timeout ,
526575 )
527576
528577 if self .http_client is not None :
@@ -549,7 +598,7 @@ def __repr__(self) -> str: # pragma: no cover
549598 return f'{ self .__class__ .__name__ } (url={ self .url !r} , tool_prefix={ self .tool_prefix !r} )'
550599
551600
552- @dataclass
601+ @dataclass ( init = False )
553602class MCPServerSSE (_MCPServerHTTP ):
554603 """An MCP server that connects over streamable HTTP connections.
555604
0 commit comments