2727 ModelStopReason ,
2828 StartEvent ,
2929 StartEventLoopEvent ,
30+ ToolInterruptEvent ,
3031 ToolResultMessageEvent ,
3132 TypedEvent ,
3233)
@@ -106,13 +107,19 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) ->
106107 )
107108 invocation_state ["event_loop_cycle_span" ] = cycle_span
108109
109- model_events = _handle_model_execution ( agent , cycle_span , cycle_trace , invocation_state , tracer )
110- async for model_event in model_events :
111- if not isinstance ( model_event , ModelStopReason ):
112- yield model_event
110+ # Skipping model invocation if in interrupt state as interrupts are currently only supported for tool calls.
111+ if agent . _interrupt_state . activated :
112+ stop_reason : StopReason = "tool_use"
113+ message = agent . _interrupt_state . context [ "tool_use_message" ]
113114
114- stop_reason , message , * _ = model_event ["stop" ]
115- yield ModelMessageEvent (message = message )
115+ else :
116+ model_events = _handle_model_execution (agent , cycle_span , cycle_trace , invocation_state , tracer )
117+ async for model_event in model_events :
118+ if not isinstance (model_event , ModelStopReason ):
119+ yield model_event
120+
121+ stop_reason , message , * _ = model_event ["stop" ]
122+ yield ModelMessageEvent (message = message )
116123
117124 try :
118125 if stop_reason == "max_tokens" :
@@ -142,6 +149,7 @@ async def event_loop_cycle(agent: "Agent", invocation_state: dict[str, Any]) ->
142149 cycle_span = cycle_span ,
143150 cycle_start_time = cycle_start_time ,
144151 invocation_state = invocation_state ,
152+ tracer = tracer ,
145153 )
146154 async for tool_event in tool_events :
147155 yield tool_event
@@ -345,6 +353,7 @@ async def _handle_tool_execution(
345353 cycle_span : Any ,
346354 cycle_start_time : float ,
347355 invocation_state : dict [str , Any ],
356+ tracer : Tracer ,
348357) -> AsyncGenerator [TypedEvent , None ]:
349358 """Handles the execution of tools requested by the model during an event loop cycle.
350359
@@ -356,6 +365,7 @@ async def _handle_tool_execution(
356365 cycle_span: Span object for tracing the cycle (type may vary).
357366 cycle_start_time: Start time of the current cycle.
358367 invocation_state: Additional keyword arguments, including request state.
368+ tracer: Tracer instance for span management.
359369
360370 Yields:
361371 Tool stream events along with events yielded from a recursive call to the event loop. The last event is a tuple
@@ -375,15 +385,45 @@ async def _handle_tool_execution(
375385 yield EventLoopStopEvent (stop_reason , message , agent .event_loop_metrics , invocation_state ["request_state" ])
376386 return
377387
388+ if agent ._interrupt_state .activated :
389+ tool_results .extend (agent ._interrupt_state .context ["tool_results" ])
390+
391+ # Filter to only the interrupted tools when resuming from interrupt (tool uses without results)
392+ tool_use_ids = {tool_result ["toolUseId" ] for tool_result in tool_results }
393+ tool_uses = [tool_use for tool_use in tool_uses if tool_use ["toolUseId" ] not in tool_use_ids ]
394+
395+ interrupts = []
378396 tool_events = agent .tool_executor ._execute (
379397 agent , tool_uses , tool_results , cycle_trace , cycle_span , invocation_state
380398 )
381399 async for tool_event in tool_events :
400+ if isinstance (tool_event , ToolInterruptEvent ):
401+ interrupts .extend (tool_event ["tool_interrupt_event" ]["interrupts" ])
402+
382403 yield tool_event
383404
384405 # Store parent cycle ID for the next cycle
385406 invocation_state ["event_loop_parent_cycle_id" ] = invocation_state ["event_loop_cycle_id" ]
386407
408+ if interrupts :
409+ # Session state stored on AfterInvocationEvent.
410+ agent ._interrupt_state .activate (context = {"tool_use_message" : message , "tool_results" : tool_results })
411+
412+ agent .event_loop_metrics .end_cycle (cycle_start_time , cycle_trace )
413+ yield EventLoopStopEvent (
414+ "interrupt" ,
415+ message ,
416+ agent .event_loop_metrics ,
417+ invocation_state ["request_state" ],
418+ interrupts ,
419+ )
420+ if cycle_span :
421+ tracer .end_event_loop_cycle_span (span = cycle_span , message = message )
422+
423+ return
424+
425+ agent ._interrupt_state .deactivate ()
426+
387427 tool_result_message : Message = {
388428 "role" : "user" ,
389429 "content" : [{"toolResult" : result } for result in tool_results ],
@@ -394,7 +434,6 @@ async def _handle_tool_execution(
394434 yield ToolResultMessageEvent (message = tool_result_message )
395435
396436 if cycle_span :
397- tracer = get_tracer ()
398437 tracer .end_event_loop_cycle_span (span = cycle_span , message = message , tool_result_message = tool_result_message )
399438
400439 if invocation_state ["request_state" ].get ("stop_event_loop" , False ):
0 commit comments