Skip to content

Commit e6f7363

Browse files
RKestcopybara-github
authored andcommitted
fix: set execute_tool {tool.name} span attributes even when exception occurs during tool's execution
PiperOrigin-RevId: 824165197
1 parent 5d9a7e7 commit e6f7363

File tree

3 files changed

+103
-39
lines changed

3 files changed

+103
-39
lines changed

src/google/adk/flows/llm_flows/functions.py

Lines changed: 41 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -305,7 +305,9 @@ async def _execute_single_function_call_async(
305305
else:
306306
raise tool_error
307307

308-
with tracer.start_as_current_span(f'execute_tool {tool.name}'):
308+
async def _run_with_trace():
309+
nonlocal function_args
310+
309311
# Step 1: Check if plugin before_tool_callback overrides the function
310312
# response.
311313
function_response = (
@@ -391,13 +393,23 @@ async def _execute_single_function_call_async(
391393
function_response_event = __build_response_event(
392394
tool, function_response, tool_context, invocation_context
393395
)
394-
trace_tool_call(
395-
tool=tool,
396-
args=function_args,
397-
function_response_event=function_response_event,
398-
)
399396
return function_response_event
400397

398+
with tracer.start_as_current_span(f'execute_tool {tool.name}'):
399+
try:
400+
function_response_event = await _run_with_trace()
401+
trace_tool_call(
402+
tool=tool,
403+
args=function_args,
404+
function_response_event=function_response_event,
405+
)
406+
return function_response_event
407+
except:
408+
trace_tool_call(
409+
tool=tool, args=function_args, function_response_event=None
410+
)
411+
raise
412+
401413

402414
async def handle_function_calls_live(
403415
invocation_context: InvocationContext,
@@ -467,13 +479,17 @@ async def _execute_single_function_call_live(
467479
tool, tool_context = _get_tool_and_context(
468480
invocation_context, function_call, tools_dict
469481
)
470-
with tracer.start_as_current_span(f'execute_tool {tool.name}'):
482+
483+
function_args = (
484+
copy.deepcopy(function_call.args) if function_call.args else {}
485+
)
486+
487+
async def _run_with_trace():
488+
nonlocal function_args
489+
471490
# Do not use "args" as the variable name, because it is a reserved keyword
472491
# in python debugger.
473492
# Make a deep copy to avoid being modified.
474-
function_args = (
475-
copy.deepcopy(function_call.args) if function_call.args else {}
476-
)
477493
function_response = None
478494

479495
# Handle before_tool_callbacks - iterate through the canonical callback
@@ -527,13 +543,23 @@ async def _execute_single_function_call_live(
527543
function_response_event = __build_response_event(
528544
tool, function_response, tool_context, invocation_context
529545
)
530-
trace_tool_call(
531-
tool=tool,
532-
args=function_args,
533-
function_response_event=function_response_event,
534-
)
535546
return function_response_event
536547

548+
with tracer.start_as_current_span(f'execute_tool {tool.name}'):
549+
try:
550+
function_response_event = await _run_with_trace()
551+
trace_tool_call(
552+
tool=tool,
553+
args=function_args,
554+
function_response_event=function_response_event,
555+
)
556+
return function_response_event
557+
except:
558+
trace_tool_call(
559+
tool=tool, args=function_args, function_response_event=None
560+
)
561+
raise
562+
537563

538564
async def _process_function_live_helper(
539565
tool,

src/google/adk/telemetry/tracing.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import json
2727
import os
2828
from typing import Any
29+
from typing import Optional
2930
from typing import TYPE_CHECKING
3031

3132
from google.genai import types
@@ -118,7 +119,7 @@ def trace_agent_invocation(
118119
def trace_tool_call(
119120
tool: BaseTool,
120121
args: dict[str, Any],
121-
function_response_event: Event,
122+
function_response_event: Optional[Event],
122123
):
123124
"""Traces tool call.
124125
@@ -154,7 +155,8 @@ def trace_tool_call(
154155
tool_call_id = '<not specified>'
155156
tool_response = '<not specified>'
156157
if (
157-
function_response_event.content is not None
158+
function_response_event is not None
159+
and function_response_event.content is not None
158160
and function_response_event.content.parts
159161
):
160162
response_parts = function_response_event.content.parts
@@ -169,7 +171,8 @@ def trace_tool_call(
169171

170172
if not isinstance(tool_response, dict):
171173
tool_response = {'result': tool_response}
172-
span.set_attribute('gcp.vertex.agent.event_id', function_response_event.id)
174+
if function_response_event is not None:
175+
span.set_attribute('gcp.vertex.agent.event_id', function_response_event.id)
173176
if _should_add_request_response_to_spans():
174177
span.set_attribute(
175178
'gcp.vertex.agent.tool_response',

tests/unittests/telemetry/test_functional.py

Lines changed: 56 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -12,18 +12,22 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import asyncio
1516
import gc
1617
import sys
17-
from unittest import mock
1818

1919
from google.adk.agents import base_agent
2020
from google.adk.agents.llm_agent import Agent
2121
from google.adk.models.base_llm import BaseLlm
22+
from google.adk.models.llm_response import LlmResponse
2223
from google.adk.telemetry import tracing
2324
from google.adk.tools import FunctionTool
2425
from google.adk.utils.context_utils import Aclosing
26+
from google.genai.types import Content
2527
from google.genai.types import Part
26-
from opentelemetry.version import __version__
28+
from opentelemetry.sdk.trace import TracerProvider
29+
from opentelemetry.sdk.trace.export import SimpleSpanProcessor
30+
from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter
2731
import pytest
2832

2933
from ..testing_utils import MockModel
@@ -63,27 +67,27 @@ async def test_runner(test_agent: Agent) -> TestInMemoryRunner:
6367

6468

6569
@pytest.fixture
66-
def mock_start_as_current_span(monkeypatch: pytest.MonkeyPatch) -> mock.Mock:
67-
mock_context_manager = mock.MagicMock()
68-
mock_context_manager.__enter__.return_value = mock.Mock()
69-
mock_start_as_current_span = mock.Mock()
70-
mock_start_as_current_span.return_value = mock_context_manager
70+
def span_exporter(monkeypatch: pytest.MonkeyPatch) -> InMemorySpanExporter:
71+
tracer_provider = TracerProvider()
72+
span_exporter = InMemorySpanExporter()
73+
tracer_provider.add_span_processor(SimpleSpanProcessor(span_exporter))
74+
real_tracer = tracer_provider.get_tracer(__name__)
7175

7276
def do_replace(tracer):
7377
monkeypatch.setattr(
74-
tracer, 'start_as_current_span', mock_start_as_current_span
78+
tracer, 'start_as_current_span', real_tracer.start_as_current_span
7579
)
7680

7781
do_replace(tracing.tracer)
7882
do_replace(base_agent.tracer)
7983

80-
return mock_start_as_current_span
84+
return span_exporter
8185

8286

8387
@pytest.mark.asyncio
8488
async def test_tracer_start_as_current_span(
8589
test_runner: TestInMemoryRunner,
86-
mock_start_as_current_span: mock.Mock,
90+
span_exporter: InMemorySpanExporter,
8791
):
8892
"""Test creation of multiple spans in an E2E runner invocation.
8993
@@ -112,18 +116,49 @@ def wrapped_firstiter(coro):
112116
pass
113117

114118
# Assert
115-
expected_start_as_current_span_calls = [
116-
mock.call('invocation'),
117-
mock.call('execute_tool some_tool'),
118-
mock.call('invoke_agent some_root_agent'),
119-
mock.call('call_llm'),
120-
mock.call('call_llm'),
119+
spans = span_exporter.get_finished_spans()
120+
assert list(sorted(span.name for span in spans)) == [
121+
'call_llm',
122+
'call_llm',
123+
'execute_tool some_tool',
124+
'invocation',
125+
'invoke_agent some_root_agent',
121126
]
122127

123-
mock_start_as_current_span.assert_has_calls(
124-
expected_start_as_current_span_calls,
125-
any_order=True,
128+
129+
@pytest.mark.asyncio
130+
async def test_exception_preserves_attributes(
131+
test_model: BaseLlm, span_exporter: InMemorySpanExporter
132+
):
133+
"""Test when an exception occurs during tool execution, span attributes are still present on spans where they are expected."""
134+
135+
# Arrange
136+
async def some_tool():
137+
raise ValueError('This tool always fails')
138+
139+
test_agent = Agent(
140+
name='some_root_agent',
141+
model=test_model,
142+
tools=[
143+
FunctionTool(some_tool),
144+
],
126145
)
127-
assert mock_start_as_current_span.call_count == len(
128-
expected_start_as_current_span_calls
146+
147+
test_runner = TestInMemoryRunner(test_agent)
148+
149+
# Act
150+
with pytest.raises(ValueError, match='This tool always fails'):
151+
async with Aclosing(
152+
test_runner.run_async_with_new_session_agen('')
153+
) as agen:
154+
async for _ in agen:
155+
pass
156+
157+
# Assert
158+
spans = span_exporter.get_finished_spans()
159+
assert len(spans) > 1
160+
assert all(
161+
span.attributes is not None and len(span.attributes) > 0
162+
for span in spans
163+
if span.name != 'invocation' # not expected to have attributes
129164
)

0 commit comments

Comments
 (0)