diff --git a/docs/logfire.md b/docs/logfire.md index 94fe349340..ea15b058c0 100644 --- a/docs/logfire.md +++ b/docs/logfire.md @@ -356,3 +356,20 @@ Agent.instrument_all(instrumentation_settings) ``` This setting is particularly useful in production environments where compliance requirements or data sensitivity concerns make it necessary to limit what content is sent to your observability platform. + +### Adding Custom Metadata + +Use the agent's `metadata` parameter to attach additional data to the agent's span. +Metadata can be provided as a string, a dictionary, or a callable that reads the [`RunContext`][pydantic_ai.tools.RunContext] to compute values on each run. + +```python {hl_lines="4-5"} +from pydantic_ai import Agent + +agent = Agent( + 'openai:gpt-5', + instrument=True, + metadata=lambda ctx: {'deployment': 'staging', 'tenant': ctx.deps.tenant}, +) +``` + +When instrumentation is enabled, the resolved metadata is recorded on the agent span under the `logfire.agent.metadata` attribute. diff --git a/pydantic_ai_slim/pydantic_ai/agent/__init__.py b/pydantic_ai_slim/pydantic_ai/agent/__init__.py index 4cd353b44a..b44785e372 100644 --- a/pydantic_ai_slim/pydantic_ai/agent/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/agent/__init__.py @@ -32,6 +32,7 @@ HistoryProcessor, ModelRequestNode, UserPromptNode, + build_run_context, capture_run_messages, ) from .._output import OutputToolset @@ -89,6 +90,8 @@ S = TypeVar('S') NoneType = type(None) +AgentMetadataValue = str | dict[str, str] | Callable[[RunContext[AgentDepsT]], str | dict[str, str]] + @dataclasses.dataclass(init=False) class Agent(AbstractAgent[AgentDepsT, OutputDataT]): @@ -130,6 +133,7 @@ class Agent(AbstractAgent[AgentDepsT, OutputDataT]): """Options to automatically instrument with OpenTelemetry.""" _instrument_default: ClassVar[InstrumentationSettings | bool] = False + _metadata: AgentMetadataValue[AgentDepsT] | None = dataclasses.field(repr=False) _deps_type: type[AgentDepsT] = dataclasses.field(repr=False) _output_schema: _output.OutputSchema[OutputDataT] = dataclasses.field(repr=False) @@ -175,6 +179,7 @@ def __init__( defer_model_check: bool = False, end_strategy: EndStrategy = 'early', instrument: InstrumentationSettings | bool | None = None, + metadata: AgentMetadataValue[AgentDepsT] | None = None, history_processors: Sequence[HistoryProcessor[AgentDepsT]] | None = None, event_stream_handler: EventStreamHandler[AgentDepsT] | None = None, ) -> None: ... @@ -201,6 +206,7 @@ def __init__( defer_model_check: bool = False, end_strategy: EndStrategy = 'early', instrument: InstrumentationSettings | bool | None = None, + metadata: AgentMetadataValue[AgentDepsT] | None = None, history_processors: Sequence[HistoryProcessor[AgentDepsT]] | None = None, event_stream_handler: EventStreamHandler[AgentDepsT] | None = None, ) -> None: ... @@ -225,6 +231,7 @@ def __init__( defer_model_check: bool = False, end_strategy: EndStrategy = 'early', instrument: InstrumentationSettings | bool | None = None, + metadata: AgentMetadataValue[AgentDepsT] | None = None, history_processors: Sequence[HistoryProcessor[AgentDepsT]] | None = None, event_stream_handler: EventStreamHandler[AgentDepsT] | None = None, **_deprecated_kwargs: Any, @@ -276,6 +283,10 @@ def __init__( [`Agent.instrument_all()`][pydantic_ai.Agent.instrument_all] will be used, which defaults to False. See the [Debugging and Monitoring guide](https://ai.pydantic.dev/logfire/) for more info. + metadata: Optional metadata to attach to telemetry for this agent. + Provide a string literal, a dict of string keys and values, or a callable returning one of those values + computed from the [`RunContext`][pydantic_ai.tools.RunContext] on each run. + Metadata is only recorded when instrumentation is enabled. history_processors: Optional list of callables to process the message history before sending it to the model. Each processor takes a list of messages and returns a modified list of messages. Processors can be sync or async and are applied in sequence. @@ -292,6 +303,7 @@ def __init__( self._output_type = output_type self.instrument = instrument + self._metadata = metadata self._deps_type = deps_type if mcp_servers := _deprecated_kwargs.pop('mcp_servers', None): @@ -349,6 +361,9 @@ def __init__( self._override_instructions: ContextVar[ _utils.Option[list[str | _system_prompt.SystemPromptFunc[AgentDepsT]]] ] = ContextVar('_override_instructions', default=None) + self._override_metadata: ContextVar[_utils.Option[AgentMetadataValue[AgentDepsT]]] = ContextVar( + '_override_metadata', default=None + ) self._enter_lock = Lock() self._entered_count = 0 @@ -645,6 +660,7 @@ async def get_instructions(run_context: RunContext[AgentDepsT]) -> str | None: }, ) + run_metadata: str | dict[str, str] | None = None try: async with graph.iter( inputs=user_prompt_node, @@ -656,8 +672,10 @@ async def get_instructions(run_context: RunContext[AgentDepsT]) -> str | None: async with toolset: agent_run = AgentRun(graph_run) yield agent_run - if (final_result := agent_run.result) is not None and run_span.is_recording(): - if instrumentation_settings and instrumentation_settings.include_content: + final_result = agent_run.result + if instrumentation_settings and run_span.is_recording(): + run_metadata = self._compute_agent_metadata(build_run_context(agent_run.ctx)) + if instrumentation_settings.include_content and final_result is not None: run_span.set_attribute( 'final_result', ( @@ -671,18 +689,32 @@ async def get_instructions(run_context: RunContext[AgentDepsT]) -> str | None: if instrumentation_settings and run_span.is_recording(): run_span.set_attributes( self._run_span_end_attributes( - instrumentation_settings, usage, state.message_history, graph_deps.new_message_index + instrumentation_settings, + usage, + state.message_history, + graph_deps.new_message_index, + run_metadata, ) ) finally: run_span.end() + def _compute_agent_metadata(self, ctx: RunContext[AgentDepsT]) -> str | dict[str, str] | None: + metadata_override = self._override_metadata.get() + metadata_config = metadata_override.value if metadata_override is not None else self._metadata + if metadata_config is None: + return None + + metadata = metadata_config(ctx) if callable(metadata_config) else metadata_config + return metadata + def _run_span_end_attributes( self, settings: InstrumentationSettings, usage: _usage.RunUsage, message_history: list[_messages.ModelMessage], new_message_index: int, + metadata: str | dict[str, str] | None = None, ): if settings.version == 1: attrs = { @@ -716,6 +748,12 @@ def _run_span_end_attributes( ): attrs['pydantic_ai.variable_instructions'] = True + if metadata is not None: + if isinstance(metadata, dict): + attrs['logfire.agent.metadata'] = json.dumps(metadata) + else: + attrs['logfire.agent.metadata'] = metadata + return { **usage.opentelemetry_attributes(), **attrs, @@ -740,6 +778,7 @@ def override( toolsets: Sequence[AbstractToolset[AgentDepsT]] | _utils.Unset = _utils.UNSET, tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] | _utils.Unset = _utils.UNSET, instructions: Instructions[AgentDepsT] | _utils.Unset = _utils.UNSET, + metadata: AgentMetadataValue[AgentDepsT] | _utils.Unset = _utils.UNSET, ) -> Iterator[None]: """Context manager to temporarily override agent name, dependencies, model, toolsets, tools, or instructions. @@ -753,6 +792,7 @@ def override( toolsets: The toolsets to use instead of the toolsets passed to the agent constructor and agent run. tools: The tools to use instead of the tools registered with the agent. instructions: The instructions to use instead of the instructions registered with the agent. + metadata: The metadata to use instead of the metadata passed to the agent constructor. """ if _utils.is_set(name): name_token = self._override_name.set(_utils.Some(name)) @@ -785,6 +825,11 @@ def override( else: instructions_token = None + if _utils.is_set(metadata): + metadata_token = self._override_metadata.set(_utils.Some(metadata)) + else: + metadata_token = None + try: yield finally: @@ -800,6 +845,8 @@ def override( self._override_tools.reset(tools_token) if instructions_token is not None: self._override_instructions.reset(instructions_token) + if metadata_token is not None: + self._override_metadata.reset(metadata_token) @overload def instructions( diff --git a/tests/test_logfire.py b/tests/test_logfire.py index dadb930dd0..514d5c4aec 100644 --- a/tests/test_logfire.py +++ b/tests/test_logfire.py @@ -120,6 +120,7 @@ async def my_ret(x: int) -> str: model=TestModel(), toolsets=[toolset], instrument=instrument, + metadata={'env': 'test'}, ) result = my_agent.run_sync('Hello') @@ -314,12 +315,14 @@ async def my_ret(x: int) -> str: ] ) ), + 'logfire.agent.metadata': '{"env": "test"}', 'logfire.json_schema': IsJson( snapshot( { 'type': 'object', 'properties': { 'pydantic_ai.all_messages': {'type': 'array'}, + 'logfire.agent.metadata': {'type': 'array'}, 'final_result': {'type': 'object'}, }, } @@ -379,12 +382,14 @@ async def my_ret(x: int) -> str: ) ), 'final_result': '{"my_ret":"1"}', + 'logfire.agent.metadata': '{"env": "test"}', 'logfire.json_schema': IsJson( snapshot( { 'type': 'object', 'properties': { 'all_messages_events': {'type': 'array'}, + 'logfire.agent.metadata': {'type': 'array'}, 'final_result': {'type': 'object'}, }, } @@ -569,6 +574,46 @@ async def my_ret(x: int) -> str: ) +def _test_logfire_metadata_values_callable_dict(ctx: RunContext[Any]) -> dict[str, str]: + return {'model_name': ctx.model.model_name} + + +def _test_logfire_metadata_values_callable_string(_ctx: RunContext[Any]) -> str: + return 'callable-str' + + +@pytest.mark.skipif(not logfire_installed, reason='logfire not installed') +@pytest.mark.parametrize( + ('metadata', 'expected'), + [ + pytest.param({'env': 'test'}, '{"env": "test"}', id='dict'), + pytest.param('staging', 'staging', id='literal-string'), + pytest.param(_test_logfire_metadata_values_callable_dict, '{"model_name": "test"}', id='callable-dict'), + pytest.param(_test_logfire_metadata_values_callable_string, 'callable-str', id='callable-string'), + ], +) +def test_logfire_metadata_values( + get_logfire_summary: Callable[[], LogfireSummary], + metadata: str | dict[str, str] | Callable[[RunContext[Any]], str | dict[str, str]], + expected: str | dict[str, str], +) -> None: + agent = Agent(model=TestModel(), instrument=InstrumentationSettings(version=2), metadata=metadata) + agent.run_sync('Hello') + + summary = get_logfire_summary() + assert summary.attributes[0]['logfire.agent.metadata'] == expected + + +@pytest.mark.skipif(not logfire_installed, reason='logfire not installed') +def test_logfire_metadata_override(get_logfire_summary: Callable[[], LogfireSummary]) -> None: + agent = Agent(model=TestModel(), instrument=InstrumentationSettings(version=2), metadata='base') + with agent.override(metadata={'env': 'override'}): + agent.run_sync('Hello') + + summary = get_logfire_summary() + assert summary.attributes[0]['logfire.agent.metadata'] == '{"env": "override"}' + + @pytest.mark.skipif(not logfire_installed, reason='logfire not installed') @pytest.mark.parametrize( 'instrument',