Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ def configure_worker(self, config: WorkerConfig) -> WorkerConfig:
'httpx',
'anyio',
'httpcore',
# Used by fastmcp via py-key-value-aio
'beartype',
# Imported inside `logfire._internal.json_encoder` when running `logfire.info` inside an activity with attributes to serialize
'attrs',
# Imported inside `logfire._internal.json_schema` when running `logfire.info` inside an activity with attributes to serialize
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from __future__ import annotations

from typing import Literal

from temporalio.workflow import ActivityConfig

from pydantic_ai import ToolsetTool
from pydantic_ai.tools import AgentDepsT, ToolDefinition
from pydantic_ai.toolsets.fastmcp import FastMCPToolset

from ._mcp import TemporalMCPToolset
from ._run_context import TemporalRunContext


class TemporalFastMCPToolset(TemporalMCPToolset[AgentDepsT]):
def __init__(
self,
toolset: FastMCPToolset[AgentDepsT],
*,
activity_name_prefix: str,
activity_config: ActivityConfig,
tool_activity_config: dict[str, ActivityConfig | Literal[False]],
deps_type: type[AgentDepsT],
run_context_type: type[TemporalRunContext[AgentDepsT]] = TemporalRunContext[AgentDepsT],
):
super().__init__(
toolset,
activity_name_prefix=activity_name_prefix,
activity_config=activity_config,
tool_activity_config=tool_activity_config,
deps_type=deps_type,
run_context_type=run_context_type,
)

def tool_for_tool_def(self, tool_def: ToolDefinition) -> ToolsetTool[AgentDepsT]:
assert isinstance(self.wrapped, FastMCPToolset)
return self.wrapped.tool_for_tool_def(tool_def)
147 changes: 147 additions & 0 deletions pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_mcp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from collections.abc import Callable
from dataclasses import dataclass
from typing import Any, Literal

from pydantic import ConfigDict, with_config
from temporalio import activity, workflow
from temporalio.workflow import ActivityConfig
from typing_extensions import Self

from pydantic_ai import ToolsetTool
from pydantic_ai.exceptions import UserError
from pydantic_ai.tools import AgentDepsT, RunContext, ToolDefinition
from pydantic_ai.toolsets import AbstractToolset

from ._run_context import TemporalRunContext
from ._toolset import (
CallToolParams,
CallToolResult,
TemporalWrapperToolset,
)


@dataclass
@with_config(ConfigDict(arbitrary_types_allowed=True))
class _GetToolsParams:
serialized_run_context: Any


class TemporalMCPToolset(TemporalWrapperToolset[AgentDepsT], ABC):
def __init__(
self,
toolset: AbstractToolset[AgentDepsT],
*,
activity_name_prefix: str,
activity_config: ActivityConfig,
tool_activity_config: dict[str, ActivityConfig | Literal[False]],
deps_type: type[AgentDepsT],
run_context_type: type[TemporalRunContext[AgentDepsT]] = TemporalRunContext[AgentDepsT],
):
super().__init__(toolset)
self.activity_config = activity_config

self.tool_activity_config: dict[str, ActivityConfig] = {}
for tool_name, tool_config in tool_activity_config.items():
if tool_config is False:
raise UserError(
f'Temporal activity config for MCP tool {tool_name!r} has been explicitly set to `False` (activity disabled), '
'but MCP tools require the use of IO and so cannot be run outside of an activity.'
)
self.tool_activity_config[tool_name] = tool_config

self.run_context_type = run_context_type

async def get_tools_activity(params: _GetToolsParams, deps: AgentDepsT) -> dict[str, ToolDefinition]:
run_context = self.run_context_type.deserialize_run_context(params.serialized_run_context, deps=deps)
tools = await self.wrapped.get_tools(run_context)
# ToolsetTool is not serializable as it holds a SchemaValidator (which is also the same for every MCP tool so unnecessary to pass along the wire every time),
# so we just return the ToolDefinitions and wrap them in ToolsetTool outside of the activity.
return {name: tool.tool_def for name, tool in tools.items()}

# Set type hint explicitly so that Temporal can take care of serialization and deserialization
get_tools_activity.__annotations__['deps'] = deps_type

self.get_tools_activity = activity.defn(name=f'{activity_name_prefix}__mcp_server__{self.id}__get_tools')(
get_tools_activity
)

async def call_tool_activity(params: CallToolParams, deps: AgentDepsT) -> CallToolResult:
run_context = self.run_context_type.deserialize_run_context(params.serialized_run_context, deps=deps)
assert isinstance(params.tool_def, ToolDefinition)
return await self._wrap_call_tool_result(
self.wrapped.call_tool(
params.name,
params.tool_args,
run_context,
self.tool_for_tool_def(params.tool_def),
)
)

# Set type hint explicitly so that Temporal can take care of serialization and deserialization
call_tool_activity.__annotations__['deps'] = deps_type

self.call_tool_activity = activity.defn(name=f'{activity_name_prefix}__mcp_server__{self.id}__call_tool')(
call_tool_activity
)

@abstractmethod
def tool_for_tool_def(self, tool_def: ToolDefinition) -> ToolsetTool[AgentDepsT]:
raise NotImplementedError

@property
def temporal_activities(self) -> list[Callable[..., Any]]:
return [self.get_tools_activity, self.call_tool_activity]

async def __aenter__(self) -> Self:
# The wrapped MCPServer enters itself around listing and calling tools
# so we don't need to enter it here (nor could we because we're not inside a Temporal activity).
return self

async def __aexit__(self, *args: Any) -> bool | None:
return None

async def get_tools(self, ctx: RunContext[AgentDepsT]) -> dict[str, ToolsetTool[AgentDepsT]]:
if not workflow.in_workflow():
return await super().get_tools(ctx)

serialized_run_context = self.run_context_type.serialize_run_context(ctx)
tool_defs = await workflow.execute_activity( # pyright: ignore[reportUnknownMemberType]
activity=self.get_tools_activity,
args=[
_GetToolsParams(serialized_run_context=serialized_run_context),
ctx.deps,
],
**self.activity_config,
)
return {name: self.tool_for_tool_def(tool_def) for name, tool_def in tool_defs.items()}

async def call_tool(
self,
name: str,
tool_args: dict[str, Any],
ctx: RunContext[AgentDepsT],
tool: ToolsetTool[AgentDepsT],
) -> CallToolResult:
if not workflow.in_workflow():
return await super().call_tool(name, tool_args, ctx, tool)

tool_activity_config = self.activity_config | self.tool_activity_config.get(name, {})
serialized_run_context = self.run_context_type.serialize_run_context(ctx)
return self._unwrap_call_tool_result(
await workflow.execute_activity( # pyright: ignore[reportUnknownMemberType]
activity=self.call_tool_activity,
args=[
CallToolParams(
name=name,
tool_args=tool_args,
serialized_run_context=serialized_run_context,
tool_def=tool.tool_def,
),
ctx.deps,
],
**tool_activity_config,
)
)
131 changes: 11 additions & 120 deletions pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_mcp_server.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,18 @@
from __future__ import annotations

from collections.abc import Callable
from dataclasses import dataclass
from typing import Any, Literal
from typing import Literal

from pydantic import ConfigDict, with_config
from temporalio import activity, workflow
from temporalio.workflow import ActivityConfig
from typing_extensions import Self

from pydantic_ai import ToolsetTool
from pydantic_ai.exceptions import UserError
from pydantic_ai.mcp import MCPServer
from pydantic_ai.tools import AgentDepsT, RunContext, ToolDefinition
from pydantic_ai.tools import AgentDepsT, ToolDefinition

from ._mcp import TemporalMCPToolset
from ._run_context import TemporalRunContext
from ._toolset import (
CallToolParams,
CallToolResult,
TemporalWrapperToolset,
)


@dataclass
@with_config(ConfigDict(arbitrary_types_allowed=True))
class _GetToolsParams:
serialized_run_context: Any


class TemporalMCPServer(TemporalWrapperToolset[AgentDepsT]):
class TemporalMCPServer(TemporalMCPToolset[AgentDepsT]):
def __init__(
self,
server: MCPServer,
Expand All @@ -39,108 +23,15 @@ def __init__(
deps_type: type[AgentDepsT],
run_context_type: type[TemporalRunContext[AgentDepsT]] = TemporalRunContext[AgentDepsT],
):
super().__init__(server)
self.activity_config = activity_config

self.tool_activity_config: dict[str, ActivityConfig] = {}
for tool_name, tool_config in tool_activity_config.items():
if tool_config is False:
raise UserError(
f'Temporal activity config for MCP tool {tool_name!r} has been explicitly set to `False` (activity disabled), '
'but MCP tools require the use of IO and so cannot be run outside of an activity.'
)
self.tool_activity_config[tool_name] = tool_config

self.run_context_type = run_context_type

async def get_tools_activity(params: _GetToolsParams, deps: AgentDepsT) -> dict[str, ToolDefinition]:
run_context = self.run_context_type.deserialize_run_context(params.serialized_run_context, deps=deps)
tools = await self.wrapped.get_tools(run_context)
# ToolsetTool is not serializable as it holds a SchemaValidator (which is also the same for every MCP tool so unnecessary to pass along the wire every time),
# so we just return the ToolDefinitions and wrap them in ToolsetTool outside of the activity.
return {name: tool.tool_def for name, tool in tools.items()}

# Set type hint explicitly so that Temporal can take care of serialization and deserialization
get_tools_activity.__annotations__['deps'] = deps_type

self.get_tools_activity = activity.defn(name=f'{activity_name_prefix}__mcp_server__{self.id}__get_tools')(
get_tools_activity
)

async def call_tool_activity(params: CallToolParams, deps: AgentDepsT) -> CallToolResult:
run_context = self.run_context_type.deserialize_run_context(params.serialized_run_context, deps=deps)
assert isinstance(params.tool_def, ToolDefinition)
return await self._wrap_call_tool_result(
self.wrapped.call_tool(
params.name,
params.tool_args,
run_context,
self.tool_for_tool_def(params.tool_def),
)
)

# Set type hint explicitly so that Temporal can take care of serialization and deserialization
call_tool_activity.__annotations__['deps'] = deps_type

self.call_tool_activity = activity.defn(name=f'{activity_name_prefix}__mcp_server__{self.id}__call_tool')(
call_tool_activity
super().__init__(
server,
activity_name_prefix=activity_name_prefix,
activity_config=activity_config,
tool_activity_config=tool_activity_config,
deps_type=deps_type,
run_context_type=run_context_type,
)

def tool_for_tool_def(self, tool_def: ToolDefinition) -> ToolsetTool[AgentDepsT]:
assert isinstance(self.wrapped, MCPServer)
return self.wrapped.tool_for_tool_def(tool_def)

@property
def temporal_activities(self) -> list[Callable[..., Any]]:
return [self.get_tools_activity, self.call_tool_activity]

async def __aenter__(self) -> Self:
# The wrapped MCPServer enters itself around listing and calling tools
# so we don't need to enter it here (nor could we because we're not inside a Temporal activity).
return self

async def __aexit__(self, *args: Any) -> bool | None:
return None

async def get_tools(self, ctx: RunContext[AgentDepsT]) -> dict[str, ToolsetTool[AgentDepsT]]:
if not workflow.in_workflow():
return await super().get_tools(ctx)

serialized_run_context = self.run_context_type.serialize_run_context(ctx)
tool_defs = await workflow.execute_activity( # pyright: ignore[reportUnknownMemberType]
activity=self.get_tools_activity,
args=[
_GetToolsParams(serialized_run_context=serialized_run_context),
ctx.deps,
],
**self.activity_config,
)
return {name: self.tool_for_tool_def(tool_def) for name, tool_def in tool_defs.items()}

async def call_tool(
self,
name: str,
tool_args: dict[str, Any],
ctx: RunContext[AgentDepsT],
tool: ToolsetTool[AgentDepsT],
) -> CallToolResult:
if not workflow.in_workflow():
return await super().call_tool(name, tool_args, ctx, tool)

tool_activity_config = self.activity_config | self.tool_activity_config.get(name, {})
serialized_run_context = self.run_context_type.serialize_run_context(ctx)
return self._unwrap_call_tool_result(
await workflow.execute_activity( # pyright: ignore[reportUnknownMemberType]
activity=self.call_tool_activity,
args=[
CallToolParams(
name=name,
tool_args=tool_args,
serialized_run_context=serialized_run_context,
tool_def=tool.tool_def,
),
ctx.deps,
],
**tool_activity_config,
)
)
17 changes: 17 additions & 0 deletions pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_toolset.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,4 +142,21 @@ def temporalize_toolset(
run_context_type=run_context_type,
)

try:
from pydantic_ai.toolsets.fastmcp import FastMCPToolset

from ._fastmcp_toolset import TemporalFastMCPToolset
except ImportError:
pass
else:
if isinstance(toolset, FastMCPToolset):
return TemporalFastMCPToolset(
toolset,
activity_name_prefix=activity_name_prefix,
activity_config=activity_config,
tool_activity_config=tool_activity_config,
deps_type=deps_type,
run_context_type=run_context_type,
)

return toolset
Loading