diff --git a/src/agents/run.py b/src/agents/run.py index 5b25df4f2..6b4418516 100644 --- a/src/agents/run.py +++ b/src/agents/run.py @@ -1072,6 +1072,7 @@ async def _start_streaming( tool_use_tracker, all_tools, server_conversation_tracker, + session, ) should_run_agent_start_hooks = False @@ -1107,7 +1108,7 @@ async def _start_streaming( AgentUpdatedStreamEvent(new_agent=current_agent) ) - # Check for soft cancel after handoff + # Check for soft cancel after handoff (before next turn) if streamed_result._cancel_mode == "after_turn": # type: ignore[comparison-overlap] streamed_result.is_complete = True streamed_result._event_queue.put_nowait(QueueCompleteSentinel()) @@ -1158,7 +1159,7 @@ async def _start_streaming( session, [], turn_result.new_step_items ) - # Check for soft cancel after turn completion + # Check for soft cancel after tool execution completes (before next turn) if streamed_result._cancel_mode == "after_turn": # type: ignore[comparison-overlap] streamed_result.is_complete = True streamed_result._event_queue.put_nowait(QueueCompleteSentinel()) @@ -1217,6 +1218,7 @@ async def _run_single_turn_streamed( tool_use_tracker: AgentToolUseTracker, all_tools: list[Tool], server_conversation_tracker: _ServerConversationTracker | None = None, + session: Session | None = None, ) -> SingleStepResult: emitted_tool_call_ids: set[str] = set() emitted_reasoning_item_ids: set[str] = set() @@ -1369,6 +1371,114 @@ async def _run_single_turn_streamed( if not final_response: raise ModelBehaviorError("Model did not produce a final response!") + # Check for soft cancel after LLM response streaming completes (before tool execution) + # Only cancel here if there are no tools/handoffs to execute - otherwise let tools execute + # and the cancel will be honored after tool execution completes + if streamed_result._cancel_mode == "after_turn": + # Process the model response to check if there are tools/handoffs to execute + processed_response = RunImpl.process_model_response( + agent=agent, + all_tools=all_tools, + response=final_response, + output_schema=output_schema, + handoffs=handoffs, + ) + + # If there are tools, handoffs, or approvals to execute, let normal flow continue + # The cancel will be honored after tool execution completes (before next step) + if processed_response.has_tools_or_approvals_to_run() or processed_response.handoffs: + # Continue with normal flow - tools will execute, + # then cancel after execution completes + pass + else: + # No tools/handoffs to execute - safe to cancel here and skip tool execution + # Note: We intentionally skip execute_tools_and_side_effects() since there are + # no tools to execute. This allows faster cancellation when the LLM response + # contains no actions. + tool_use_tracker.add_tool_use(agent, processed_response.tools_used) + + # Filter out items that have already been sent to avoid duplicates + items_to_save = list(processed_response.new_items) + + if emitted_tool_call_ids: + # Filter out tool call items that were already emitted during streaming + items_to_save = [ + item + for item in items_to_save + if not ( + isinstance(item, ToolCallItem) + and ( + call_id := getattr( + item.raw_item, "call_id", getattr(item.raw_item, "id", None) + ) + ) + and call_id in emitted_tool_call_ids + ) + ] + + if emitted_reasoning_item_ids: + # Filter out reasoning items that were already emitted during streaming + items_to_save = [ + item + for item in items_to_save + if not ( + isinstance(item, ReasoningItem) + and (reasoning_id := getattr(item.raw_item, "id", None)) + and reasoning_id in emitted_reasoning_item_ids + ) + ] + + # Filter out HandoffCallItem to avoid duplicates (already sent earlier) + items_to_save = [ + item for item in items_to_save if not isinstance(item, HandoffCallItem) + ] + + # Create SingleStepResult with NextStepRunAgain (we're stopping mid-turn) + single_step_result = SingleStepResult( + original_input=streamed_result.input, + model_response=final_response, + pre_step_items=streamed_result.new_items, + new_step_items=items_to_save, + next_step=NextStepRunAgain(), + tool_input_guardrail_results=[], + tool_output_guardrail_results=[], + ) + + # Save session with the model response items + # Exclude ToolCallItem objects to avoid saving incomplete tool calls without outputs + if session is not None: + should_skip_session_save = ( + await AgentRunner._input_guardrail_tripwire_triggered_for_stream( + streamed_result + ) + ) + if should_skip_session_save is False: + # Filter out tool calls - they don't have outputs yet, so shouldn't be saved + # This prevents saving incomplete tool calls that violate API requirements + items_for_session = [ + item for item in items_to_save if not isinstance(item, ToolCallItem) + ] + # Type ignore: intentionally filtering out ToolCallItem to avoid saving + # incomplete tool calls without corresponding outputs + await AgentRunner._save_result_to_session( + session, + [], + items_for_session, # type: ignore[arg-type] + ) + + # Stream the items to the event queue + import dataclasses as _dc + + RunImpl.stream_step_result_to_queue( + single_step_result, streamed_result._event_queue + ) + + # Mark as complete and signal completion + streamed_result.is_complete = True + streamed_result._event_queue.put_nowait(QueueCompleteSentinel()) + + return single_step_result + # 3. Now, we can process the turn as we do in the non-streaming case single_step_result = await cls._get_single_step_result_from_response( agent=agent, diff --git a/tests/test_soft_cancel.py b/tests/test_soft_cancel.py index 395f2fb6f..1f2d1044d 100644 --- a/tests/test_soft_cancel.py +++ b/tests/test_soft_cancel.py @@ -4,7 +4,8 @@ import pytest -from agents import Agent, Runner, SQLiteSession +from agents import Agent, OutputGuardrail, Runner, SQLiteSession +from agents.guardrail import GuardrailFunctionOutput from .fake_model import FakeModel from .test_responses import get_function_tool, get_function_tool_call, get_text_message @@ -87,13 +88,15 @@ async def test_soft_cancel_with_tool_calls(): if event.type == "run_item_stream_event": if event.name == "tool_called": tool_call_seen = True - # Cancel right after seeing tool call + # Cancel right after seeing tool call - tools will execute + # then cancel is honored after tool execution completes result.cancel(mode="after_turn") elif event.name == "tool_output": tool_output_seen = True assert tool_call_seen, "Tool call should be seen" - assert tool_output_seen, "Tool output should be seen (tool should execute before soft cancel)" + assert tool_output_seen, "Tool output SHOULD be seen (tools execute before cancel is honored)" + assert result.is_complete, "Result should be marked complete" @pytest.mark.asyncio @@ -293,18 +296,25 @@ async def test_soft_cancel_with_multiple_tool_calls(): result = Runner.run_streamed(agent, input="Execute tools") + tool_calls_seen = 0 tool_outputs_seen = 0 async for event in result.stream_events(): if event.type == "run_item_stream_event": if event.name == "tool_called": - # Cancel after seeing first tool call - if tool_outputs_seen == 0: + tool_calls_seen += 1 + # Cancel after seeing first tool call - tools will execute + # then cancel is honored after tool execution completes + if tool_calls_seen == 1: result.cancel(mode="after_turn") elif event.name == "tool_output": tool_outputs_seen += 1 - # Both tools should execute - assert tool_outputs_seen == 2, "Both tools should execute before soft cancel" + # Tool calls should be seen, and tools SHOULD execute before cancel is honored + assert tool_calls_seen >= 1, "Tool calls should be seen" + assert tool_outputs_seen > 0, ( + "Tool outputs SHOULD be seen (tools execute before cancel is honored)" + ) + assert result.is_complete, "Result should be marked complete" @pytest.mark.asyncio @@ -476,3 +486,46 @@ async def test_soft_cancel_with_session_and_multiple_turns(): # Cleanup await session.clear_session() + + +@pytest.mark.asyncio +async def test_soft_cancel_runs_output_guardrails_before_canceling(): + """Verify output guardrails run even when cancellation happens after final output.""" + model = FakeModel() + + # Track if guardrail was called + guardrail_called = False + + def output_guardrail_fn(context, agent, output): + nonlocal guardrail_called + guardrail_called = True + return GuardrailFunctionOutput(output_info=None, tripwire_triggered=False) + + agent = Agent( + name="Assistant", + model=model, + output_guardrails=[OutputGuardrail(guardrail_function=output_guardrail_fn)], + ) + + # Setup: agent produces final output + model.add_multiple_turn_outputs([[get_text_message("Final answer")]]) + + result = Runner.run_streamed(agent, input="What is the answer?") + + # Cancel after seeing the message output event (indicates turn completed) + # but before consuming all events + async for event in result.stream_events(): + if event.type == "run_item_stream_event" and event.name == "message_output_created": + # Cancel after turn completes - guardrails should still run + result.cancel(mode="after_turn") + # Don't break - continue consuming to let guardrails complete + + # Guardrail should have been called + assert guardrail_called, "Output guardrail should run even when canceling after final output" + + # Final output should be set + assert result.final_output is not None, "final_output should be set even when canceling" + assert result.final_output == "Final answer" + + # Output guardrail results should be recorded + assert len(result.output_guardrail_results) == 1, "Output guardrail results should be recorded"