|
4 | 4 | import time |
5 | 5 | from collections.abc import AsyncGenerator, Generator |
6 | 6 | from typing import Any |
| 7 | +from unittest.mock import Mock |
7 | 8 |
|
8 | 9 | import anyio |
9 | 10 | import httpx |
|
16 | 17 | from starlette.responses import Response |
17 | 18 | from starlette.routing import Mount, Route |
18 | 19 |
|
| 20 | +import mcp.client.sse |
19 | 21 | import mcp.types as types |
20 | 22 | from mcp.client.session import ClientSession |
21 | | -from mcp.client.sse import sse_client |
| 23 | +from mcp.client.sse import _extract_session_id_from_endpoint, sse_client |
22 | 24 | from mcp.server import Server |
23 | 25 | from mcp.server.sse import SseServerTransport |
24 | 26 | from mcp.server.transport_security import TransportSecuritySettings |
@@ -184,6 +186,57 @@ async def test_sse_client_basic_connection(server: None, server_url: str) -> Non |
184 | 186 | assert isinstance(ping_result, EmptyResult) |
185 | 187 |
|
186 | 188 |
|
| 189 | +@pytest.mark.anyio |
| 190 | +async def test_sse_client_on_session_created(server: None, server_url: str) -> None: |
| 191 | + captured_session_id: str | None = None |
| 192 | + |
| 193 | + def on_session_created(session_id: str) -> None: |
| 194 | + nonlocal captured_session_id |
| 195 | + captured_session_id = session_id |
| 196 | + |
| 197 | + async with sse_client(server_url + "/sse", on_session_created=on_session_created) as streams: |
| 198 | + async with ClientSession(*streams) as session: |
| 199 | + result = await session.initialize() |
| 200 | + assert isinstance(result, InitializeResult) |
| 201 | + |
| 202 | + assert captured_session_id is not None |
| 203 | + assert len(captured_session_id) > 0 |
| 204 | + |
| 205 | + |
| 206 | +@pytest.mark.parametrize( |
| 207 | + "endpoint_url,expected", |
| 208 | + [ |
| 209 | + ("/messages?sessionId=abc123", "abc123"), |
| 210 | + ("/messages?session_id=def456", "def456"), |
| 211 | + ("/messages?sessionId=abc&session_id=def", "abc"), |
| 212 | + ("/messages?other=value", None), |
| 213 | + ("/messages", None), |
| 214 | + ("", None), |
| 215 | + ], |
| 216 | +) |
| 217 | +def test_extract_session_id_from_endpoint(endpoint_url: str, expected: str | None) -> None: |
| 218 | + assert _extract_session_id_from_endpoint(endpoint_url) == expected |
| 219 | + |
| 220 | + |
| 221 | +@pytest.mark.anyio |
| 222 | +async def test_sse_client_on_session_created_not_called_when_no_session_id( |
| 223 | + server: None, server_url: str, monkeypatch: pytest.MonkeyPatch |
| 224 | +) -> None: |
| 225 | + callback_mock = Mock() |
| 226 | + |
| 227 | + def mock_extract(url: str) -> None: |
| 228 | + return None |
| 229 | + |
| 230 | + monkeypatch.setattr(mcp.client.sse, "_extract_session_id_from_endpoint", mock_extract) |
| 231 | + |
| 232 | + async with sse_client(server_url + "/sse", on_session_created=callback_mock) as streams: |
| 233 | + async with ClientSession(*streams) as session: |
| 234 | + result = await session.initialize() |
| 235 | + assert isinstance(result, InitializeResult) |
| 236 | + |
| 237 | + callback_mock.assert_not_called() |
| 238 | + |
| 239 | + |
187 | 240 | @pytest.fixture |
188 | 241 | async def initialized_sse_client_session(server: None, server_url: str) -> AsyncGenerator[ClientSession, None]: |
189 | 242 | async with sse_client(server_url + "/sse", sse_read_timeout=0.5) as streams: |
|
0 commit comments