Skip to content

Commit fe91d33

Browse files
omaryashraf5Omar Abdelwahabashwinb
authored
fix: Remove authorization from provider data (#4161)
# What does this PR do? - Remove backward compatibility for authorization in mcp_headers - Enforce authorization must use dedicated parameter - Add validation error if Authorization found in provider_data headers - Update test_mcp.py to use authorization parameter - Update test_mcp_json_schema.py to use authorization parameter - Update test_tools_with_schemas.py to use authorization parameter - Update documentation to show the change in the authorization approach Breaking Change: - Authorization can no longer be passed via mcp_headers in provider_data - Users must use the dedicated 'authorization' parameter instead - Clear error message guides users to the new approach" ## Test Plan CI --------- Co-authored-by: Omar Abdelwahab <omara@fb.com> Co-authored-by: Ashwin Bharambe <ashwin.bharambe@gmail.com>
1 parent 0128eff commit fe91d33

File tree

5 files changed

+61
-172
lines changed

5 files changed

+61
-172
lines changed

docs/docs/building_applications/tools.mdx

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -104,23 +104,19 @@ client.toolgroups.register(
104104
)
105105
```
106106

107-
Note that most of the more useful MCP servers need you to authenticate with them. Many of them use OAuth2.0 for authentication. You can provide authorization headers to send to the MCP server using the "Provider Data" abstraction provided by Llama Stack. When making an agent call,
107+
Note that most of the more useful MCP servers need you to authenticate with them. Many of them use OAuth2.0 for authentication. You can provide the authorization token when creating the Agent:
108108

109109
```python
110110
agent = Agent(
111111
...,
112-
tools=["mcp::deepwiki"],
113-
extra_headers={
114-
"X-LlamaStack-Provider-Data": json.dumps(
115-
{
116-
"mcp_headers": {
117-
"http://mcp.deepwiki.com/sse": {
118-
"Authorization": "Bearer <your_access_token>",
119-
},
120-
},
121-
}
122-
),
123-
},
112+
tools=[
113+
{
114+
"type": "mcp",
115+
"server_url": "https://mcp.deepwiki.com/sse",
116+
"server_label": "mcp::deepwiki",
117+
"authorization": "<your_access_token>", # OAuth token (without "Bearer " prefix)
118+
}
119+
],
124120
)
125121
agent.create_turn(...)
126122
```

src/llama_stack/providers/remote/tool_runtime/model_context_protocol/model_context_protocol.py

Lines changed: 26 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -48,16 +48,10 @@ async def list_runtime_tools(
4848
if mcp_endpoint is None:
4949
raise ValueError("mcp_endpoint is required")
5050

51-
# Phase 1: Support both old header-based auth AND new authorization parameter
52-
# Get headers and auth from provider data (old approach)
53-
provider_headers, provider_auth = await self.get_headers_from_request(mcp_endpoint.uri)
51+
# Get other headers from provider data (but NOT authorization)
52+
provider_headers = await self.get_headers_from_request(mcp_endpoint.uri)
5453

55-
# New authorization parameter takes precedence over provider data
56-
final_authorization = authorization or provider_auth
57-
58-
return await list_mcp_tools(
59-
endpoint=mcp_endpoint.uri, headers=provider_headers, authorization=final_authorization
60-
)
54+
return await list_mcp_tools(endpoint=mcp_endpoint.uri, headers=provider_headers, authorization=authorization)
6155

6256
async def invoke_tool(
6357
self, tool_name: str, kwargs: dict[str, Any], authorization: str | None = None
@@ -69,57 +63,53 @@ async def invoke_tool(
6963
if urlparse(endpoint).scheme not in ("http", "https"):
7064
raise ValueError(f"Endpoint {endpoint} is not a valid HTTP(S) URL")
7165

72-
# Phase 1: Support both old header-based auth AND new authorization parameter
73-
# Get headers and auth from provider data (old approach)
74-
provider_headers, provider_auth = await self.get_headers_from_request(endpoint)
75-
76-
# New authorization parameter takes precedence over provider data
77-
final_authorization = authorization or provider_auth
66+
# Get other headers from provider data (but NOT authorization)
67+
provider_headers = await self.get_headers_from_request(endpoint)
7868

7969
return await invoke_mcp_tool(
8070
endpoint=endpoint,
8171
tool_name=tool_name,
8272
kwargs=kwargs,
8373
headers=provider_headers,
84-
authorization=final_authorization,
74+
authorization=authorization,
8575
)
8676

87-
async def get_headers_from_request(self, mcp_endpoint_uri: str) -> tuple[dict[str, str], str | None]:
77+
async def get_headers_from_request(self, mcp_endpoint_uri: str) -> dict[str, str]:
8878
"""
89-
Extract headers and authorization from request provider data (Phase 1 backward compatibility).
79+
Extract headers from request provider data, excluding authorization.
9080
91-
Phase 1: Temporarily allows Authorization to be passed via mcp_headers for backward compatibility.
92-
Phase 2: Will enforce that Authorization should use the dedicated authorization parameter instead.
81+
Authorization must be provided via the dedicated authorization parameter.
82+
If Authorization is found in mcp_headers, raise an error to guide users to the correct approach.
83+
84+
Args:
85+
mcp_endpoint_uri: The MCP endpoint URI to match against provider data
9386
9487
Returns:
95-
Tuple of (headers_dict, authorization_token)
96-
- headers_dict: All headers except Authorization
97-
- authorization_token: Token from Authorization header (with "Bearer " prefix removed), or None
88+
dict[str, str]: Headers dictionary (without Authorization)
89+
90+
Raises:
91+
ValueError: If Authorization header is found in mcp_headers
9892
"""
9993

10094
def canonicalize_uri(uri: str) -> str:
10195
return f"{urlparse(uri).netloc or ''}/{urlparse(uri).path or ''}"
10296

10397
headers = {}
104-
authorization = None
10598

10699
provider_data = self.get_request_provider_data()
107100
if provider_data and hasattr(provider_data, "mcp_headers") and provider_data.mcp_headers:
108101
for uri, values in provider_data.mcp_headers.items():
109102
if canonicalize_uri(uri) != canonicalize_uri(mcp_endpoint_uri):
110103
continue
111104

112-
# Phase 1: Extract Authorization from mcp_headers for backward compatibility
113-
# (Phase 2 will reject this and require the dedicated authorization parameter)
105+
# Reject Authorization in mcp_headers - must use authorization parameter
114106
for key in values.keys():
115107
if key.lower() == "authorization":
116-
# Extract authorization token and strip "Bearer " prefix if present
117-
auth_value = values[key]
118-
if auth_value.startswith("Bearer "):
119-
authorization = auth_value[7:] # Remove "Bearer " prefix
120-
else:
121-
authorization = auth_value
122-
else:
123-
headers[key] = values[key]
124-
125-
return headers, authorization
108+
raise ValueError(
109+
"Authorization cannot be provided via mcp_headers in provider_data. "
110+
"Please use the dedicated 'authorization' parameter instead. "
111+
"Example: tool_runtime.invoke_tool(..., authorization='your-token')"
112+
)
113+
headers[key] = values[key]
114+
115+
return headers

tests/integration/inference/test_tools_with_schemas.py

Lines changed: 2 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,6 @@
99
Tests that tools pass through correctly to various LLM providers.
1010
"""
1111

12-
import json
13-
1412
import pytest
1513

1614
from llama_stack.core.library_client import LlamaStackAsLibraryClient
@@ -193,22 +191,11 @@ def test_mcp_tools_in_inference(self, llama_stack_client, text_model_id, mcp_wit
193191
mcp_endpoint=dict(uri=uri),
194192
)
195193

196-
# Use old header-based approach for Phase 1 (backward compatibility)
197-
provider_data = {
198-
"mcp_headers": {
199-
uri: {
200-
"Authorization": f"Bearer {AUTH_TOKEN}",
201-
},
202-
},
203-
}
204-
auth_headers = {
205-
"X-LlamaStack-Provider-Data": json.dumps(provider_data),
206-
}
207-
194+
# Use the dedicated authorization parameter
208195
# Get the tools from MCP
209196
tools_response = llama_stack_client.tool_runtime.list_tools(
210197
tool_group_id=test_toolgroup_id,
211-
extra_headers=auth_headers,
198+
authorization=AUTH_TOKEN,
212199
)
213200

214201
# Convert to OpenAI format for inference

tests/integration/tool_runtime/test_mcp.py

Lines changed: 7 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,6 @@
44
# This source code is licensed under the terms described in the LICENSE file in
55
# the root directory of this source tree.
66

7-
import json
8-
97
import pytest
108
from llama_stack_client.lib.agents.agent import Agent
119
from llama_stack_client.lib.agents.turn_events import StepCompleted, StepProgress, ToolCallIssuedDelta
@@ -37,32 +35,20 @@ def test_mcp_invocation(llama_stack_client, text_model_id, mcp_server):
3735
mcp_endpoint=dict(uri=uri),
3836
)
3937

40-
# Use old header-based approach for Phase 1 (backward compatibility)
41-
provider_data = {
42-
"mcp_headers": {
43-
uri: {
44-
"Authorization": f"Bearer {AUTH_TOKEN}",
45-
},
46-
},
47-
}
48-
auth_headers = {
49-
"X-LlamaStack-Provider-Data": json.dumps(provider_data),
50-
}
51-
52-
with pytest.raises(Exception, match="Unauthorized"):
53-
llama_stack_client.tools.list(toolgroup_id=test_toolgroup_id)
54-
55-
tools_list = llama_stack_client.tools.list(
56-
toolgroup_id=test_toolgroup_id,
57-
extra_headers=auth_headers, # Use old header-based approach
38+
# Use the dedicated authorization parameter (no more provider_data headers)
39+
# This tests direct tool_runtime.invoke_tool API calls
40+
tools_list = llama_stack_client.tool_runtime.list_tools(
41+
tool_group_id=test_toolgroup_id,
42+
authorization=AUTH_TOKEN, # Use dedicated authorization parameter
5843
)
5944
assert len(tools_list) == 2
6045
assert {t.name for t in tools_list} == {"greet_everyone", "get_boiling_point"}
6146

47+
# Invoke tool with authorization parameter
6248
response = llama_stack_client.tool_runtime.invoke_tool(
6349
tool_name="greet_everyone",
6450
kwargs=dict(url="https://www.google.com"),
65-
extra_headers=auth_headers, # Use old header-based approach
51+
authorization=AUTH_TOKEN, # Use dedicated authorization parameter
6652
)
6753
content = response.content
6854
assert len(content) == 1

0 commit comments

Comments
 (0)