11import json
22import httpx
3- from typing import Dict , Optional , Any , List , Union , Callable , Awaitable , Iterable , Literal , Sequence
3+ from typing import Dict , Optional , Any , List , Union , Literal , Sequence
44from typing_extensions import Annotated , Doc
55
66from fastapi import FastAPI , Request , APIRouter , params
1919logger = logging .getLogger (__name__ )
2020
2121
22- class LowlevelMCPServer (Server ):
23- def call_tool (self ):
24- """
25- A near-direct copy of `mcp.server.lowlevel.server.Server.call_tool()`, except that it looks for
26- the original HTTP request info in the MCP message, and passes it to the tool call handler.
27- """
28-
29- def decorator (
30- func : Callable [
31- ...,
32- Awaitable [Iterable [types .TextContent | types .ImageContent | types .EmbeddedResource ]],
33- ],
34- ):
35- logger .debug ("Registering handler for CallToolRequest" )
36-
37- async def handler (req : types .CallToolRequest ):
38- try :
39- # HACK: Pull the original HTTP request info from the MCP message. It was injected in
40- # `FastApiSseTransport.handle_fastapi_post_message()`
41- if hasattr (req .params , "_http_request_info" ) and req .params ._http_request_info is not None :
42- http_request_info = HTTPRequestInfo .model_validate (req .params ._http_request_info )
43- results = await func (req .params .name , (req .params .arguments or {}), http_request_info )
44- else :
45- results = await func (req .params .name , (req .params .arguments or {}))
46- return types .ServerResult (types .CallToolResult (content = list (results ), isError = False ))
47- except Exception as e :
48- return types .ServerResult (
49- types .CallToolResult (
50- content = [types .TextContent (type = "text" , text = str (e ))],
51- isError = True ,
52- )
53- )
54-
55- self .request_handlers [types .CallToolRequest ] = handler
56- return func
57-
58- return decorator
59-
60-
6122class FastApiMCP :
6223 """
6324 Create an MCP server from a FastAPI app.
@@ -115,14 +76,14 @@ def __init__(
11576 Doc ("Configuration for MCP authentication" ),
11677 ] = None ,
11778 headers : Annotated [
118- Optional [ List [str ] ],
79+ List [str ],
11980 Doc (
12081 """
12182 List of HTTP header names to forward from the incoming MCP request into each tool invocation.
12283 Only headers in this allowlist will be forwarded. Defaults to ['authorization'].
12384 """
12485 ),
125- ] = None ,
86+ ] = [ "authorization" ] ,
12687 ):
12788 # Validate operation and tag filtering options
12889 if include_operations is not None and exclude_operations is not None :
@@ -157,7 +118,7 @@ def __init__(
157118 timeout = 10.0 ,
158119 )
159120
160- self ._forward_headers = {h .lower () for h in ( headers or [ "Authorization" ]) }
121+ self ._forward_headers = {h .lower () for h in headers }
161122
162123 self .setup_server ()
163124
@@ -179,16 +140,40 @@ def setup_server(self) -> None:
179140 # Filter tools based on operation IDs and tags
180141 self .tools = self ._filter_tools (all_tools , openapi_schema )
181142
182- mcp_server : LowlevelMCPServer = LowlevelMCPServer (self .name , self .description )
143+ mcp_server : Server = Server (self .name , self .description )
183144
184145 @mcp_server .list_tools ()
185146 async def handle_list_tools () -> List [types .Tool ]:
186147 return self .tools
187148
188149 @mcp_server .call_tool ()
189150 async def handle_call_tool (
190- name : str , arguments : Dict [str , Any ], http_request_info : Optional [ HTTPRequestInfo ] = None
151+ name : str , arguments : Dict [str , Any ]
191152 ) -> List [Union [types .TextContent , types .ImageContent , types .EmbeddedResource ]]:
153+ # Extract HTTP request info from MCP context
154+ http_request_info = None
155+ try :
156+ # Access the MCP server's request context to get the original HTTP Request
157+ request_context = mcp_server .request_context
158+
159+ if request_context and hasattr (request_context , "request" ):
160+ http_request = request_context .request
161+
162+ if http_request and hasattr (http_request , "method" ):
163+ http_request_info = HTTPRequestInfo (
164+ method = http_request .method ,
165+ path = http_request .url .path ,
166+ headers = dict (http_request .headers ),
167+ cookies = http_request .cookies ,
168+ query_params = dict (http_request .query_params ),
169+ body = None ,
170+ )
171+ logger .debug (
172+ f"Extracted HTTP request info from context: { http_request_info .method } { http_request_info .path } "
173+ )
174+ except (LookupError , AttributeError ) as e :
175+ logger .error (f"Could not extract HTTP request info from context: { e } " )
176+
192177 return await self ._execute_api_tool (
193178 client = self ._http_client ,
194179 tool_name = name ,
0 commit comments