|
8 | 8 | import socket |
9 | 9 | import time |
10 | 10 | from collections.abc import Generator |
| 11 | +from typing import Any |
11 | 12 |
|
12 | 13 | import anyio |
13 | 14 | import httpx |
|
33 | 34 | StreamId, |
34 | 35 | ) |
35 | 36 | from mcp.server.streamable_http_manager import StreamableHTTPSessionManager |
| 37 | +from mcp.shared.context import RequestContext |
36 | 38 | from mcp.shared.exceptions import McpError |
37 | 39 | from mcp.shared.message import ( |
38 | 40 | ClientMessageMetadata, |
@@ -139,6 +141,11 @@ async def handle_list_tools() -> list[Tool]: |
139 | 141 | description="A long-running tool that sends periodic notifications", |
140 | 142 | inputSchema={"type": "object", "properties": {}}, |
141 | 143 | ), |
| 144 | + Tool( |
| 145 | + name="test_sampling_tool", |
| 146 | + description="A tool that triggers server-side sampling", |
| 147 | + inputSchema={"type": "object", "properties": {}}, |
| 148 | + ), |
142 | 149 | ] |
143 | 150 |
|
144 | 151 | @self.call_tool() |
@@ -174,6 +181,34 @@ async def handle_call_tool(name: str, args: dict) -> list[TextContent]: |
174 | 181 |
|
175 | 182 | return [TextContent(type="text", text="Completed!")] |
176 | 183 |
|
| 184 | + elif name == "test_sampling_tool": |
| 185 | + # Test sampling by requesting the client to sample a message |
| 186 | + sampling_result = await ctx.session.create_message( |
| 187 | + messages=[ |
| 188 | + types.SamplingMessage( |
| 189 | + role="user", |
| 190 | + content=types.TextContent( |
| 191 | + type="text", text="Server needs client sampling" |
| 192 | + ), |
| 193 | + ) |
| 194 | + ], |
| 195 | + max_tokens=100, |
| 196 | + related_request_id=ctx.request_id, |
| 197 | + ) |
| 198 | + |
| 199 | + # Return the sampling result in the tool response |
| 200 | + response = ( |
| 201 | + sampling_result.content.text |
| 202 | + if sampling_result.content.type == "text" |
| 203 | + else None |
| 204 | + ) |
| 205 | + return [ |
| 206 | + TextContent( |
| 207 | + type="text", |
| 208 | + text=f"Response from sampling: {response}", |
| 209 | + ) |
| 210 | + ] |
| 211 | + |
177 | 212 | return [TextContent(type="text", text=f"Called {name}")] |
178 | 213 |
|
179 | 214 |
|
@@ -754,7 +789,7 @@ async def test_streamablehttp_client_tool_invocation(initialized_client_session) |
754 | 789 | """Test client tool invocation.""" |
755 | 790 | # First list tools |
756 | 791 | tools = await initialized_client_session.list_tools() |
757 | | - assert len(tools.tools) == 3 |
| 792 | + assert len(tools.tools) == 4 |
758 | 793 | assert tools.tools[0].name == "test_tool" |
759 | 794 |
|
760 | 795 | # Call the tool |
@@ -795,7 +830,7 @@ async def test_streamablehttp_client_session_persistence( |
795 | 830 |
|
796 | 831 | # Make multiple requests to verify session persistence |
797 | 832 | tools = await session.list_tools() |
798 | | - assert len(tools.tools) == 3 |
| 833 | + assert len(tools.tools) == 4 |
799 | 834 |
|
800 | 835 | # Read a resource |
801 | 836 | resource = await session.read_resource(uri=AnyUrl("foobar://test-persist")) |
@@ -826,7 +861,7 @@ async def test_streamablehttp_client_json_response( |
826 | 861 |
|
827 | 862 | # Check tool listing |
828 | 863 | tools = await session.list_tools() |
829 | | - assert len(tools.tools) == 3 |
| 864 | + assert len(tools.tools) == 4 |
830 | 865 |
|
831 | 866 | # Call a tool and verify JSON response handling |
832 | 867 | result = await session.call_tool("test_tool", {}) |
@@ -905,7 +940,7 @@ async def test_streamablehttp_client_session_termination( |
905 | 940 |
|
906 | 941 | # Make a request to confirm session is working |
907 | 942 | tools = await session.list_tools() |
908 | | - assert len(tools.tools) == 3 |
| 943 | + assert len(tools.tools) == 4 |
909 | 944 |
|
910 | 945 | headers = {} |
911 | 946 | if captured_session_id: |
@@ -1054,3 +1089,71 @@ async def run_tool(): |
1054 | 1089 | assert not any( |
1055 | 1090 | n in captured_notifications_pre for n in captured_notifications |
1056 | 1091 | ) |
| 1092 | + |
| 1093 | + |
| 1094 | +@pytest.mark.anyio |
| 1095 | +async def test_streamablehttp_server_sampling(basic_server, basic_server_url): |
| 1096 | + """Test server-initiated sampling request through streamable HTTP transport.""" |
| 1097 | + print("Testing server sampling...") |
| 1098 | + # Variable to track if sampling callback was invoked |
| 1099 | + sampling_callback_invoked = False |
| 1100 | + captured_message_params = None |
| 1101 | + |
| 1102 | + # Define sampling callback that returns a mock response |
| 1103 | + async def sampling_callback( |
| 1104 | + context: RequestContext[ClientSession, Any], |
| 1105 | + params: types.CreateMessageRequestParams, |
| 1106 | + ) -> types.CreateMessageResult: |
| 1107 | + nonlocal sampling_callback_invoked, captured_message_params |
| 1108 | + sampling_callback_invoked = True |
| 1109 | + captured_message_params = params |
| 1110 | + message_received = ( |
| 1111 | + params.messages[0].content.text |
| 1112 | + if params.messages[0].content.type == "text" |
| 1113 | + else None |
| 1114 | + ) |
| 1115 | + |
| 1116 | + return types.CreateMessageResult( |
| 1117 | + role="assistant", |
| 1118 | + content=types.TextContent( |
| 1119 | + type="text", |
| 1120 | + text=f"Received message from server: {message_received}", |
| 1121 | + ), |
| 1122 | + model="test-model", |
| 1123 | + stopReason="endTurn", |
| 1124 | + ) |
| 1125 | + |
| 1126 | + # Create client with sampling callback |
| 1127 | + async with streamablehttp_client(f"{basic_server_url}/mcp") as ( |
| 1128 | + read_stream, |
| 1129 | + write_stream, |
| 1130 | + _, |
| 1131 | + ): |
| 1132 | + async with ClientSession( |
| 1133 | + read_stream, |
| 1134 | + write_stream, |
| 1135 | + sampling_callback=sampling_callback, |
| 1136 | + ) as session: |
| 1137 | + # Initialize the session |
| 1138 | + result = await session.initialize() |
| 1139 | + assert isinstance(result, InitializeResult) |
| 1140 | + |
| 1141 | + # Call the tool that triggers server-side sampling |
| 1142 | + tool_result = await session.call_tool("test_sampling_tool", {}) |
| 1143 | + |
| 1144 | + # Verify the tool result contains the expected content |
| 1145 | + assert len(tool_result.content) == 1 |
| 1146 | + assert tool_result.content[0].type == "text" |
| 1147 | + assert ( |
| 1148 | + "Response from sampling: Received message from server" |
| 1149 | + in tool_result.content[0].text |
| 1150 | + ) |
| 1151 | + |
| 1152 | + # Verify sampling callback was invoked |
| 1153 | + assert sampling_callback_invoked |
| 1154 | + assert captured_message_params is not None |
| 1155 | + assert len(captured_message_params.messages) == 1 |
| 1156 | + assert ( |
| 1157 | + captured_message_params.messages[0].content.text |
| 1158 | + == "Server needs client sampling" |
| 1159 | + ) |
0 commit comments