diff --git a/examples/basic/max_turns_resume.py b/examples/basic/max_turns_resume.py new file mode 100644 index 000000000..15be34e4f --- /dev/null +++ b/examples/basic/max_turns_resume.py @@ -0,0 +1,41 @@ +from typing import Annotated + +from agents import Agent, MaxTurnsExceeded, Runner, function_tool + + +@function_tool +def gather_facts(topic: Annotated[str, "The topic to investigate"]) -> str: + """Return placeholder research that simulates a tool lookup.""" + return ( + f"Key facts about {topic}: it moves through evaporation, condensation, " + "precipitation, and collection." + ) + + +def main(): + agent = Agent( + name="Researcher", + instructions=( + "You must call the gather_facts tool before answering. " + "Once you have the tool output, summarize it in your own words." + ), + tools=[gather_facts], + ) + + try: + Runner.run_sync( + agent, + input="Give me the main stages of the water cycle.", + max_turns=1, + ) + except MaxTurnsExceeded as max_turns_exc: + print("Reached the max turn limit. Asking the agent to finalize without tools...\n") + result = max_turns_exc.resume_sync( + "Finish the answer using the gathered information without calling tools again." + ) + print(result.final_output) + # The water cycle proceeds through evaporation, condensation, precipitation, and collection. + + +if __name__ == "__main__": + main() diff --git a/src/agents/exceptions.py b/src/agents/exceptions.py index 39518c39d..ac9007550 100644 --- a/src/agents/exceptions.py +++ b/src/agents/exceptions.py @@ -1,12 +1,15 @@ from __future__ import annotations -from dataclasses import dataclass +from dataclasses import dataclass, replace +from textwrap import dedent from typing import TYPE_CHECKING, Any if TYPE_CHECKING: from .agent import Agent from .guardrail import InputGuardrailResult, OutputGuardrailResult from .items import ModelResponse, RunItem, TResponseInputItem + from .result import RunResult + from .run import RunConfig from .run_context import RunContextWrapper from .tool_guardrails import ( ToolGuardrailFunctionOutput, @@ -28,6 +31,7 @@ class RunErrorDetails: context_wrapper: RunContextWrapper[Any] input_guardrail_results: list[InputGuardrailResult] output_guardrail_results: list[OutputGuardrailResult] + run_config: RunConfig def __str__(self) -> str: return pretty_print_run_error_details(self) @@ -48,10 +52,111 @@ class MaxTurnsExceeded(AgentsException): message: str + _DEFAULT_RESUME_PROMPT = """ + You reached the maximum number of turns. + Return a final answer to the query using ONLY the information already gathered \ + in the conversation so far. + """ + def __init__(self, message: str): self.message = message super().__init__(message) + async def resume(self, prompt: str | None = _DEFAULT_RESUME_PROMPT) -> RunResult: + """Resume the failed run asynchronously with a final, tool-free turn. + + Note: + This helper does not automatically reuse the original session object. + If you need the resumed turn to be persisted in the session, + run the follow-up turn manually with that information. + + Args: + prompt: Optional user instruction to append before rerunning the final turn. + Pass ``None`` to skip injecting an extra message; defaults to a reminder + to produce a final answer from existing context. + """ + run_data = self._require_run_data() + inputs, run_config = self._prepare_resume_arguments(run_data, prompt) + + from .run import Runner + + return await Runner.run( + starting_agent=run_data.last_agent, + input=inputs, + context=run_data.context_wrapper.context, + max_turns=1, + run_config=run_config, + ) + + def resume_sync(self, prompt: str | None = _DEFAULT_RESUME_PROMPT) -> RunResult: + """Resume the failed run synchronously with a final, tool-free turn. + + Note: + This helper does not automatically reuse the original session object. + If you need the resumed turn to be persisted in the session, + run the follow-up turn manually with that information. + + Args: + prompt: Optional user instruction to append before rerunning the final turn. + Pass ``None`` to skip injecting an extra message; defaults to a reminder + to produce a final answer from existing context. + """ + run_data = self._require_run_data() + inputs, run_config = self._prepare_resume_arguments(run_data, prompt) + + from .run import Runner + + return Runner.run_sync( + starting_agent=run_data.last_agent, + input=inputs, + context=run_data.context_wrapper.context, + max_turns=1, + run_config=run_config, + ) + + def _prepare_resume_arguments( + self, + run_data: RunErrorDetails, + prompt: str | None = None, + ) -> tuple[list[TResponseInputItem], RunConfig]: + from .items import ItemHelpers + from .model_settings import ModelSettings + + history: list[TResponseInputItem] = ItemHelpers.input_to_new_input_list(run_data.input) + for item in run_data.new_items: + history.append(item.to_input_item()) + + normalized_prompt = self._normalize_resume_prompt(prompt) + if normalized_prompt is not None: + history.append({"content": normalized_prompt, "role": "user"}) + + run_config = replace(run_data.run_config) + if run_config.model_settings is None: + run_config.model_settings = ModelSettings(tool_choice="none") + else: + run_config.model_settings = run_config.model_settings.resolve( + ModelSettings(tool_choice="none") + ) + + return ( + history, + run_config, + ) + + def _normalize_resume_prompt(self, prompt: str | None) -> str | None: + if prompt is None: + return None + normalized = dedent(prompt).strip() + return normalized or None + + def _require_run_data(self) -> RunErrorDetails: + if self.run_data is None: + raise RuntimeError( + "Run data is not available; resume() can only be called on\ + exceptions raised by Runner." + ) + return self.run_data + class ModelBehaviorError(AgentsException): """Exception raised when the model does something unexpected, e.g. calling a tool that doesn't diff --git a/src/agents/result.py b/src/agents/result.py index 3fe20cfa5..0aadfae87 100644 --- a/src/agents/result.py +++ b/src/agents/result.py @@ -3,7 +3,7 @@ import abc import asyncio from collections.abc import AsyncIterator -from dataclasses import dataclass, field +from dataclasses import dataclass, field, replace from typing import TYPE_CHECKING, Any, Literal, cast from typing_extensions import TypeVar @@ -31,6 +31,7 @@ if TYPE_CHECKING: from ._run_impl import QueueCompleteSentinel from .agent import Agent + from .run import RunConfig from .tool_guardrails import ToolInputGuardrailResult, ToolOutputGuardrailResult T = TypeVar("T") @@ -69,6 +70,9 @@ class RunResultBase(abc.ABC): context_wrapper: RunContextWrapper[Any] """The context wrapper for the agent run.""" + run_config: RunConfig + """The run configuration that was used for the agent run.""" + @property @abc.abstractmethod def last_agent(self) -> Agent[Any]: @@ -279,6 +283,7 @@ def _create_error_details(self) -> RunErrorDetails: context_wrapper=self.context_wrapper, input_guardrail_results=self.input_guardrail_results, output_guardrail_results=self.output_guardrail_results, + run_config=replace(self.run_config), ) def _check_errors(self): diff --git a/src/agents/run.py b/src/agents/run.py index 5b25df4f2..6a97cd352 100644 --- a/src/agents/run.py +++ b/src/agents/run.py @@ -5,7 +5,7 @@ import inspect import os import warnings -from dataclasses import dataclass, field +from dataclasses import dataclass, field, replace from typing import Any, Callable, Generic, cast, get_args from openai.types.responses import ( @@ -665,6 +665,7 @@ async def run( output_guardrail_results=output_guardrail_results, tool_input_guardrail_results=tool_input_guardrail_results, tool_output_guardrail_results=tool_output_guardrail_results, + run_config=replace(run_config), context_wrapper=context_wrapper, ) if not any( @@ -702,6 +703,7 @@ async def run( context_wrapper=context_wrapper, input_guardrail_results=input_guardrail_results, output_guardrail_results=[], + run_config=replace(run_config), ) raise finally: @@ -837,6 +839,7 @@ def run_streamed( output_guardrail_results=[], tool_input_guardrail_results=[], tool_output_guardrail_results=[], + run_config=replace(run_config), _current_agent_output_schema=output_schema, trace=new_trace, context_wrapper=context_wrapper, @@ -1174,6 +1177,7 @@ async def _start_streaming( context_wrapper=context_wrapper, input_guardrail_results=streamed_result.input_guardrail_results, output_guardrail_results=streamed_result.output_guardrail_results, + run_config=replace(run_config), ) raise except Exception as e: diff --git a/tests/extensions/memory/test_advanced_sqlite_session.py b/tests/extensions/memory/test_advanced_sqlite_session.py index 40edb99fe..b1a52177e 100644 --- a/tests/extensions/memory/test_advanced_sqlite_session.py +++ b/tests/extensions/memory/test_advanced_sqlite_session.py @@ -7,7 +7,7 @@ pytest.importorskip("sqlalchemy") # Skip tests if SQLAlchemy is not installed from openai.types.responses.response_usage import InputTokensDetails, OutputTokensDetails -from agents import Agent, Runner, TResponseInputItem, function_tool +from agents import Agent, RunConfig, Runner, TResponseInputItem, function_tool from agents.extensions.memory import AdvancedSQLiteSession from agents.result import RunResult from agents.run_context import RunContextWrapper @@ -74,6 +74,7 @@ def create_mock_run_result( tool_output_guardrail_results=[], context_wrapper=context_wrapper, _last_agent=agent, + run_config=RunConfig(), ) diff --git a/tests/test_max_turns.py b/tests/test_max_turns.py index f01bb18ff..1481ebe66 100644 --- a/tests/test_max_turns.py +++ b/tests/test_max_turns.py @@ -5,7 +5,7 @@ import pytest from typing_extensions import TypedDict -from agents import Agent, MaxTurnsExceeded, Runner +from agents import Agent, MaxTurnsExceeded, ModelSettings, RunConfig, Runner from .fake_model import FakeModel from .test_responses import get_function_tool, get_function_tool_call, get_text_message @@ -75,6 +75,109 @@ async def test_streamed_max_turns(): pass +@pytest.mark.asyncio +async def test_max_turns_resume_runs_final_turn(): + model = FakeModel() + agent = Agent( + name="test_1", + model=model, + tools=[get_function_tool("some_function", "result")], + ) + + func_output = json.dumps({"a": "b"}) + final_answer = "final answer" + + model.add_multiple_turn_outputs( + [ + [get_text_message("1"), get_function_tool_call("some_function", func_output)], + [get_text_message("2"), get_function_tool_call("some_function", func_output)], + [get_text_message(final_answer)], + ] + ) + + with pytest.raises(MaxTurnsExceeded) as exc_info: + await Runner.run(agent, input="user_message", max_turns=2) + + result = await exc_info.value.resume("Finish without tools.") + + assert result.final_output == final_answer + resume_input = model.last_turn_args["input"] + assert resume_input[0]["content"] == "user_message" + assert resume_input[-1] == {"content": "Finish without tools.", "role": "user"} + assert any(item.get("type") == "function_call_output" for item in resume_input) + assert model.last_turn_args["model_settings"].tool_choice == "none" + + +def test_max_turns_resume_sync_uses_default_prompt(): + model = FakeModel() + agent = Agent( + name="test_1", + model=model, + tools=[get_function_tool("some_function", "result")], + ) + + func_output = json.dumps({"a": "b"}) + final_answer = "final answer" + + model.add_multiple_turn_outputs( + [ + [get_text_message("1"), get_function_tool_call("some_function", func_output)], + [get_text_message("2"), get_function_tool_call("some_function", func_output)], + [get_text_message(final_answer)], + ] + ) + + with pytest.raises(MaxTurnsExceeded) as exc_info: + Runner.run_sync(agent, input="user_message", max_turns=2) + + resume_prompt = "Return a final answer to the query using ONLY the information already gathered" + result = exc_info.value.resume_sync(resume_prompt) + + assert result.final_output == final_answer + resume_input = model.last_turn_args["input"] + assert resume_input[-1] == {"content": resume_prompt, "role": "user"} + assert model.last_turn_args["model_settings"].tool_choice == "none" + + +@pytest.mark.asyncio +async def test_resume_preserves_run_config_settings(): + model = FakeModel() + agent = Agent( + name="test_1", + model=model, + tools=[get_function_tool("some_function", "result")], + ) + + func_output = json.dumps({"a": "b"}) + final_answer = "final answer" + + model.add_multiple_turn_outputs( + [ + [get_text_message("1"), get_function_tool_call("some_function", func_output)], + [get_text_message("2"), get_function_tool_call("some_function", func_output)], + [get_text_message(final_answer)], + ] + ) + + run_config = RunConfig(model_settings=ModelSettings(temperature=0.25, tool_choice="auto")) + + with pytest.raises(MaxTurnsExceeded) as exc_info: + await Runner.run(agent, input="user_message", max_turns=2, run_config=run_config) + + await exc_info.value.resume("Finish without tools.") + + final_settings = model.last_turn_args["model_settings"] + assert final_settings.temperature == 0.25 + assert final_settings.tool_choice == "none" + + run_data = exc_info.value.run_data + assert run_data is not None + stored_settings = run_data.run_config.model_settings + assert stored_settings is not None + assert stored_settings.temperature == 0.25 + assert stored_settings.tool_choice == "auto" + + class Foo(TypedDict): a: str diff --git a/tests/test_result_cast.py b/tests/test_result_cast.py index 4ef1a293d..baa886913 100644 --- a/tests/test_result_cast.py +++ b/tests/test_result_cast.py @@ -3,7 +3,7 @@ import pytest from pydantic import BaseModel -from agents import Agent, RunContextWrapper, RunResult +from agents import Agent, RunConfig, RunContextWrapper, RunResult def create_run_result(final_output: Any) -> RunResult: @@ -18,6 +18,7 @@ def create_run_result(final_output: Any) -> RunResult: tool_output_guardrail_results=[], _last_agent=Agent(name="test"), context_wrapper=RunContextWrapper(context=None), + run_config=RunConfig(), ) diff --git a/tests/test_run_error_details.py b/tests/test_run_error_details.py index 104b248fc..b92dee3d2 100644 --- a/tests/test_run_error_details.py +++ b/tests/test_run_error_details.py @@ -2,7 +2,7 @@ import pytest -from agents import Agent, MaxTurnsExceeded, RunErrorDetails, Runner +from agents import Agent, MaxTurnsExceeded, RunConfig, RunErrorDetails, Runner from .fake_model import FakeModel from .test_responses import get_function_tool, get_function_tool_call, get_text_message @@ -25,6 +25,7 @@ async def test_run_error_includes_data(): assert data.last_agent == agent assert len(data.raw_responses) == 1 assert len(data.new_items) > 0 + assert isinstance(data.run_config, RunConfig) @pytest.mark.asyncio @@ -46,3 +47,4 @@ async def test_streamed_run_error_includes_data(): assert data.last_agent == agent assert len(data.raw_responses) == 1 assert len(data.new_items) > 0 + assert isinstance(data.run_config, RunConfig)