|
1 | 1 | from __future__ import annotations |
2 | 2 |
|
3 | 3 | import warnings |
4 | | -from collections.abc import AsyncIterator, Callable, Sequence |
5 | | -from contextlib import AbstractAsyncContextManager |
6 | 4 | from dataclasses import replace |
7 | 5 | from typing import Any |
8 | 6 |
|
9 | 7 | from pydantic.errors import PydanticUserError |
10 | | -from temporalio.client import ClientConfig, Plugin as ClientPlugin, WorkflowHistory |
11 | 8 | from temporalio.contrib.pydantic import PydanticPayloadConverter, pydantic_data_converter |
12 | 9 | from temporalio.converter import DataConverter, DefaultPayloadConverter |
13 | | -from temporalio.service import ConnectConfig, ServiceClient |
14 | | -from temporalio.worker import ( |
15 | | - Plugin as WorkerPlugin, |
16 | | - Replayer, |
17 | | - ReplayerConfig, |
18 | | - Worker, |
19 | | - WorkerConfig, |
20 | | - WorkflowReplayResult, |
21 | | -) |
| 10 | +from temporalio.plugin import SimplePlugin |
| 11 | +from temporalio.worker import WorkflowRunner |
22 | 12 | from temporalio.worker.workflow_sandbox import SandboxedWorkflowRunner |
23 | 13 |
|
24 | 14 | from ...exceptions import UserError |
|
48 | 38 | pass |
49 | 39 |
|
50 | 40 |
|
51 | | -class PydanticAIPlugin(ClientPlugin, WorkerPlugin): |
| 41 | +def _data_converter(converter: DataConverter | None) -> DataConverter: |
| 42 | + if converter and converter.payload_converter_class not in ( |
| 43 | + DefaultPayloadConverter, |
| 44 | + PydanticPayloadConverter, |
| 45 | + ): |
| 46 | + warnings.warn( # pragma: no cover |
| 47 | + 'A non-default Temporal data converter was used which has been replaced with the Pydantic data converter.' |
| 48 | + ) |
| 49 | + |
| 50 | + return pydantic_data_converter |
| 51 | + |
| 52 | + |
| 53 | +def _workflow_runner(runner: WorkflowRunner | None) -> WorkflowRunner: |
| 54 | + if not runner: |
| 55 | + raise ValueError('No WorkflowRunner provided to the Pydantic AI plugin.') # pragma: no cover |
| 56 | + |
| 57 | + if not isinstance(runner, SandboxedWorkflowRunner): |
| 58 | + return runner # pragma: no cover |
| 59 | + |
| 60 | + return replace( |
| 61 | + runner, |
| 62 | + restrictions=runner.restrictions.with_passthrough_modules( |
| 63 | + 'pydantic_ai', |
| 64 | + 'pydantic', |
| 65 | + 'pydantic_core', |
| 66 | + 'logfire', |
| 67 | + 'rich', |
| 68 | + 'httpx', |
| 69 | + 'anyio', |
| 70 | + 'httpcore', |
| 71 | + # Used by fastmcp via py-key-value-aio |
| 72 | + 'beartype', |
| 73 | + # Imported inside `logfire._internal.json_encoder` when running `logfire.info` inside an activity with attributes to serialize |
| 74 | + 'attrs', |
| 75 | + # Imported inside `logfire._internal.json_schema` when running `logfire.info` inside an activity with attributes to serialize |
| 76 | + 'numpy', |
| 77 | + 'pandas', |
| 78 | + ), |
| 79 | + ) |
| 80 | + |
| 81 | + |
| 82 | +class PydanticAIPlugin(SimplePlugin): |
52 | 83 | """Temporal client and worker plugin for Pydantic AI.""" |
53 | 84 |
|
54 | | - def init_client_plugin(self, next: ClientPlugin) -> None: |
55 | | - self.next_client_plugin = next |
56 | | - |
57 | | - def init_worker_plugin(self, next: WorkerPlugin) -> None: |
58 | | - self.next_worker_plugin = next |
59 | | - |
60 | | - def configure_client(self, config: ClientConfig) -> ClientConfig: |
61 | | - config['data_converter'] = self._get_new_data_converter(config.get('data_converter')) |
62 | | - return self.next_client_plugin.configure_client(config) |
63 | | - |
64 | | - def configure_worker(self, config: WorkerConfig) -> WorkerConfig: |
65 | | - runner = config.get('workflow_runner') # pyright: ignore[reportUnknownMemberType] |
66 | | - if isinstance(runner, SandboxedWorkflowRunner): # pragma: no branch |
67 | | - config['workflow_runner'] = replace( |
68 | | - runner, |
69 | | - restrictions=runner.restrictions.with_passthrough_modules( |
70 | | - 'pydantic_ai', |
71 | | - 'pydantic', |
72 | | - 'pydantic_core', |
73 | | - 'logfire', |
74 | | - 'rich', |
75 | | - 'httpx', |
76 | | - 'anyio', |
77 | | - 'httpcore', |
78 | | - # Used by fastmcp via py-key-value-aio |
79 | | - 'beartype', |
80 | | - # Imported inside `logfire._internal.json_encoder` when running `logfire.info` inside an activity with attributes to serialize |
81 | | - 'attrs', |
82 | | - # Imported inside `logfire._internal.json_schema` when running `logfire.info` inside an activity with attributes to serialize |
83 | | - 'numpy', |
84 | | - 'pandas', |
85 | | - ), |
86 | | - ) |
87 | | - |
88 | | - config['workflow_failure_exception_types'] = [ |
89 | | - *config.get('workflow_failure_exception_types', []), # pyright: ignore[reportUnknownMemberType] |
90 | | - UserError, |
91 | | - PydanticUserError, |
92 | | - ] |
93 | | - |
94 | | - return self.next_worker_plugin.configure_worker(config) |
95 | | - |
96 | | - async def connect_service_client(self, config: ConnectConfig) -> ServiceClient: |
97 | | - return await self.next_client_plugin.connect_service_client(config) |
98 | | - |
99 | | - async def run_worker(self, worker: Worker) -> None: |
100 | | - await self.next_worker_plugin.run_worker(worker) |
101 | | - |
102 | | - def configure_replayer(self, config: ReplayerConfig) -> ReplayerConfig: # pragma: no cover |
103 | | - config['data_converter'] = self._get_new_data_converter(config.get('data_converter')) # pyright: ignore[reportUnknownMemberType] |
104 | | - return self.next_worker_plugin.configure_replayer(config) |
105 | | - |
106 | | - def run_replayer( |
107 | | - self, |
108 | | - replayer: Replayer, |
109 | | - histories: AsyncIterator[WorkflowHistory], |
110 | | - ) -> AbstractAsyncContextManager[AsyncIterator[WorkflowReplayResult]]: # pragma: no cover |
111 | | - return self.next_worker_plugin.run_replayer(replayer, histories) |
112 | | - |
113 | | - def _get_new_data_converter(self, converter: DataConverter | None) -> DataConverter: |
114 | | - if converter and converter.payload_converter_class not in ( |
115 | | - DefaultPayloadConverter, |
116 | | - PydanticPayloadConverter, |
117 | | - ): |
118 | | - warnings.warn( # pragma: no cover |
119 | | - 'A non-default Temporal data converter was used which has been replaced with the Pydantic data converter.' |
120 | | - ) |
121 | | - |
122 | | - return pydantic_data_converter |
123 | | - |
124 | | - |
125 | | -class AgentPlugin(WorkerPlugin): |
126 | | - """Temporal worker plugin for a specific Pydantic AI agent.""" |
127 | | - |
128 | | - def __init__(self, agent: TemporalAgent[Any, Any]): |
129 | | - self.agent = agent |
130 | | - |
131 | | - def init_worker_plugin(self, next: WorkerPlugin) -> None: |
132 | | - self.next_worker_plugin = next |
| 85 | + def __init__(self): |
| 86 | + super().__init__( # type: ignore[reportUnknownMemberType] |
| 87 | + name='PydanticAIPlugin', |
| 88 | + data_converter=_data_converter, |
| 89 | + workflow_runner=_workflow_runner, |
| 90 | + workflow_failure_exception_types=[UserError, PydanticUserError], |
| 91 | + ) |
133 | 92 |
|
134 | | - def configure_worker(self, config: WorkerConfig) -> WorkerConfig: |
135 | | - activities: Sequence[Callable[..., Any]] = config.get('activities', []) # pyright: ignore[reportUnknownMemberType] |
136 | | - # Activities are checked for name conflicts by Temporal. |
137 | | - config['activities'] = [*activities, *self.agent.temporal_activities] |
138 | | - return self.next_worker_plugin.configure_worker(config) |
139 | 93 |
|
140 | | - async def run_worker(self, worker: Worker) -> None: |
141 | | - await self.next_worker_plugin.run_worker(worker) |
142 | | - |
143 | | - def configure_replayer(self, config: ReplayerConfig) -> ReplayerConfig: # pragma: no cover |
144 | | - return self.next_worker_plugin.configure_replayer(config) |
| 94 | +class AgentPlugin(SimplePlugin): |
| 95 | + """Temporal worker plugin for a specific Pydantic AI agent.""" |
145 | 96 |
|
146 | | - def run_replayer( |
147 | | - self, |
148 | | - replayer: Replayer, |
149 | | - histories: AsyncIterator[WorkflowHistory], |
150 | | - ) -> AbstractAsyncContextManager[AsyncIterator[WorkflowReplayResult]]: # pragma: no cover |
151 | | - return self.next_worker_plugin.run_replayer(replayer, histories) |
| 97 | + def __init__(self, agent: TemporalAgent[Any, Any]): |
| 98 | + super().__init__( # type: ignore[reportUnknownMemberType] |
| 99 | + name='AgentPlugin', |
| 100 | + activities=agent.temporal_activities, |
| 101 | + ) |
0 commit comments