@@ -38,7 +38,7 @@ async def handle_list_prompts(ctx: RequestContext) -> list[types.Prompt]:
3838"""
3939
4040from enum import Enum
41- from typing import Any , TypeVar
41+ from typing import Any , TypeVar , overload
4242
4343import anyio
4444import anyio .lowlevel
@@ -233,6 +233,7 @@ async def send_resource_updated(self, uri: AnyUrl) -> None: # pragma: no cover
233233 )
234234 )
235235
236+ @overload
236237 async def create_message (
237238 self ,
238239 messages : list [types .SamplingMessage ],
@@ -244,10 +245,47 @@ async def create_message(
244245 stop_sequences : list [str ] | None = None ,
245246 metadata : dict [str , Any ] | None = None ,
246247 model_preferences : types .ModelPreferences | None = None ,
247- tools : list [ types . Tool ] | None = None ,
248+ tools : None = None ,
248249 tool_choice : types .ToolChoice | None = None ,
249250 related_request_id : types .RequestId | None = None ,
250251 ) -> types .CreateMessageResult :
252+ """Overload: Without tools, returns single content."""
253+ ...
254+
255+ @overload
256+ async def create_message (
257+ self ,
258+ messages : list [types .SamplingMessage ],
259+ * ,
260+ max_tokens : int ,
261+ system_prompt : str | None = None ,
262+ include_context : types .IncludeContext | None = None ,
263+ temperature : float | None = None ,
264+ stop_sequences : list [str ] | None = None ,
265+ metadata : dict [str , Any ] | None = None ,
266+ model_preferences : types .ModelPreferences | None = None ,
267+ tools : list [types .Tool ],
268+ tool_choice : types .ToolChoice | None = None ,
269+ related_request_id : types .RequestId | None = None ,
270+ ) -> types .CreateMessageResultWithTools :
271+ """Overload: With tools, returns array-capable content."""
272+ ...
273+
274+ async def create_message (
275+ self ,
276+ messages : list [types .SamplingMessage ],
277+ * ,
278+ max_tokens : int ,
279+ system_prompt : str | None = None ,
280+ include_context : types .IncludeContext | None = None ,
281+ temperature : float | None = None ,
282+ stop_sequences : list [str ] | None = None ,
283+ metadata : dict [str , Any ] | None = None ,
284+ model_preferences : types .ModelPreferences | None = None ,
285+ tools : list [types .Tool ] | None = None ,
286+ tool_choice : types .ToolChoice | None = None ,
287+ related_request_id : types .RequestId | None = None ,
288+ ) -> types .CreateMessageResult | types .CreateMessageResultWithTools :
251289 """Send a sampling/create_message request.
252290
253291 Args:
@@ -278,27 +316,35 @@ async def create_message(
278316 validate_sampling_tools (client_caps , tools , tool_choice )
279317 validate_tool_use_result_messages (messages )
280318
319+ request = types .ServerRequest (
320+ types .CreateMessageRequest (
321+ params = types .CreateMessageRequestParams (
322+ messages = messages ,
323+ systemPrompt = system_prompt ,
324+ includeContext = include_context ,
325+ temperature = temperature ,
326+ maxTokens = max_tokens ,
327+ stopSequences = stop_sequences ,
328+ metadata = metadata ,
329+ modelPreferences = model_preferences ,
330+ tools = tools ,
331+ toolChoice = tool_choice ,
332+ ),
333+ )
334+ )
335+ metadata_obj = ServerMessageMetadata (related_request_id = related_request_id )
336+
337+ # Use different result types based on whether tools are provided
338+ if tools is not None :
339+ return await self .send_request (
340+ request = request ,
341+ result_type = types .CreateMessageResultWithTools ,
342+ metadata = metadata_obj ,
343+ )
281344 return await self .send_request (
282- request = types .ServerRequest (
283- types .CreateMessageRequest (
284- params = types .CreateMessageRequestParams (
285- messages = messages ,
286- systemPrompt = system_prompt ,
287- includeContext = include_context ,
288- temperature = temperature ,
289- maxTokens = max_tokens ,
290- stopSequences = stop_sequences ,
291- metadata = metadata ,
292- modelPreferences = model_preferences ,
293- tools = tools ,
294- toolChoice = tool_choice ,
295- ),
296- )
297- ),
345+ request = request ,
298346 result_type = types .CreateMessageResult ,
299- metadata = ServerMessageMetadata (
300- related_request_id = related_request_id ,
301- ),
347+ metadata = metadata_obj ,
302348 )
303349
304350 async def list_roots (self ) -> types .ListRootsResult :
0 commit comments