Skip to content

Commit 2cd178a

Browse files
Add on_session_created callback option (#1710)
1 parent c92bb2f commit 2cd178a

File tree

2 files changed

+68
-2
lines changed

2 files changed

+68
-2
lines changed

src/mcp/client/sse.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import logging
2+
from collections.abc import Callable
23
from contextlib import asynccontextmanager
34
from typing import Any
4-
from urllib.parse import urljoin, urlparse
5+
from urllib.parse import parse_qs, urljoin, urlparse
56

67
import anyio
78
import httpx
@@ -21,6 +22,11 @@ def remove_request_params(url: str) -> str:
2122
return urljoin(url, urlparse(url).path)
2223

2324

25+
def _extract_session_id_from_endpoint(endpoint_url: str) -> str | None:
26+
query_params = parse_qs(urlparse(endpoint_url).query)
27+
return query_params.get("sessionId", [None])[0] or query_params.get("session_id", [None])[0]
28+
29+
2430
@asynccontextmanager
2531
async def sse_client(
2632
url: str,
@@ -29,6 +35,7 @@ async def sse_client(
2935
sse_read_timeout: float = 60 * 5,
3036
httpx_client_factory: McpHttpClientFactory = create_mcp_http_client,
3137
auth: httpx.Auth | None = None,
38+
on_session_created: Callable[[str], None] | None = None,
3239
):
3340
"""
3441
Client transport for SSE.
@@ -42,6 +49,7 @@ async def sse_client(
4249
timeout: HTTP timeout for regular operations.
4350
sse_read_timeout: Timeout for SSE read operations.
4451
auth: Optional HTTPX authentication handler.
52+
on_session_created: Optional callback invoked with the session ID when received.
4553
"""
4654
read_stream: MemoryObjectReceiveStream[SessionMessage | Exception]
4755
read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception]
@@ -89,6 +97,11 @@ async def sse_reader(
8997
logger.error(error_msg) # pragma: no cover
9098
raise ValueError(error_msg) # pragma: no cover
9199

100+
if on_session_created:
101+
session_id = _extract_session_id_from_endpoint(endpoint_url)
102+
if session_id:
103+
on_session_created(session_id)
104+
92105
task_status.started(endpoint_url)
93106

94107
case "message":

tests/shared/test_sse.py

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import time
55
from collections.abc import AsyncGenerator, Generator
66
from typing import Any
7+
from unittest.mock import Mock
78

89
import anyio
910
import httpx
@@ -16,9 +17,10 @@
1617
from starlette.responses import Response
1718
from starlette.routing import Mount, Route
1819

20+
import mcp.client.sse
1921
import mcp.types as types
2022
from mcp.client.session import ClientSession
21-
from mcp.client.sse import sse_client
23+
from mcp.client.sse import _extract_session_id_from_endpoint, sse_client
2224
from mcp.server import Server
2325
from mcp.server.sse import SseServerTransport
2426
from mcp.server.transport_security import TransportSecuritySettings
@@ -184,6 +186,57 @@ async def test_sse_client_basic_connection(server: None, server_url: str) -> Non
184186
assert isinstance(ping_result, EmptyResult)
185187

186188

189+
@pytest.mark.anyio
190+
async def test_sse_client_on_session_created(server: None, server_url: str) -> None:
191+
captured_session_id: str | None = None
192+
193+
def on_session_created(session_id: str) -> None:
194+
nonlocal captured_session_id
195+
captured_session_id = session_id
196+
197+
async with sse_client(server_url + "/sse", on_session_created=on_session_created) as streams:
198+
async with ClientSession(*streams) as session:
199+
result = await session.initialize()
200+
assert isinstance(result, InitializeResult)
201+
202+
assert captured_session_id is not None
203+
assert len(captured_session_id) > 0
204+
205+
206+
@pytest.mark.parametrize(
207+
"endpoint_url,expected",
208+
[
209+
("/messages?sessionId=abc123", "abc123"),
210+
("/messages?session_id=def456", "def456"),
211+
("/messages?sessionId=abc&session_id=def", "abc"),
212+
("/messages?other=value", None),
213+
("/messages", None),
214+
("", None),
215+
],
216+
)
217+
def test_extract_session_id_from_endpoint(endpoint_url: str, expected: str | None) -> None:
218+
assert _extract_session_id_from_endpoint(endpoint_url) == expected
219+
220+
221+
@pytest.mark.anyio
222+
async def test_sse_client_on_session_created_not_called_when_no_session_id(
223+
server: None, server_url: str, monkeypatch: pytest.MonkeyPatch
224+
) -> None:
225+
callback_mock = Mock()
226+
227+
def mock_extract(url: str) -> None:
228+
return None
229+
230+
monkeypatch.setattr(mcp.client.sse, "_extract_session_id_from_endpoint", mock_extract)
231+
232+
async with sse_client(server_url + "/sse", on_session_created=callback_mock) as streams:
233+
async with ClientSession(*streams) as session:
234+
result = await session.initialize()
235+
assert isinstance(result, InitializeResult)
236+
237+
callback_mock.assert_not_called()
238+
239+
187240
@pytest.fixture
188241
async def initialized_sse_client_session(server: None, server_url: str) -> AsyncGenerator[ClientSession, None]:
189242
async with sse_client(server_url + "/sse", sse_read_timeout=0.5) as streams:

0 commit comments

Comments
 (0)