1010from anyio .streams .memory import MemoryObjectReceiveStream , MemoryObjectSendStream
1111from mcp import ClientSession , StdioServerParameters , Tool as MCPTool , stdio_client
1212from mcp .client .sse import sse_client
13- from mcp .types import CallToolResult , JSONRPCMessage
13+ from mcp .client .streamable_http import GetSessionIdCallback , streamablehttp_client
14+ from mcp .shared .message import SessionMessage
15+ from mcp .types import CallToolResult
1416from typing_extensions import NotRequired , TypedDict
1517
1618from ..exceptions import UserError
@@ -83,8 +85,9 @@ def create_streams(
8385 self ,
8486 ) -> AbstractAsyncContextManager [
8587 tuple [
86- MemoryObjectReceiveStream [JSONRPCMessage | Exception ],
87- MemoryObjectSendStream [JSONRPCMessage ],
88+ MemoryObjectReceiveStream [SessionMessage | Exception ],
89+ MemoryObjectSendStream [SessionMessage ],
90+ GetSessionIdCallback | None
8891 ]
8992 ]:
9093 """Create the streams for the server."""
@@ -105,7 +108,11 @@ async def connect(self):
105108 """Connect to the server."""
106109 try :
107110 transport = await self .exit_stack .enter_async_context (self .create_streams ())
108- read , write = transport
111+ # streamablehttp_client returns (read, write, get_session_id)
112+ # sse_client returns (read, write)
113+
114+ read , write , * _ = transport
115+
109116 session = await self .exit_stack .enter_async_context (
110117 ClientSession (
111118 read ,
@@ -232,8 +239,9 @@ def create_streams(
232239 self ,
233240 ) -> AbstractAsyncContextManager [
234241 tuple [
235- MemoryObjectReceiveStream [JSONRPCMessage | Exception ],
236- MemoryObjectSendStream [JSONRPCMessage ],
242+ MemoryObjectReceiveStream [SessionMessage | Exception ],
243+ MemoryObjectSendStream [SessionMessage ],
244+ GetSessionIdCallback | None
237245 ]
238246 ]:
239247 """Create the streams for the server."""
@@ -302,8 +310,9 @@ def create_streams(
302310 self ,
303311 ) -> AbstractAsyncContextManager [
304312 tuple [
305- MemoryObjectReceiveStream [JSONRPCMessage | Exception ],
306- MemoryObjectSendStream [JSONRPCMessage ],
313+ MemoryObjectReceiveStream [SessionMessage | Exception ],
314+ MemoryObjectSendStream [SessionMessage ],
315+ GetSessionIdCallback | None
307316 ]
308317 ]:
309318 """Create the streams for the server."""
@@ -318,3 +327,84 @@ def create_streams(
318327 def name (self ) -> str :
319328 """A readable name for the server."""
320329 return self ._name
330+
331+
332+ class MCPServerStreamableHttpParams (TypedDict ):
333+ """Mirrors the params in`mcp.client.streamable_http.streamablehttp_client`."""
334+
335+ url : str
336+ """The URL of the server."""
337+
338+ headers : NotRequired [dict [str , str ]]
339+ """The headers to send to the server."""
340+
341+ timeout : NotRequired [timedelta ]
342+ """The timeout for the HTTP request. Defaults to 5 seconds."""
343+
344+ sse_read_timeout : NotRequired [timedelta ]
345+ """The timeout for the SSE connection, in seconds. Defaults to 5 minutes."""
346+
347+ terminate_on_close : NotRequired [bool ]
348+ """Terminate on close"""
349+
350+
351+ class MCPServerStreamableHttp (_MCPServerWithClientSession ):
352+ """MCP server implementation that uses the Streamable HTTP transport. See the [spec]
353+ (https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#streamable-http)
354+ for details.
355+ """
356+
357+ def __init__ (
358+ self ,
359+ params : MCPServerStreamableHttpParams ,
360+ cache_tools_list : bool = False ,
361+ name : str | None = None ,
362+ client_session_timeout_seconds : float | None = 5 ,
363+ ):
364+ """Create a new MCP server based on the Streamable HTTP transport.
365+
366+ Args:
367+ params: The params that configure the server. This includes the URL of the server,
368+ the headers to send to the server, the timeout for the HTTP request, and the
369+ timeout for the Streamable HTTP connection and whether we need to
370+ terminate on close.
371+
372+ cache_tools_list: Whether to cache the tools list. If `True`, the tools list will be
373+ cached and only fetched from the server once. If `False`, the tools list will be
374+ fetched from the server on each call to `list_tools()`. The cache can be
375+ invalidated by calling `invalidate_tools_cache()`. You should set this to `True`
376+ if you know the server will not change its tools list, because it can drastically
377+ improve latency (by avoiding a round-trip to the server every time).
378+
379+ name: A readable name for the server. If not provided, we'll create one from the
380+ URL.
381+
382+ client_session_timeout_seconds: the read timeout passed to the MCP ClientSession.
383+ """
384+ super ().__init__ (cache_tools_list , client_session_timeout_seconds )
385+
386+ self .params = params
387+ self ._name = name or f"streamable_http: { self .params ['url' ]} "
388+
389+ def create_streams (
390+ self ,
391+ ) -> AbstractAsyncContextManager [
392+ tuple [
393+ MemoryObjectReceiveStream [SessionMessage | Exception ],
394+ MemoryObjectSendStream [SessionMessage ],
395+ GetSessionIdCallback | None
396+ ]
397+ ]:
398+ """Create the streams for the server."""
399+ return streamablehttp_client (
400+ url = self .params ["url" ],
401+ headers = self .params .get ("headers" , None ),
402+ timeout = self .params .get ("timeout" , timedelta (seconds = 30 )),
403+ sse_read_timeout = self .params .get ("sse_read_timeout" , timedelta (seconds = 60 * 5 )),
404+ terminate_on_close = self .params .get ("terminate_on_close" , True )
405+ )
406+
407+ @property
408+ def name (self ) -> str :
409+ """A readable name for the server."""
410+ return self ._name
0 commit comments