Skip to content

Commit 428750b

Browse files
authored
event loop - handle model execution (#958)
1 parent 2493545 commit 428750b

File tree

1 file changed

+135
-99
lines changed

1 file changed

+135
-99
lines changed

src/strands/event_loop/event_loop.py

Lines changed: 135 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
from ..hooks import AfterModelCallEvent, BeforeModelCallEvent, MessageAddedEvent
1919
from ..telemetry.metrics import Trace
20-
from ..telemetry.tracer import get_tracer
20+
from ..telemetry.tracer import Tracer, get_tracer
2121
from ..tools._validator import validate_and_prepare_tools
2222
from ..types._events import (
2323
EventLoopStopEvent,
@@ -37,7 +37,7 @@
3737
MaxTokensReachedException,
3838
ModelThrottledException,
3939
)
40-
from ..types.streaming import Metrics, StopReason
40+
from ..types.streaming import StopReason
4141
from ..types.tools import ToolResult, ToolUse
4242
from ._recover_message_on_max_tokens_reached import recover_message_on_max_tokens_reached
4343
from .streaming import stream_messages
@@ -106,16 +106,142 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) ->
106106
)
107107
invocation_state["event_loop_cycle_span"] = cycle_span
108108

109+
model_events = _handle_model_execution(agent, cycle_span, cycle_trace, invocation_state, tracer)
110+
async for model_event in model_events:
111+
if not isinstance(model_event, ModelStopReason):
112+
yield model_event
113+
114+
stop_reason, message, *_ = model_event["stop"]
115+
yield ModelMessageEvent(message=message)
116+
117+
try:
118+
if stop_reason == "max_tokens":
119+
"""
120+
Handle max_tokens limit reached by the model.
121+
122+
When the model reaches its maximum token limit, this represents a potentially unrecoverable
123+
state where the model's response was truncated. By default, Strands fails hard with an
124+
MaxTokensReachedException to maintain consistency with other failure types.
125+
"""
126+
raise MaxTokensReachedException(
127+
message=(
128+
"Agent has reached an unrecoverable state due to max_tokens limit. "
129+
"For more information see: "
130+
"https://strandsagents.com/latest/user-guide/concepts/agents/agent-loop/#maxtokensreachedexception"
131+
)
132+
)
133+
134+
# If the model is requesting to use tools
135+
if stop_reason == "tool_use":
136+
# Handle tool execution
137+
tool_events = _handle_tool_execution(
138+
stop_reason,
139+
message,
140+
agent=agent,
141+
cycle_trace=cycle_trace,
142+
cycle_span=cycle_span,
143+
cycle_start_time=cycle_start_time,
144+
invocation_state=invocation_state,
145+
)
146+
async for tool_event in tool_events:
147+
yield tool_event
148+
149+
return
150+
151+
# End the cycle and return results
152+
agent.event_loop_metrics.end_cycle(cycle_start_time, cycle_trace, attributes)
153+
if cycle_span:
154+
tracer.end_event_loop_cycle_span(
155+
span=cycle_span,
156+
message=message,
157+
)
158+
except EventLoopException as e:
159+
if cycle_span:
160+
tracer.end_span_with_error(cycle_span, str(e), e)
161+
162+
# Don't yield or log the exception - we already did it when we
163+
# raised the exception and we don't need that duplication.
164+
raise
165+
except (ContextWindowOverflowException, MaxTokensReachedException) as e:
166+
# Special cased exceptions which we want to bubble up rather than get wrapped in an EventLoopException
167+
if cycle_span:
168+
tracer.end_span_with_error(cycle_span, str(e), e)
169+
raise e
170+
except Exception as e:
171+
if cycle_span:
172+
tracer.end_span_with_error(cycle_span, str(e), e)
173+
174+
# Handle any other exceptions
175+
yield ForceStopEvent(reason=e)
176+
logger.exception("cycle failed")
177+
raise EventLoopException(e, invocation_state["request_state"]) from e
178+
179+
yield EventLoopStopEvent(stop_reason, message, agent.event_loop_metrics, invocation_state["request_state"])
180+
181+
182+
async def recurse_event_loop(agent: "Agent", invocation_state: dict[str, Any]) -> AsyncGenerator[TypedEvent, None]:
183+
"""Make a recursive call to event_loop_cycle with the current state.
184+
185+
This function is used when the event loop needs to continue processing after tool execution.
186+
187+
Args:
188+
agent: Agent for which the recursive call is being made.
189+
invocation_state: Arguments to pass through event_loop_cycle
190+
191+
192+
Yields:
193+
Results from event_loop_cycle where the last result contains:
194+
195+
- StopReason: Reason the model stopped generating
196+
- Message: The generated message from the model
197+
- EventLoopMetrics: Updated metrics for the event loop
198+
- Any: Updated request state
199+
"""
200+
cycle_trace = invocation_state["event_loop_cycle_trace"]
201+
202+
# Recursive call trace
203+
recursive_trace = Trace("Recursive call", parent_id=cycle_trace.id)
204+
cycle_trace.add_child(recursive_trace)
205+
206+
yield StartEvent()
207+
208+
events = event_loop_cycle(agent=agent, invocation_state=invocation_state)
209+
async for event in events:
210+
yield event
211+
212+
recursive_trace.end()
213+
214+
215+
async def _handle_model_execution(
216+
agent: "Agent",
217+
cycle_span: Any,
218+
cycle_trace: Trace,
219+
invocation_state: dict[str, Any],
220+
tracer: Tracer,
221+
) -> AsyncGenerator[TypedEvent, None]:
222+
"""Handle model execution with retry logic for throttling exceptions.
223+
224+
Executes the model inference with automatic retry handling for throttling exceptions.
225+
Manages tracing, hooks, and metrics collection throughout the process.
226+
227+
Args:
228+
agent: The agent executing the model.
229+
cycle_span: Span object for tracing the cycle.
230+
cycle_trace: Trace object for the current event loop cycle.
231+
invocation_state: State maintained across cycles.
232+
tracer: Tracer instance for span management.
233+
234+
Yields:
235+
Model stream events and throttle events during retries.
236+
237+
Raises:
238+
ModelThrottledException: If max retry attempts are exceeded.
239+
Exception: Any other model execution errors.
240+
"""
109241
# Create a trace for the stream_messages call
110242
stream_trace = Trace("stream_messages", parent_id=cycle_trace.id)
111243
cycle_trace.add_child(stream_trace)
112244

113-
# Process messages with exponential backoff for throttling
114-
message: Message
115-
stop_reason: StopReason
116-
usage: Any
117-
metrics: Metrics
118-
119245
# Retry loop for handling throttling exceptions
120246
current_delay = INITIAL_DELAY
121247
for attempt in range(MAX_ATTEMPTS):
@@ -136,8 +262,7 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) ->
136262

137263
try:
138264
async for event in stream_messages(agent.model, agent.system_prompt, agent.messages, tool_specs):
139-
if not isinstance(event, ModelStopReason):
140-
yield event
265+
yield event
141266

142267
stop_reason, message, usage, metrics = event["stop"]
143268
invocation_state.setdefault("request_state", {})
@@ -198,108 +323,19 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) ->
198323
# Add the response message to the conversation
199324
agent.messages.append(message)
200325
agent.hooks.invoke_callbacks(MessageAddedEvent(agent=agent, message=message))
201-
yield ModelMessageEvent(message=message)
202326

203327
# Update metrics
204328
agent.event_loop_metrics.update_usage(usage)
205329
agent.event_loop_metrics.update_metrics(metrics)
206330

207-
if stop_reason == "max_tokens":
208-
"""
209-
Handle max_tokens limit reached by the model.
210-
211-
When the model reaches its maximum token limit, this represents a potentially unrecoverable
212-
state where the model's response was truncated. By default, Strands fails hard with an
213-
MaxTokensReachedException to maintain consistency with other failure types.
214-
"""
215-
raise MaxTokensReachedException(
216-
message=(
217-
"Agent has reached an unrecoverable state due to max_tokens limit. "
218-
"For more information see: "
219-
"https://strandsagents.com/latest/user-guide/concepts/agents/agent-loop/#maxtokensreachedexception"
220-
)
221-
)
222-
223-
# If the model is requesting to use tools
224-
if stop_reason == "tool_use":
225-
# Handle tool execution
226-
events = _handle_tool_execution(
227-
stop_reason,
228-
message,
229-
agent=agent,
230-
cycle_trace=cycle_trace,
231-
cycle_span=cycle_span,
232-
cycle_start_time=cycle_start_time,
233-
invocation_state=invocation_state,
234-
)
235-
async for typed_event in events:
236-
yield typed_event
237-
238-
return
239-
240-
# End the cycle and return results
241-
agent.event_loop_metrics.end_cycle(cycle_start_time, cycle_trace, attributes)
242-
if cycle_span:
243-
tracer.end_event_loop_cycle_span(
244-
span=cycle_span,
245-
message=message,
246-
)
247-
except EventLoopException as e:
248-
if cycle_span:
249-
tracer.end_span_with_error(cycle_span, str(e), e)
250-
251-
# Don't yield or log the exception - we already did it when we
252-
# raised the exception and we don't need that duplication.
253-
raise
254-
except (ContextWindowOverflowException, MaxTokensReachedException) as e:
255-
# Special cased exceptions which we want to bubble up rather than get wrapped in an EventLoopException
256-
if cycle_span:
257-
tracer.end_span_with_error(cycle_span, str(e), e)
258-
raise e
259331
except Exception as e:
260332
if cycle_span:
261333
tracer.end_span_with_error(cycle_span, str(e), e)
262334

263-
# Handle any other exceptions
264335
yield ForceStopEvent(reason=e)
265336
logger.exception("cycle failed")
266337
raise EventLoopException(e, invocation_state["request_state"]) from e
267338

268-
yield EventLoopStopEvent(stop_reason, message, agent.event_loop_metrics, invocation_state["request_state"])
269-
270-
271-
async def recurse_event_loop(agent: "Agent", invocation_state: dict[str, Any]) -> AsyncGenerator[TypedEvent, None]:
272-
"""Make a recursive call to event_loop_cycle with the current state.
273-
274-
This function is used when the event loop needs to continue processing after tool execution.
275-
276-
Args:
277-
agent: Agent for which the recursive call is being made.
278-
invocation_state: Arguments to pass through event_loop_cycle
279-
280-
281-
Yields:
282-
Results from event_loop_cycle where the last result contains:
283-
284-
- StopReason: Reason the model stopped generating
285-
- Message: The generated message from the model
286-
- EventLoopMetrics: Updated metrics for the event loop
287-
- Any: Updated request state
288-
"""
289-
cycle_trace = invocation_state["event_loop_cycle_trace"]
290-
291-
# Recursive call trace
292-
recursive_trace = Trace("Recursive call", parent_id=cycle_trace.id)
293-
cycle_trace.add_child(recursive_trace)
294-
295-
yield StartEvent()
296-
297-
events = event_loop_cycle(agent=agent, invocation_state=invocation_state)
298-
async for event in events:
299-
yield event
300-
301-
recursive_trace.end()
302-
303339

304340
async def _handle_tool_execution(
305341
stop_reason: StopReason,

0 commit comments

Comments
 (0)