Skip to content

Commit 4001eb6

Browse files
authored
feat: added method for multiagent spans (#451)
1 parent 1ec793d commit 4001eb6

File tree

4 files changed

+186
-30
lines changed

4 files changed

+186
-30
lines changed

src/strands/multiagent/graph.py

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,10 @@
2121
from dataclasses import dataclass, field
2222
from typing import Any, Callable, Tuple, cast
2323

24+
from opentelemetry import trace as trace_api
25+
2426
from ..agent import Agent, AgentResult
27+
from ..telemetry import get_tracer
2528
from ..types.content import ContentBlock
2629
from ..types.event_loop import Metrics, Usage
2730
from .base import MultiAgentBase, MultiAgentResult, NodeResult, Status
@@ -249,6 +252,7 @@ def __init__(self, nodes: dict[str, GraphNode], edges: set[GraphEdge], entry_poi
249252
self.edges = edges
250253
self.entry_points = entry_points
251254
self.state = GraphState()
255+
self.tracer = get_tracer()
252256

253257
def execute(self, task: str | list[ContentBlock]) -> GraphResult:
254258
"""Execute task synchronously."""
@@ -274,19 +278,20 @@ async def execute_async(self, task: str | list[ContentBlock]) -> GraphResult:
274278
)
275279

276280
start_time = time.time()
277-
try:
278-
await self._execute_graph()
279-
self.state.status = Status.COMPLETED
280-
logger.debug("status=<%s> | graph execution completed", self.state.status)
281-
282-
except Exception:
283-
logger.exception("graph execution failed")
284-
self.state.status = Status.FAILED
285-
raise
286-
finally:
287-
self.state.execution_time = round((time.time() - start_time) * 1000)
288-
289-
return self._build_result()
281+
span = self.tracer.start_multiagent_span(task, "graph")
282+
with trace_api.use_span(span, end_on_exit=True):
283+
try:
284+
await self._execute_graph()
285+
self.state.status = Status.COMPLETED
286+
logger.debug("status=<%s> | graph execution completed", self.state.status)
287+
288+
except Exception:
289+
logger.exception("graph execution failed")
290+
self.state.status = Status.FAILED
291+
raise
292+
finally:
293+
self.state.execution_time = round((time.time() - start_time) * 1000)
294+
return self._build_result()
290295

291296
async def _execute_graph(self) -> None:
292297
"""Unified execution flow with conditional routing."""

src/strands/telemetry/tracer.py

Lines changed: 46 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from opentelemetry.trace import Span, StatusCode
1515

1616
from ..agent.agent_result import AgentResult
17-
from ..types.content import Message, Messages
17+
from ..types.content import ContentBlock, Message, Messages
1818
from ..types.streaming import StopReason, Usage
1919
from ..types.tools import ToolResult, ToolUse
2020
from ..types.traces import AttributeValue
@@ -86,8 +86,6 @@ def __init__(
8686
"""Initialize the tracer."""
8787
self.service_name = __name__
8888
self.tracer_provider: Optional[trace_api.TracerProvider] = None
89-
self.tracer: Optional[trace_api.Tracer] = None
90-
9189
self.tracer_provider = trace_api.get_tracer_provider()
9290
self.tracer = self.tracer_provider.get_tracer(self.service_name)
9391
ThreadingInstrumentor().instrument()
@@ -98,7 +96,7 @@ def _start_span(
9896
parent_span: Optional[Span] = None,
9997
attributes: Optional[Dict[str, AttributeValue]] = None,
10098
span_kind: trace_api.SpanKind = trace_api.SpanKind.INTERNAL,
101-
) -> Optional[Span]:
99+
) -> Span:
102100
"""Generic helper method to start a span with common attributes.
103101
104102
Args:
@@ -110,10 +108,13 @@ def _start_span(
110108
Returns:
111109
The created span, or None if tracing is not enabled
112110
"""
113-
if self.tracer is None:
114-
return None
111+
if not parent_span:
112+
parent_span = trace_api.get_current_span()
113+
114+
context = None
115+
if parent_span and parent_span.is_recording() and parent_span != trace_api.INVALID_SPAN:
116+
context = trace_api.set_span_in_context(parent_span)
115117

116-
context = trace_api.set_span_in_context(parent_span) if parent_span else None
117118
span = self.tracer.start_span(name=span_name, context=context, kind=span_kind)
118119

119120
# Set start time as a common attribute
@@ -235,7 +236,7 @@ def start_model_invoke_span(
235236
# Add additional kwargs as attributes
236237
attributes.update({k: v for k, v in kwargs.items() if isinstance(v, (str, int, float, bool))})
237238

238-
span = self._start_span("chat", parent_span, attributes, span_kind=trace_api.SpanKind.CLIENT)
239+
span = self._start_span("chat", parent_span, attributes=attributes, span_kind=trace_api.SpanKind.CLIENT)
239240
for message in messages:
240241
self._add_event(
241242
span,
@@ -293,8 +294,8 @@ def start_tool_call_span(self, tool: ToolUse, parent_span: Optional[Span] = None
293294
# Add additional kwargs as attributes
294295
attributes.update(kwargs)
295296

296-
span_name = f"Tool: {tool['name']}"
297-
span = self._start_span(span_name, parent_span, attributes, span_kind=trace_api.SpanKind.INTERNAL)
297+
span_name = f"execute_tool {tool['name']}"
298+
span = self._start_span(span_name, parent_span, attributes=attributes, span_kind=trace_api.SpanKind.INTERNAL)
298299

299300
self._add_event(
300301
span,
@@ -497,6 +498,41 @@ def end_agent_span(
497498

498499
self._end_span(span, attributes, error)
499500

501+
def start_multiagent_span(
502+
self,
503+
task: str | list[ContentBlock],
504+
instance: str,
505+
) -> Span:
506+
"""Start a new span for swarm invocation."""
507+
attributes: Dict[str, AttributeValue] = {
508+
"gen_ai.system": "strands-agents",
509+
"gen_ai.agent.name": instance,
510+
"gen_ai.operation.name": f"invoke_{instance}",
511+
}
512+
513+
span = self._start_span(f"invoke_{instance}", attributes=attributes, span_kind=trace_api.SpanKind.CLIENT)
514+
content = serialize(task) if isinstance(task, list) else task
515+
self._add_event(
516+
span,
517+
"gen_ai.user.message",
518+
event_attributes={"content": content},
519+
)
520+
521+
return span
522+
523+
def end_swarm_span(
524+
self,
525+
span: Span,
526+
result: Optional[str] = None,
527+
) -> None:
528+
"""End a swarm span with results."""
529+
if result:
530+
self._add_event(
531+
span,
532+
"gen_ai.choice",
533+
event_attributes={"message": result},
534+
)
535+
500536

501537
# Singleton instance for global access
502538
_tracer_instance = None

tests/strands/multiagent/test_graph.py

Lines changed: 36 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from unittest.mock import AsyncMock, MagicMock, Mock
1+
from unittest.mock import AsyncMock, MagicMock, Mock, patch
22

33
import pytest
44

@@ -101,6 +101,22 @@ def string_content_agent():
101101
return agent
102102

103103

104+
@pytest.fixture
105+
def mock_strands_tracer():
106+
with patch("strands.multiagent.graph.get_tracer") as mock_get_tracer:
107+
mock_tracer_instance = MagicMock()
108+
mock_span = MagicMock()
109+
mock_tracer_instance.start_multiagent_span.return_value = mock_span
110+
mock_get_tracer.return_value = mock_tracer_instance
111+
yield mock_tracer_instance
112+
113+
114+
@pytest.fixture
115+
def mock_use_span():
116+
with patch("strands.multiagent.graph.trace_api.use_span") as mock_use_span:
117+
yield mock_use_span
118+
119+
104120
@pytest.fixture
105121
def mock_graph(mock_agents, string_content_agent):
106122
"""Create a graph for testing various scenarios."""
@@ -138,8 +154,9 @@ def always_false_condition(state: GraphState) -> bool:
138154

139155

140156
@pytest.mark.asyncio
141-
async def test_graph_execution(mock_graph, mock_agents, string_content_agent):
157+
async def test_graph_execution(mock_strands_tracer, mock_use_span, mock_graph, mock_agents, string_content_agent):
142158
"""Test comprehensive graph execution with diverse nodes and conditional edges."""
159+
143160
# Test graph structure
144161
assert len(mock_graph.nodes) == 8
145162
assert len(mock_graph.edges) == 8
@@ -214,9 +231,12 @@ async def test_graph_execution(mock_graph, mock_agents, string_content_agent):
214231
assert len(result.entry_points) == 1
215232
assert result.entry_points[0].node_id == "start_agent"
216233

234+
mock_strands_tracer.start_multiagent_span.assert_called()
235+
mock_use_span.assert_called_once()
236+
217237

218238
@pytest.mark.asyncio
219-
async def test_graph_unsupported_node_type():
239+
async def test_graph_unsupported_node_type(mock_strands_tracer, mock_use_span):
220240
"""Test unsupported executor type error handling."""
221241

222242
class UnsupportedExecutor:
@@ -229,9 +249,12 @@ class UnsupportedExecutor:
229249
with pytest.raises(ValueError, match="Node 'unsupported_node' of type.*is not supported"):
230250
await graph.execute_async("test task")
231251

252+
mock_strands_tracer.start_multiagent_span.assert_called()
253+
mock_use_span.assert_called_once()
254+
232255

233256
@pytest.mark.asyncio
234-
async def test_graph_execution_with_failures():
257+
async def test_graph_execution_with_failures(mock_strands_tracer, mock_use_span):
235258
"""Test graph execution error handling and failure propagation."""
236259
failing_agent = Mock(spec=Agent)
237260
failing_agent.name = "failing_agent"
@@ -261,10 +284,12 @@ async def mock_stream_failure(*args, **kwargs):
261284
assert graph.state.status == Status.FAILED
262285
assert any(node.node_id == "fail_node" for node in graph.state.failed_nodes)
263286
assert len(graph.state.completed_nodes) == 0
287+
mock_strands_tracer.start_multiagent_span.assert_called()
288+
mock_use_span.assert_called_once()
264289

265290

266291
@pytest.mark.asyncio
267-
async def test_graph_edge_cases():
292+
async def test_graph_edge_cases(mock_strands_tracer, mock_use_span):
268293
"""Test specific edge cases for coverage."""
269294
# Test entry node execution without dependencies
270295
entry_agent = create_mock_agent("entry_agent", "Entry response")
@@ -278,6 +303,8 @@ async def test_graph_edge_cases():
278303
# Verify entry node was called with original task
279304
entry_agent.stream_async.assert_called_once_with([{"text": "Original task"}])
280305
assert result.status == Status.COMPLETED
306+
mock_strands_tracer.start_multiagent_span.assert_called()
307+
mock_use_span.assert_called_once()
281308

282309

283310
def test_graph_builder_validation():
@@ -415,7 +442,7 @@ def test_condition(state):
415442
assert len(node.dependencies) == 0
416443

417444

418-
def test_graph_synchronous_execution(mock_agents):
445+
def test_graph_synchronous_execution(mock_strands_tracer, mock_use_span, mock_agents):
419446
"""Test synchronous graph execution using execute method."""
420447
builder = GraphBuilder()
421448
builder.add_node(mock_agents["start_agent"], "start_agent")
@@ -444,3 +471,6 @@ def test_graph_synchronous_execution(mock_agents):
444471
# Verify return type is GraphResult
445472
assert isinstance(result, GraphResult)
446473
assert isinstance(result, MultiAgentResult)
474+
475+
mock_strands_tracer.start_multiagent_span.assert_called()
476+
mock_use_span.assert_called_once()

tests/strands/telemetry/test_tracer.py

Lines changed: 86 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
)
1111

1212
from strands.telemetry.tracer import JSONEncoder, Tracer, get_tracer, serialize
13+
from strands.types.content import ContentBlock
1314
from strands.types.streaming import StopReason, Usage
1415

1516

@@ -198,7 +199,91 @@ def test_start_tool_call_span(mock_tracer):
198199
span = tracer.start_tool_call_span(tool)
199200

200201
mock_tracer.start_span.assert_called_once()
201-
assert mock_tracer.start_span.call_args[1]["name"] == "Tool: test-tool"
202+
assert mock_tracer.start_span.call_args[1]["name"] == "execute_tool test-tool"
203+
mock_span.set_attribute.assert_any_call("gen_ai.tool.name", "test-tool")
204+
mock_span.set_attribute.assert_any_call("gen_ai.system", "strands-agents")
205+
mock_span.set_attribute.assert_any_call("gen_ai.operation.name", "execute_tool")
206+
mock_span.set_attribute.assert_any_call("gen_ai.tool.call.id", "123")
207+
mock_span.add_event.assert_any_call(
208+
"gen_ai.tool.message", attributes={"role": "tool", "content": json.dumps({"param": "value"}), "id": "123"}
209+
)
210+
assert span is not None
211+
212+
213+
def test_start_swarm_call_span_with_string_task(mock_tracer):
214+
"""Test starting a swarm call span with task as string."""
215+
with mock.patch("strands.telemetry.tracer.trace_api.get_tracer", return_value=mock_tracer):
216+
tracer = Tracer()
217+
tracer.tracer = mock_tracer
218+
219+
mock_span = mock.MagicMock()
220+
mock_tracer.start_span.return_value = mock_span
221+
222+
task = "Design foo bar"
223+
224+
span = tracer.start_multiagent_span(task, "swarm")
225+
226+
mock_tracer.start_span.assert_called_once()
227+
assert mock_tracer.start_span.call_args[1]["name"] == "invoke_swarm"
228+
mock_span.set_attribute.assert_any_call("gen_ai.system", "strands-agents")
229+
mock_span.set_attribute.assert_any_call("gen_ai.agent.name", "swarm")
230+
mock_span.set_attribute.assert_any_call("gen_ai.operation.name", "invoke_swarm")
231+
mock_span.add_event.assert_any_call("gen_ai.user.message", attributes={"content": "Design foo bar"})
232+
assert span is not None
233+
234+
235+
def test_start_swarm_span_with_contentblock_task(mock_tracer):
236+
"""Test starting a swarm call span with task as list of contentBlock."""
237+
with mock.patch("strands.telemetry.tracer.trace_api.get_tracer", return_value=mock_tracer):
238+
tracer = Tracer()
239+
tracer.tracer = mock_tracer
240+
241+
mock_span = mock.MagicMock()
242+
mock_tracer.start_span.return_value = mock_span
243+
244+
task = [ContentBlock(text="Original Task: foo bar")]
245+
246+
span = tracer.start_multiagent_span(task, "swarm")
247+
248+
mock_tracer.start_span.assert_called_once()
249+
assert mock_tracer.start_span.call_args[1]["name"] == "invoke_swarm"
250+
mock_span.set_attribute.assert_any_call("gen_ai.system", "strands-agents")
251+
mock_span.set_attribute.assert_any_call("gen_ai.agent.name", "swarm")
252+
mock_span.set_attribute.assert_any_call("gen_ai.operation.name", "invoke_swarm")
253+
mock_span.add_event.assert_any_call(
254+
"gen_ai.user.message", attributes={"content": '[{"text": "Original Task: foo bar"}]'}
255+
)
256+
assert span is not None
257+
258+
259+
def test_end_swarm_span(mock_span):
260+
"""Test ending a tool call span."""
261+
tracer = Tracer()
262+
swarm_final_reuslt = "foo bar bar"
263+
264+
tracer.end_swarm_span(mock_span, swarm_final_reuslt)
265+
266+
mock_span.add_event.assert_called_with(
267+
"gen_ai.choice",
268+
attributes={"message": "foo bar bar"},
269+
)
270+
271+
272+
def test_start_graph_call_span(mock_tracer):
273+
"""Test starting a graph call span."""
274+
with mock.patch("strands.telemetry.tracer.trace_api.get_tracer", return_value=mock_tracer):
275+
tracer = Tracer()
276+
tracer.tracer = mock_tracer
277+
278+
mock_span = mock.MagicMock()
279+
mock_tracer.start_span.return_value = mock_span
280+
281+
tool = {"name": "test-tool", "toolUseId": "123", "input": {"param": "value"}}
282+
283+
span = tracer.start_tool_call_span(tool)
284+
285+
mock_tracer.start_span.assert_called_once()
286+
assert mock_tracer.start_span.call_args[1]["name"] == "execute_tool test-tool"
202287
mock_span.set_attribute.assert_any_call("gen_ai.tool.name", "test-tool")
203288
mock_span.set_attribute.assert_any_call("gen_ai.system", "strands-agents")
204289
mock_span.set_attribute.assert_any_call("gen_ai.operation.name", "execute_tool")

0 commit comments

Comments
 (0)