-
Notifications
You must be signed in to change notification settings - Fork 1.4k
Added a mechanism to extract metadata from MCP tool call response #3339
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
4bd79c9
bb1c256
98412c0
5c3f58f
73ba0a1
0e5ae00
ad73591
dfe81b6
fd806f7
d99be83
f8feb5b
39d47b5
6eea048
f54b127
7935568
026b364
b8ce49c
13fb859
46a7b87
41b3aa4
8154885
0133802
4a58898
137126a
c5b3b57
a0a1294
0c8251d
55598da
9f0b28f
a97b711
8267b71
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -273,12 +273,62 @@ async def direct_call_tool( | |
| ): | ||
| # The MCP SDK wraps primitives and generic types like list in a `result` key, but we want to use the raw value returned by the tool function. | ||
| # See https://github.com/modelcontextprotocol/python-sdk#structured-output | ||
| if isinstance(structured, dict) and len(structured) == 1 and 'result' in structured: | ||
| return structured['result'] | ||
| return structured | ||
| return_value = ( | ||
| structured['result'] | ||
| if isinstance(structured, dict) and len(structured) == 1 and 'result' in structured | ||
| else structured | ||
| ) | ||
| return messages.ToolReturn(return_value=return_value, metadata=result.meta) if result.meta else return_value | ||
|
|
||
| parts_with_metadata = [await self._map_tool_result_part(part) for part in result.content] | ||
| parts_only = [part for part, _ in parts_with_metadata] | ||
| # any_part_has_metadata = any(metadata is not None for _, metadata in parts_with_metadata) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not using this anymore? |
||
| return_values: list[Any] = [] | ||
| user_contents: list[Any] = [] | ||
| parts_metadata: dict[int, dict[str, Any]] = {} | ||
| return_metadata: dict[str, Any] = {} | ||
| # if any_part_has_metadata: | ||
| for idx, (part, part_metadata) in enumerate(parts_with_metadata): | ||
| if part_metadata is not None: | ||
| parts_metadata[idx] = part_metadata | ||
| # TODO: Keep updated with the multimodal content parsing in _agent_graph.py | ||
| if isinstance(part, messages.BinaryContent): | ||
| identifier = part.identifier | ||
|
|
||
| return_values.append(f'See file {identifier}') | ||
| user_contents.append([f'This is file {identifier}:', part]) | ||
| else: | ||
| user_contents.append(part) | ||
|
|
||
| # The following branching cannot be tested until FastMCP is updated to version 2.13.1 | ||
| # such that the MCP server can generate ToolResult and result.meta can be specified. | ||
| # TODO: Add tests for the following branching once FastMCP is updated. | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Adding a comment here so we don't lose this |
||
| if len(parts_metadata) > 0: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here too, I prefer |
||
| if result.meta is not None and len(result.meta) > 0: # pragma: no cover | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should not have any no-covers if we can help it! Edit: You already pointed out why you did that, never mind :)
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also is this equivalent to |
||
| # Merge the tool result metadata and parts metadata into the return metadata | ||
| return_metadata = {'result': result.meta, 'content': parts_metadata} | ||
| else: | ||
| # Only parts metadata exists | ||
| if len(parts_metadata) == 1: | ||
| # If there is only one content metadata, unwrap it | ||
| return_metadata = parts_metadata[0] | ||
| else: | ||
| return_metadata = {'content': parts_metadata} # pragma: no cover | ||
| else: | ||
| if result.meta is not None and len(result.meta) > 0: # pragma: no cover | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This can be |
||
| return_metadata = result.meta | ||
| # TODO: What else should we cover here? | ||
|
|
||
| mapped = [await self._map_tool_result_part(part) for part in result.content] | ||
| return mapped[0] if len(mapped) == 1 else mapped | ||
| # Finally, construct and return the ToolReturn object | ||
| return ( | ||
| messages.ToolReturn( | ||
| return_value=return_values, | ||
| content=user_contents, | ||
| metadata=return_metadata, | ||
| ) | ||
| if len(return_metadata) > 0 | ||
| else (parts_only[0] if len(parts_only) == 1 else parts_only) | ||
| ) | ||
|
|
||
| async def call_tool( | ||
| self, | ||
|
|
@@ -394,35 +444,57 @@ async def _sampling_callback( | |
|
|
||
| async def _map_tool_result_part( | ||
| self, part: mcp_types.ContentBlock | ||
| ) -> str | messages.BinaryContent | dict[str, Any] | list[Any]: | ||
| ) -> tuple[str | messages.BinaryContent | dict[str, Any] | list[Any], dict[str, Any] | None]: | ||
| # See https://github.com/jlowin/fastmcp/blob/main/docs/servers/tools.mdx#return-values | ||
|
|
||
| metadata: dict[str, Any] | None = part.meta | ||
DouweM marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| if isinstance(part, mcp_types.TextContent): | ||
| text = part.text | ||
| if text.startswith(('[', '{')): | ||
| try: | ||
| return pydantic_core.from_json(text) | ||
| return pydantic_core.from_json(text), metadata | ||
| except ValueError: | ||
| pass | ||
| return text | ||
| return text, metadata | ||
| elif isinstance(part, mcp_types.ImageContent): | ||
| return messages.BinaryContent(data=base64.b64decode(part.data), media_type=part.mimeType) | ||
| elif isinstance(part, mcp_types.AudioContent): | ||
| return messages.BinaryContent(data=base64.b64decode(part.data), media_type=part.mimeType), metadata | ||
| elif isinstance(part, mcp_types.AudioContent): # pragma: no cover | ||
| # NOTE: The FastMCP server doesn't support audio content. | ||
| # See <https://github.com/modelcontextprotocol/python-sdk/issues/952> for more details. | ||
| return messages.BinaryContent( | ||
| data=base64.b64decode(part.data), media_type=part.mimeType | ||
| ) # pragma: no cover | ||
| return messages.BinaryContent(data=base64.b64decode(part.data), media_type=part.mimeType), metadata | ||
| elif isinstance(part, mcp_types.EmbeddedResource): | ||
| resource = part.resource | ||
| return self._get_content(resource) | ||
| return self._get_content(part.resource), metadata | ||
| # The following branching cannot be tested until FastMCP is updated to version 2.13.1 | ||
| # such that the MCP server can generate ToolResult and result.meta can be specified. | ||
| # TODO: Add tests for the following branching once FastMCP is updated. | ||
| elif isinstance(part, mcp_types.ResourceLink): | ||
| resource_result: mcp_types.ReadResourceResult = await self._client.read_resource(part.uri) | ||
anirbanbasu marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| return ( | ||
| self._get_content(resource_result.contents[0]) | ||
| if len(resource_result.contents) == 1 | ||
| else [self._get_content(resource) for resource in resource_result.contents] | ||
| ) | ||
| # Check if metadata already exists. If so, merge it with nested the resource metadata. | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we dedupe any of this with the above with some helper functions? |
||
| parts_metadata: dict[int, dict[str, Any]] = {} | ||
| nested_metadata: dict[str, Any] = {} | ||
| for idx, content in enumerate(resource_result.contents): | ||
| if content.meta is not None: # pragma: no cover | ||
| parts_metadata[idx] = content.meta | ||
| if len(parts_metadata) > 0: | ||
| if resource_result.meta is not None and len(resource_result.meta) > 0: # pragma: no cover | ||
| # Merge the tool result metadata and parts metadata into the return metadata | ||
| nested_metadata = {'result': resource_result.meta, 'content': parts_metadata} | ||
| else: | ||
| # Only parts metadata exists | ||
| if len(parts_metadata) == 1: # pragma: no cover | ||
| # If there is only one content metadata, unwrap it | ||
| nested_metadata = parts_metadata[0] | ||
| else: | ||
| nested_metadata = {'content': parts_metadata} # pragma: no cover | ||
| else: | ||
| if resource_result.meta is not None and len(resource_result.meta) > 0: # pragma: no cover | ||
| nested_metadata = resource_result.meta | ||
| # FIXME: Is this a correct assumption? If metadata was read from the part then that is the same as resource_result.meta | ||
| metadata = nested_metadata | ||
| if len(resource_result.contents) == 1: | ||
| return self._get_content(resource_result.contents[0]), metadata | ||
| else: # pragma: no cover | ||
| return [self._get_content(resource) for resource in resource_result.contents], metadata | ||
| else: | ||
| assert_never(part) | ||
|
|
||
|
|
@@ -895,6 +967,7 @@ def __eq__(self, value: object, /) -> bool: | |
| ToolResult = ( | ||
| str | ||
| | messages.BinaryContent | ||
| | messages.ToolReturn | ||
| | dict[str, Any] | ||
| | list[Any] | ||
| | Sequence[str | messages.BinaryContent | dict[str, Any] | list[Any]] | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -26,6 +26,7 @@ | |
| from pydantic_ai.agent import Agent | ||
| from pydantic_ai.exceptions import ModelRetry, UnexpectedModelBehavior, UserError | ||
| from pydantic_ai.mcp import MCPServerStreamableHTTP, load_mcp_servers | ||
| from pydantic_ai.messages import ToolReturn | ||
| from pydantic_ai.models import Model | ||
| from pydantic_ai.models.test import TestModel | ||
| from pydantic_ai.tools import RunContext | ||
|
|
@@ -77,7 +78,7 @@ async def test_stdio_server(run_context: RunContext[int]): | |
| server = MCPServerStdio('python', ['-m', 'tests.mcp_server']) | ||
| async with server: | ||
| tools = [tool.tool_def for tool in (await server.get_tools(run_context)).values()] | ||
| assert len(tools) == snapshot(18) | ||
| assert len(tools) == snapshot(20) | ||
| assert tools[0].name == 'celsius_to_fahrenheit' | ||
| assert isinstance(tools[0].description, str) | ||
| assert tools[0].description.startswith('Convert Celsius to Fahrenheit.') | ||
|
|
@@ -87,6 +88,39 @@ async def test_stdio_server(run_context: RunContext[int]): | |
| assert result == snapshot(32.0) | ||
|
|
||
|
|
||
| async def test_tool_response_metadata(run_context: RunContext[int]): | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We'll want tests of every combination that we've covered up above |
||
| server = MCPServerStdio('python', ['-m', 'tests.mcp_server']) | ||
| async with server: | ||
| tools = [tool.tool_def for tool in (await server.get_tools(run_context)).values()] | ||
| assert len(tools) == snapshot(20) | ||
| assert tools[4].name == 'get_collatz_conjecture' | ||
| assert isinstance(tools[4].description, str) | ||
| assert tools[4].description.startswith('Generate the Collatz conjecture sequence for a given number.') | ||
|
|
||
| result = await server.direct_call_tool('get_collatz_conjecture', {'n': 7}) | ||
| assert isinstance(result, ToolReturn) | ||
| assert isinstance(result.content, list) | ||
| assert result.content[0] == snapshot([7, 22, 11, 34, 17, 52, 26, 13, 40, 20, 10, 5, 16, 8, 4, 2, 1]) | ||
| assert result.metadata == snapshot({'pydantic_ai': {'tool': 'collatz_conjecture', 'n': 7, 'length': 17}}) | ||
|
|
||
|
|
||
| async def test_tool_structured_response_metadata(run_context: RunContext[int]): | ||
| server = MCPServerStdio('python', ['-m', 'tests.mcp_server']) | ||
| async with server: | ||
| tools = [tool.tool_def for tool in (await server.get_tools(run_context)).values()] | ||
| assert len(tools) == snapshot(20) | ||
| assert tools[5].name == 'get_structured_text_content_with_metadata' | ||
| assert isinstance(tools[5].description, str) | ||
| assert tools[5].description.startswith('Return structured dict with metadata.') | ||
|
|
||
| result = await server.direct_call_tool('get_structured_text_content_with_metadata', {}) | ||
| assert isinstance(result, dict) | ||
| assert 'result' in result | ||
| assert result['result'] == 'This is some text content.' | ||
| assert '_meta' in result | ||
| assert result['_meta'] == snapshot({'pydantic_ai': {'source': 'get_structured_text_content_with_metadata'}}) | ||
|
|
||
|
|
||
| async def test_reentrant_context_manager(): | ||
| server = MCPServerStdio('python', ['-m', 'tests.mcp_server']) | ||
| async with server: | ||
|
|
@@ -138,7 +172,7 @@ async def test_stdio_server_with_cwd(run_context: RunContext[int]): | |
| server = MCPServerStdio('python', ['mcp_server.py'], cwd=test_dir) | ||
| async with server: | ||
| tools = await server.get_tools(run_context) | ||
| assert len(tools) == snapshot(18) | ||
| assert len(tools) == snapshot(20) | ||
|
|
||
|
|
||
| async def test_process_tool_call(run_context: RunContext[int]) -> int: | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.