Skip to content

Commit d10df9a

Browse files
authored
Bump temporalio and use SimplePlugin (#3214)
1 parent e72452f commit d10df9a

File tree

9 files changed

+109
-162
lines changed

9 files changed

+109
-162
lines changed
Lines changed: 58 additions & 108 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,14 @@
11
from __future__ import annotations
22

33
import warnings
4-
from collections.abc import AsyncIterator, Callable, Sequence
5-
from contextlib import AbstractAsyncContextManager
64
from dataclasses import replace
75
from typing import Any
86

97
from pydantic.errors import PydanticUserError
10-
from temporalio.client import ClientConfig, Plugin as ClientPlugin, WorkflowHistory
118
from temporalio.contrib.pydantic import PydanticPayloadConverter, pydantic_data_converter
129
from temporalio.converter import DataConverter, DefaultPayloadConverter
13-
from temporalio.service import ConnectConfig, ServiceClient
14-
from temporalio.worker import (
15-
Plugin as WorkerPlugin,
16-
Replayer,
17-
ReplayerConfig,
18-
Worker,
19-
WorkerConfig,
20-
WorkflowReplayResult,
21-
)
10+
from temporalio.plugin import SimplePlugin
11+
from temporalio.worker import WorkflowRunner
2212
from temporalio.worker.workflow_sandbox import SandboxedWorkflowRunner
2313

2414
from ...exceptions import UserError
@@ -48,104 +38,64 @@
4838
pass
4939

5040

51-
class PydanticAIPlugin(ClientPlugin, WorkerPlugin):
41+
def _data_converter(converter: DataConverter | None) -> DataConverter:
42+
if converter and converter.payload_converter_class not in (
43+
DefaultPayloadConverter,
44+
PydanticPayloadConverter,
45+
):
46+
warnings.warn( # pragma: no cover
47+
'A non-default Temporal data converter was used which has been replaced with the Pydantic data converter.'
48+
)
49+
50+
return pydantic_data_converter
51+
52+
53+
def _workflow_runner(runner: WorkflowRunner | None) -> WorkflowRunner:
54+
if not runner:
55+
raise ValueError('No WorkflowRunner provided to the Pydantic AI plugin.') # pragma: no cover
56+
57+
if not isinstance(runner, SandboxedWorkflowRunner):
58+
return runner # pragma: no cover
59+
60+
return replace(
61+
runner,
62+
restrictions=runner.restrictions.with_passthrough_modules(
63+
'pydantic_ai',
64+
'pydantic',
65+
'pydantic_core',
66+
'logfire',
67+
'rich',
68+
'httpx',
69+
'anyio',
70+
'httpcore',
71+
# Used by fastmcp via py-key-value-aio
72+
'beartype',
73+
# Imported inside `logfire._internal.json_encoder` when running `logfire.info` inside an activity with attributes to serialize
74+
'attrs',
75+
# Imported inside `logfire._internal.json_schema` when running `logfire.info` inside an activity with attributes to serialize
76+
'numpy',
77+
'pandas',
78+
),
79+
)
80+
81+
82+
class PydanticAIPlugin(SimplePlugin):
5283
"""Temporal client and worker plugin for Pydantic AI."""
5384

54-
def init_client_plugin(self, next: ClientPlugin) -> None:
55-
self.next_client_plugin = next
56-
57-
def init_worker_plugin(self, next: WorkerPlugin) -> None:
58-
self.next_worker_plugin = next
59-
60-
def configure_client(self, config: ClientConfig) -> ClientConfig:
61-
config['data_converter'] = self._get_new_data_converter(config.get('data_converter'))
62-
return self.next_client_plugin.configure_client(config)
63-
64-
def configure_worker(self, config: WorkerConfig) -> WorkerConfig:
65-
runner = config.get('workflow_runner') # pyright: ignore[reportUnknownMemberType]
66-
if isinstance(runner, SandboxedWorkflowRunner): # pragma: no branch
67-
config['workflow_runner'] = replace(
68-
runner,
69-
restrictions=runner.restrictions.with_passthrough_modules(
70-
'pydantic_ai',
71-
'pydantic',
72-
'pydantic_core',
73-
'logfire',
74-
'rich',
75-
'httpx',
76-
'anyio',
77-
'httpcore',
78-
# Used by fastmcp via py-key-value-aio
79-
'beartype',
80-
# Imported inside `logfire._internal.json_encoder` when running `logfire.info` inside an activity with attributes to serialize
81-
'attrs',
82-
# Imported inside `logfire._internal.json_schema` when running `logfire.info` inside an activity with attributes to serialize
83-
'numpy',
84-
'pandas',
85-
),
86-
)
87-
88-
config['workflow_failure_exception_types'] = [
89-
*config.get('workflow_failure_exception_types', []), # pyright: ignore[reportUnknownMemberType]
90-
UserError,
91-
PydanticUserError,
92-
]
93-
94-
return self.next_worker_plugin.configure_worker(config)
95-
96-
async def connect_service_client(self, config: ConnectConfig) -> ServiceClient:
97-
return await self.next_client_plugin.connect_service_client(config)
98-
99-
async def run_worker(self, worker: Worker) -> None:
100-
await self.next_worker_plugin.run_worker(worker)
101-
102-
def configure_replayer(self, config: ReplayerConfig) -> ReplayerConfig: # pragma: no cover
103-
config['data_converter'] = self._get_new_data_converter(config.get('data_converter')) # pyright: ignore[reportUnknownMemberType]
104-
return self.next_worker_plugin.configure_replayer(config)
105-
106-
def run_replayer(
107-
self,
108-
replayer: Replayer,
109-
histories: AsyncIterator[WorkflowHistory],
110-
) -> AbstractAsyncContextManager[AsyncIterator[WorkflowReplayResult]]: # pragma: no cover
111-
return self.next_worker_plugin.run_replayer(replayer, histories)
112-
113-
def _get_new_data_converter(self, converter: DataConverter | None) -> DataConverter:
114-
if converter and converter.payload_converter_class not in (
115-
DefaultPayloadConverter,
116-
PydanticPayloadConverter,
117-
):
118-
warnings.warn( # pragma: no cover
119-
'A non-default Temporal data converter was used which has been replaced with the Pydantic data converter.'
120-
)
121-
122-
return pydantic_data_converter
123-
124-
125-
class AgentPlugin(WorkerPlugin):
126-
"""Temporal worker plugin for a specific Pydantic AI agent."""
127-
128-
def __init__(self, agent: TemporalAgent[Any, Any]):
129-
self.agent = agent
130-
131-
def init_worker_plugin(self, next: WorkerPlugin) -> None:
132-
self.next_worker_plugin = next
85+
def __init__(self):
86+
super().__init__( # type: ignore[reportUnknownMemberType]
87+
name='PydanticAIPlugin',
88+
data_converter=_data_converter,
89+
workflow_runner=_workflow_runner,
90+
workflow_failure_exception_types=[UserError, PydanticUserError],
91+
)
13392

134-
def configure_worker(self, config: WorkerConfig) -> WorkerConfig:
135-
activities: Sequence[Callable[..., Any]] = config.get('activities', []) # pyright: ignore[reportUnknownMemberType]
136-
# Activities are checked for name conflicts by Temporal.
137-
config['activities'] = [*activities, *self.agent.temporal_activities]
138-
return self.next_worker_plugin.configure_worker(config)
13993

140-
async def run_worker(self, worker: Worker) -> None:
141-
await self.next_worker_plugin.run_worker(worker)
142-
143-
def configure_replayer(self, config: ReplayerConfig) -> ReplayerConfig: # pragma: no cover
144-
return self.next_worker_plugin.configure_replayer(config)
94+
class AgentPlugin(SimplePlugin):
95+
"""Temporal worker plugin for a specific Pydantic AI agent."""
14596

146-
def run_replayer(
147-
self,
148-
replayer: Replayer,
149-
histories: AsyncIterator[WorkflowHistory],
150-
) -> AbstractAsyncContextManager[AsyncIterator[WorkflowReplayResult]]: # pragma: no cover
151-
return self.next_worker_plugin.run_replayer(replayer, histories)
97+
def __init__(self, agent: TemporalAgent[Any, Any]):
98+
super().__init__( # type: ignore[reportUnknownMemberType]
99+
name='AgentPlugin',
100+
activities=agent.temporal_activities,
101+
)

pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_agent.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,7 @@ async def _call_event_stream_handler_activity(
219219
) -> None:
220220
serialized_run_context = self.run_context_type.serialize_run_context(ctx)
221221
async for event in stream:
222-
await workflow.execute_activity( # pyright: ignore[reportUnknownMemberType]
222+
await workflow.execute_activity(
223223
activity=self.event_stream_handler_activity,
224224
args=[
225225
_EventStreamHandlerParams(

pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_function_toolset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ async def call_tool(
8181
tool_activity_config = self.activity_config | tool_activity_config
8282
serialized_run_context = self.run_context_type.serialize_run_context(ctx)
8383
return self._unwrap_call_tool_result(
84-
await workflow.execute_activity( # pyright: ignore[reportUnknownMemberType]
84+
await workflow.execute_activity(
8585
activity=self.call_tool_activity,
8686
args=[
8787
CallToolParams(

pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_logfire.py

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
from __future__ import annotations
22

3-
from collections.abc import Callable
3+
from collections.abc import Awaitable, Callable
44
from typing import TYPE_CHECKING
55

6-
from temporalio.client import ClientConfig, Plugin as ClientPlugin
6+
from temporalio.plugin import SimplePlugin
77
from temporalio.runtime import OpenTelemetryConfig, Runtime, TelemetryConfig
88
from temporalio.service import ConnectConfig, ServiceClient
99

@@ -19,12 +19,14 @@ def _default_setup_logfire() -> Logfire:
1919
return instance
2020

2121

22-
class LogfirePlugin(ClientPlugin):
22+
class LogfirePlugin(SimplePlugin):
2323
"""Temporal client plugin for Logfire."""
2424

2525
def __init__(self, setup_logfire: Callable[[], Logfire] = _default_setup_logfire, *, metrics: bool = True):
2626
try:
2727
import logfire # noqa: F401 # pyright: ignore[reportUnusedImport]
28+
from opentelemetry.trace import get_tracer
29+
from temporalio.contrib.opentelemetry import TracingInterceptor
2830
except ImportError as _import_error:
2931
raise ImportError(
3032
'Please install the `logfire` package to use the Logfire plugin, '
@@ -34,18 +36,14 @@ def __init__(self, setup_logfire: Callable[[], Logfire] = _default_setup_logfire
3436
self.setup_logfire = setup_logfire
3537
self.metrics = metrics
3638

37-
def init_client_plugin(self, next: ClientPlugin) -> None:
38-
self.next_client_plugin = next
39+
super().__init__( # type: ignore[reportUnknownMemberType]
40+
name='LogfirePlugin',
41+
client_interceptors=[TracingInterceptor(get_tracer('temporalio'))],
42+
)
3943

40-
def configure_client(self, config: ClientConfig) -> ClientConfig:
41-
from opentelemetry.trace import get_tracer
42-
from temporalio.contrib.opentelemetry import TracingInterceptor
43-
44-
interceptors = config.get('interceptors', [])
45-
config['interceptors'] = [*interceptors, TracingInterceptor(get_tracer('temporalio'))]
46-
return self.next_client_plugin.configure_client(config)
47-
48-
async def connect_service_client(self, config: ConnectConfig) -> ServiceClient:
44+
async def connect_service_client(
45+
self, config: ConnectConfig, next: Callable[[ConnectConfig], Awaitable[ServiceClient]]
46+
) -> ServiceClient:
4947
logfire = self.setup_logfire()
5048

5149
if self.metrics:
@@ -60,4 +58,4 @@ async def connect_service_client(self, config: ConnectConfig) -> ServiceClient:
6058
telemetry=TelemetryConfig(metrics=OpenTelemetryConfig(url=metrics_url, headers=headers))
6159
)
6260

63-
return await self.next_client_plugin.connect_service_client(config)
61+
return await next(config)

pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_mcp.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ async def get_tools(self, ctx: RunContext[AgentDepsT]) -> dict[str, ToolsetTool[
108108
return await super().get_tools(ctx)
109109

110110
serialized_run_context = self.run_context_type.serialize_run_context(ctx)
111-
tool_defs = await workflow.execute_activity( # pyright: ignore[reportUnknownMemberType]
111+
tool_defs = await workflow.execute_activity(
112112
activity=self.get_tools_activity,
113113
args=[
114114
_GetToolsParams(serialized_run_context=serialized_run_context),
@@ -131,7 +131,7 @@ async def call_tool(
131131
tool_activity_config = self.activity_config | self.tool_activity_config.get(name, {})
132132
serialized_run_context = self.run_context_type.serialize_run_context(ctx)
133133
return self._unwrap_call_tool_result(
134-
await workflow.execute_activity( # pyright: ignore[reportUnknownMemberType]
134+
await workflow.execute_activity(
135135
activity=self.call_tool_activity,
136136
args=[
137137
CallToolParams(

pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ async def request(
130130

131131
self._validate_model_request_parameters(model_request_parameters)
132132

133-
return await workflow.execute_activity( # pyright: ignore[reportUnknownMemberType]
133+
return await workflow.execute_activity(
134134
activity=self.request_activity,
135135
arg=_RequestParams(
136136
messages=messages,
@@ -168,7 +168,7 @@ async def request_stream(
168168
self._validate_model_request_parameters(model_request_parameters)
169169

170170
serialized_run_context = self.run_context_type.serialize_run_context(run_context)
171-
response = await workflow.execute_activity( # pyright: ignore[reportUnknownMemberType]
171+
response = await workflow.execute_activity(
172172
activity=self.request_stream_activity,
173173
args=[
174174
_RequestParams(

pydantic_ai_slim/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ ag-ui = ["ag-ui-protocol>=0.1.8", "starlette>=0.45.3"]
106106
# Retries
107107
retries = ["tenacity>=8.2.3"]
108108
# Temporal
109-
temporal = ["temporalio==1.18.2"]
109+
temporal = ["temporalio==1.19.0"]
110110
# DBOS
111111
dbos = ["dbos>=1.14.0"]
112112
# Prefect

0 commit comments

Comments
 (0)