Skip to content

Commit fa851d9

Browse files
feat: backwards-compatible create_message overloads for SEP-1577 (#1713)
1 parent f82b0c9 commit fa851d9

File tree

9 files changed

+209
-32
lines changed

9 files changed

+209
-32
lines changed

README.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -948,8 +948,9 @@ async def generate_poem(topic: str, ctx: Context[ServerSession, None]) -> str:
948948
max_tokens=100,
949949
)
950950

951-
if all(c.type == "text" for c in result.content_as_list):
952-
return "\n".join(c.text for c in result.content_as_list if c.type == "text")
951+
# Since we're not passing tools param, result.content is single content
952+
if result.content.type == "text":
953+
return result.content.text
953954
return str(result.content)
954955
```
955956

examples/servers/everything-server/mcp_everything_server/server.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -178,8 +178,9 @@ async def test_sampling(prompt: str, ctx: Context[ServerSession, None]) -> str:
178178
max_tokens=100,
179179
)
180180

181-
if any(c.type == "text" for c in result.content_as_list):
182-
model_response = "\n".join(c.text for c in result.content_as_list if c.type == "text")
181+
# Since we're not passing tools param, result.content is single content
182+
if result.content.type == "text":
183+
model_response = result.content.text
183184
else:
184185
model_response = "No response"
185186

examples/snippets/servers/sampling.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ async def generate_poem(topic: str, ctx: Context[ServerSession, None]) -> str:
2020
max_tokens=100,
2121
)
2222

23-
if all(c.type == "text" for c in result.content_as_list):
24-
return "\n".join(c.text for c in result.content_as_list if c.type == "text")
23+
# Since we're not passing tools param, result.content is single content
24+
if result.content.type == "text":
25+
return result.content.text
2526
return str(result.content)

src/mcp/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
CompleteRequest,
1414
CreateMessageRequest,
1515
CreateMessageResult,
16+
CreateMessageResultWithTools,
1617
ErrorData,
1718
GetPromptRequest,
1819
GetPromptResult,
@@ -42,6 +43,7 @@
4243
ResourceUpdatedNotification,
4344
RootsCapability,
4445
SamplingCapability,
46+
SamplingContent,
4547
SamplingContextCapability,
4648
SamplingMessage,
4749
SamplingMessageContentBlock,
@@ -75,6 +77,7 @@
7577
"CompleteRequest",
7678
"CreateMessageRequest",
7779
"CreateMessageResult",
80+
"CreateMessageResultWithTools",
7881
"ErrorData",
7982
"GetPromptRequest",
8083
"GetPromptResult",
@@ -105,6 +108,7 @@
105108
"ResourceUpdatedNotification",
106109
"RootsCapability",
107110
"SamplingCapability",
111+
"SamplingContent",
108112
"SamplingContextCapability",
109113
"SamplingMessage",
110114
"SamplingMessageContentBlock",

src/mcp/server/session.py

Lines changed: 67 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ async def handle_list_prompts(ctx: RequestContext) -> list[types.Prompt]:
3838
"""
3939

4040
from enum import Enum
41-
from typing import Any, TypeVar
41+
from typing import Any, TypeVar, overload
4242

4343
import anyio
4444
import anyio.lowlevel
@@ -233,6 +233,7 @@ async def send_resource_updated(self, uri: AnyUrl) -> None: # pragma: no cover
233233
)
234234
)
235235

236+
@overload
236237
async def create_message(
237238
self,
238239
messages: list[types.SamplingMessage],
@@ -244,10 +245,47 @@ async def create_message(
244245
stop_sequences: list[str] | None = None,
245246
metadata: dict[str, Any] | None = None,
246247
model_preferences: types.ModelPreferences | None = None,
247-
tools: list[types.Tool] | None = None,
248+
tools: None = None,
248249
tool_choice: types.ToolChoice | None = None,
249250
related_request_id: types.RequestId | None = None,
250251
) -> types.CreateMessageResult:
252+
"""Overload: Without tools, returns single content."""
253+
...
254+
255+
@overload
256+
async def create_message(
257+
self,
258+
messages: list[types.SamplingMessage],
259+
*,
260+
max_tokens: int,
261+
system_prompt: str | None = None,
262+
include_context: types.IncludeContext | None = None,
263+
temperature: float | None = None,
264+
stop_sequences: list[str] | None = None,
265+
metadata: dict[str, Any] | None = None,
266+
model_preferences: types.ModelPreferences | None = None,
267+
tools: list[types.Tool],
268+
tool_choice: types.ToolChoice | None = None,
269+
related_request_id: types.RequestId | None = None,
270+
) -> types.CreateMessageResultWithTools:
271+
"""Overload: With tools, returns array-capable content."""
272+
...
273+
274+
async def create_message(
275+
self,
276+
messages: list[types.SamplingMessage],
277+
*,
278+
max_tokens: int,
279+
system_prompt: str | None = None,
280+
include_context: types.IncludeContext | None = None,
281+
temperature: float | None = None,
282+
stop_sequences: list[str] | None = None,
283+
metadata: dict[str, Any] | None = None,
284+
model_preferences: types.ModelPreferences | None = None,
285+
tools: list[types.Tool] | None = None,
286+
tool_choice: types.ToolChoice | None = None,
287+
related_request_id: types.RequestId | None = None,
288+
) -> types.CreateMessageResult | types.CreateMessageResultWithTools:
251289
"""Send a sampling/create_message request.
252290
253291
Args:
@@ -278,27 +316,35 @@ async def create_message(
278316
validate_sampling_tools(client_caps, tools, tool_choice)
279317
validate_tool_use_result_messages(messages)
280318

319+
request = types.ServerRequest(
320+
types.CreateMessageRequest(
321+
params=types.CreateMessageRequestParams(
322+
messages=messages,
323+
systemPrompt=system_prompt,
324+
includeContext=include_context,
325+
temperature=temperature,
326+
maxTokens=max_tokens,
327+
stopSequences=stop_sequences,
328+
metadata=metadata,
329+
modelPreferences=model_preferences,
330+
tools=tools,
331+
toolChoice=tool_choice,
332+
),
333+
)
334+
)
335+
metadata_obj = ServerMessageMetadata(related_request_id=related_request_id)
336+
337+
# Use different result types based on whether tools are provided
338+
if tools is not None:
339+
return await self.send_request(
340+
request=request,
341+
result_type=types.CreateMessageResultWithTools,
342+
metadata=metadata_obj,
343+
)
281344
return await self.send_request(
282-
request=types.ServerRequest(
283-
types.CreateMessageRequest(
284-
params=types.CreateMessageRequestParams(
285-
messages=messages,
286-
systemPrompt=system_prompt,
287-
includeContext=include_context,
288-
temperature=temperature,
289-
maxTokens=max_tokens,
290-
stopSequences=stop_sequences,
291-
metadata=metadata,
292-
modelPreferences=model_preferences,
293-
tools=tools,
294-
toolChoice=tool_choice,
295-
),
296-
)
297-
),
345+
request=request,
298346
result_type=types.CreateMessageResult,
299-
metadata=ServerMessageMetadata(
300-
related_request_id=related_request_id,
301-
),
347+
metadata=metadata_obj,
302348
)
303349

304350
async def list_roots(self) -> types.ListRootsResult:

src/mcp/types.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1146,6 +1146,10 @@ class ToolResultContent(BaseModel):
11461146
SamplingMessageContentBlock: TypeAlias = TextContent | ImageContent | AudioContent | ToolUseContent | ToolResultContent
11471147
"""Content block types allowed in sampling messages."""
11481148

1149+
SamplingContent: TypeAlias = TextContent | ImageContent | AudioContent
1150+
"""Basic content types for sampling responses (without tool use).
1151+
Used for backwards-compatible CreateMessageResult when tools are not used."""
1152+
11491153

11501154
class SamplingMessage(BaseModel):
11511155
"""Describes a message issued to or received from an LLM API."""
@@ -1543,7 +1547,27 @@ class CreateMessageRequest(Request[CreateMessageRequestParams, Literal["sampling
15431547

15441548

15451549
class CreateMessageResult(Result):
1546-
"""The client's response to a sampling/create_message request from the server."""
1550+
"""The client's response to a sampling/create_message request from the server.
1551+
1552+
This is the backwards-compatible version that returns single content (no arrays).
1553+
Used when the request does not include tools.
1554+
"""
1555+
1556+
role: Role
1557+
"""The role of the message sender (typically 'assistant' for LLM responses)."""
1558+
content: SamplingContent
1559+
"""Response content. Single content block (text, image, or audio)."""
1560+
model: str
1561+
"""The name of the model that generated the message."""
1562+
stopReason: StopReason | None = None
1563+
"""The reason why sampling stopped, if known."""
1564+
1565+
1566+
class CreateMessageResultWithTools(Result):
1567+
"""The client's response to a sampling/create_message request when tools were provided.
1568+
1569+
This version supports array content for tool use flows.
1570+
"""
15471571

15481572
role: Role
15491573
"""The role of the message sender (typically 'assistant' for LLM responses)."""

tests/client/test_sampling_callback.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,10 @@
88
from mcp.types import (
99
CreateMessageRequestParams,
1010
CreateMessageResult,
11+
CreateMessageResultWithTools,
1112
SamplingMessage,
1213
TextContent,
14+
ToolUseContent,
1315
)
1416

1517

@@ -56,3 +58,79 @@ async def test_sampling_tool(message: str):
5658
assert result.isError is True
5759
assert isinstance(result.content[0], TextContent)
5860
assert result.content[0].text == "Error executing tool test_sampling: Sampling not supported"
61+
62+
63+
@pytest.mark.anyio
64+
async def test_create_message_backwards_compat_single_content():
65+
"""Test backwards compatibility: create_message without tools returns single content."""
66+
from mcp.server.fastmcp import FastMCP
67+
68+
server = FastMCP("test")
69+
70+
# Callback returns single content (text)
71+
callback_return = CreateMessageResult(
72+
role="assistant",
73+
content=TextContent(type="text", text="Hello from LLM"),
74+
model="test-model",
75+
stopReason="endTurn",
76+
)
77+
78+
async def sampling_callback(
79+
context: RequestContext[ClientSession, None],
80+
params: CreateMessageRequestParams,
81+
) -> CreateMessageResult:
82+
return callback_return
83+
84+
@server.tool("test_backwards_compat")
85+
async def test_tool(message: str):
86+
# Call create_message WITHOUT tools
87+
result = await server.get_context().session.create_message(
88+
messages=[SamplingMessage(role="user", content=TextContent(type="text", text=message))],
89+
max_tokens=100,
90+
)
91+
# Backwards compat: result should be CreateMessageResult
92+
assert isinstance(result, CreateMessageResult)
93+
# Content should be single (not a list) - this is the key backwards compat check
94+
assert isinstance(result.content, TextContent)
95+
assert result.content.text == "Hello from LLM"
96+
# CreateMessageResult should NOT have content_as_list (that's on WithTools)
97+
assert not hasattr(result, "content_as_list") or not callable(getattr(result, "content_as_list", None))
98+
return True
99+
100+
async with create_session(server._mcp_server, sampling_callback=sampling_callback) as client_session:
101+
result = await client_session.call_tool("test_backwards_compat", {"message": "Test"})
102+
assert result.isError is False
103+
assert isinstance(result.content[0], TextContent)
104+
assert result.content[0].text == "true"
105+
106+
107+
@pytest.mark.anyio
108+
async def test_create_message_result_with_tools_type():
109+
"""Test that CreateMessageResultWithTools supports content_as_list."""
110+
# Test the type itself, not the overload (overload requires client capability setup)
111+
result = CreateMessageResultWithTools(
112+
role="assistant",
113+
content=ToolUseContent(type="tool_use", id="call_123", name="get_weather", input={"city": "SF"}),
114+
model="test-model",
115+
stopReason="toolUse",
116+
)
117+
118+
# CreateMessageResultWithTools should have content_as_list
119+
content_list = result.content_as_list
120+
assert len(content_list) == 1
121+
assert content_list[0].type == "tool_use"
122+
123+
# It should also work with array content
124+
result_array = CreateMessageResultWithTools(
125+
role="assistant",
126+
content=[
127+
TextContent(type="text", text="Let me check the weather"),
128+
ToolUseContent(type="tool_use", id="call_456", name="get_weather", input={"city": "NYC"}),
129+
],
130+
model="test-model",
131+
stopReason="toolUse",
132+
)
133+
content_list_array = result_array.content_as_list
134+
assert len(content_list_array) == 2
135+
assert content_list_array[0].type == "text"
136+
assert content_list_array[1].type == "tool_use"

tests/shared/test_streamable_http.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -242,8 +242,9 @@ async def handle_call_tool(name: str, args: dict[str, Any]) -> list[TextContent]
242242
)
243243

244244
# Return the sampling result in the tool response
245-
if all(c.type == "text" for c in sampling_result.content_as_list):
246-
response = "\n".join(c.text for c in sampling_result.content_as_list if c.type == "text")
245+
# Since we're not passing tools param, result.content is single content
246+
if sampling_result.content.type == "text":
247+
response = sampling_result.content.text
247248
else:
248249
response = str(sampling_result.content)
249250
return [

tests/test_types.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
ClientRequest,
99
CreateMessageRequestParams,
1010
CreateMessageResult,
11+
CreateMessageResultWithTools,
1112
Implementation,
1213
InitializeRequest,
1314
InitializeRequestParams,
@@ -239,15 +240,16 @@ async def test_create_message_request_params_with_tools():
239240

240241
@pytest.mark.anyio
241242
async def test_create_message_result_with_tool_use():
242-
"""Test CreateMessageResult with tool use content for SEP-1577."""
243+
"""Test CreateMessageResultWithTools with tool use content for SEP-1577."""
243244
result_data = {
244245
"role": "assistant",
245246
"content": {"type": "tool_use", "name": "search", "id": "call_123", "input": {"query": "test"}},
246247
"model": "claude-3",
247248
"stopReason": "toolUse",
248249
}
249250

250-
result = CreateMessageResult.model_validate(result_data)
251+
# Tool use content uses CreateMessageResultWithTools
252+
result = CreateMessageResultWithTools.model_validate(result_data)
251253
assert result.role == "assistant"
252254
assert isinstance(result.content, ToolUseContent)
253255
assert result.stopReason == "toolUse"
@@ -259,6 +261,25 @@ async def test_create_message_result_with_tool_use():
259261
assert content_list[0] == result.content
260262

261263

264+
@pytest.mark.anyio
265+
async def test_create_message_result_basic():
266+
"""Test CreateMessageResult with basic text content (backwards compatible)."""
267+
result_data = {
268+
"role": "assistant",
269+
"content": {"type": "text", "text": "Hello!"},
270+
"model": "claude-3",
271+
"stopReason": "endTurn",
272+
}
273+
274+
# Basic content uses CreateMessageResult (single content, no arrays)
275+
result = CreateMessageResult.model_validate(result_data)
276+
assert result.role == "assistant"
277+
assert isinstance(result.content, TextContent)
278+
assert result.content.text == "Hello!"
279+
assert result.stopReason == "endTurn"
280+
assert result.model == "claude-3"
281+
282+
262283
@pytest.mark.anyio
263284
async def test_client_capabilities_with_sampling_tools():
264285
"""Test ClientCapabilities with nested sampling capabilities for SEP-1577."""

0 commit comments

Comments
 (0)