33import base64
44import functools
55from abc import ABC , abstractmethod
6- from collections .abc import AsyncIterator , Awaitable , Iterator , Sequence
7- from contextlib import AbstractAsyncContextManager , AsyncExitStack , asynccontextmanager , contextmanager
8- from contextvars import ContextVar
6+ from collections .abc import AsyncIterator , Awaitable , Sequence
7+ from contextlib import AbstractAsyncContextManager , AsyncExitStack , asynccontextmanager
98from dataclasses import dataclass
109from pathlib import Path
1110from types import TracebackType
@@ -61,22 +60,6 @@ class MCPServer(ABC):
6160 _exit_stack : AsyncExitStack
6261 sampling_model : models .Model | None = None
6362
64- def __post_init__ (self ):
65- self ._override_sampling_model : ContextVar [models .Model | None ] = ContextVar (
66- '_override_sampling_model' , default = None
67- )
68-
69- @contextmanager
70- def override_sampling_model (
71- self ,
72- model : models .Model ,
73- ) -> Iterator [None ]:
74- token = self ._override_sampling_model .set (model )
75- try :
76- yield
77- finally :
78- self ._override_sampling_model .reset (token )
79-
8063 @abstractmethod
8164 @asynccontextmanager
8265 async def client_streams (
@@ -201,8 +184,7 @@ async def _sampling_callback(
201184 self , context : RequestContext [ClientSession , Any ], params : mcp_types .CreateMessageRequestParams
202185 ) -> mcp_types .CreateMessageResult | mcp_types .ErrorData :
203186 """MCP sampling callback."""
204- sampling_model = self ._override_sampling_model .get () or self .sampling_model
205- if sampling_model is None :
187+ if self .sampling_model is None :
206188 raise ValueError ('Sampling model is not set' ) # pragma: no cover
207189
208190 pai_messages = _mcp .map_from_mcp_params (params )
@@ -214,15 +196,15 @@ async def _sampling_callback(
214196 if stop_sequences := params .stopSequences : # pragma: no branch
215197 model_settings ['stop_sequences' ] = stop_sequences
216198
217- model_response = await sampling_model .request (
199+ model_response = await self . sampling_model .request (
218200 pai_messages ,
219201 model_settings ,
220202 models .ModelRequestParameters (),
221203 )
222204 return mcp_types .CreateMessageResult (
223205 role = 'assistant' ,
224206 content = _mcp .map_from_model_response (model_response ),
225- model = sampling_model .model_name ,
207+ model = self . sampling_model .model_name ,
226208 )
227209
228210 def _map_tool_result_part (
0 commit comments