Skip to content

Commit de986ef

Browse files
mjschockclaude
andcommitted
feat: add RunState parameter support to Runner.run() methods
This commit integrates RunState into the Runner API, allowing runs to be resumed from a saved state. This is the final piece needed to make human-in-the-loop (HITL) tool approval fully functional. **Changes:** 1. **Import NextStepInterruption** (run.py:21-32) - Added NextStepInterruption to imports from _run_impl - Added RunState import 2. **Updated Method Signatures** (run.py:285-444) - Runner.run(): Added `RunState[TContext]` to input union type - Runner.run_sync(): Added `RunState[TContext]` to input union type - Runner.run_streamed(): Added `RunState[TContext]` to input union type - AgentRunner.run(): Added `RunState[TContext]` to input union type - AgentRunner.run_sync(): Added `RunState[TContext]` to input union type - AgentRunner.run_streamed(): Added `RunState[TContext]` to input union type 3. **RunState Resumption Logic** (run.py:524-584) - Check if input is RunState instance - Extract state fields when resuming: current_turn, original_input, generated_items, model_responses, context_wrapper - Prime server conversation tracker from model_responses if resuming - Cast context_wrapper to correct type after extraction 4. **Interruption Handling** (run.py:689-726) - Added `interruptions=[]` to successful RunResult creation - Added elif branch for NextStepInterruption - Return RunResult with interruptions when tool approval needed - Set final_output to None for interrupted runs 5. **RunResultStreaming Support** (run.py:879-918) - Handle RunState input for streaming runs - Added `interruptions=[]` field to RunResultStreaming creation - Extract original_input from RunState for result **How It Works:** When resuming from RunState: ```python # User approves/rejects tool calls on the state run_state.approve(approval_item) # Resume the run from where it left off result = await Runner.run(agent, run_state) ``` When a tool needs approval: 1. Run pauses at tool execution 2. Returns RunResult with interruptions=[ToolApprovalItem(...)] 3. User can inspect interruptions and approve/reject 4. User resumes by passing RunResult back to Runner.run() **Remaining Work:** - Add `state` property to RunResult for creating RunState from results - Add comprehensive tests - Add documentation/examples 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
1 parent 44cdeb4 commit de986ef

File tree

1 file changed

+78
-21
lines changed

1 file changed

+78
-21
lines changed

src/agents/run.py

Lines changed: 78 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
AgentToolUseTracker,
2323
NextStepFinalOutput,
2424
NextStepHandoff,
25+
NextStepInterruption,
2526
NextStepRunAgain,
2627
QueueCompleteSentinel,
2728
RunImpl,
@@ -65,6 +66,7 @@
6566
from .models.multi_provider import MultiProvider
6667
from .result import RunResult, RunResultStreaming
6768
from .run_context import RunContextWrapper, TContext
69+
from .run_state import RunState
6870
from .stream_events import (
6971
AgentUpdatedStreamEvent,
7072
RawResponsesStreamEvent,
@@ -283,7 +285,7 @@ class Runner:
283285
async def run(
284286
cls,
285287
starting_agent: Agent[TContext],
286-
input: str | list[TResponseInputItem],
288+
input: str | list[TResponseInputItem] | RunState[TContext],
287289
*,
288290
context: TContext | None = None,
289291
max_turns: int = DEFAULT_MAX_TURNS,
@@ -358,7 +360,7 @@ async def run(
358360
def run_sync(
359361
cls,
360362
starting_agent: Agent[TContext],
361-
input: str | list[TResponseInputItem],
363+
input: str | list[TResponseInputItem] | RunState[TContext],
362364
*,
363365
context: TContext | None = None,
364366
max_turns: int = DEFAULT_MAX_TURNS,
@@ -431,7 +433,7 @@ def run_sync(
431433
def run_streamed(
432434
cls,
433435
starting_agent: Agent[TContext],
434-
input: str | list[TResponseInputItem],
436+
input: str | list[TResponseInputItem] | RunState[TContext],
435437
context: TContext | None = None,
436438
max_turns: int = DEFAULT_MAX_TURNS,
437439
hooks: RunHooks[TContext] | None = None,
@@ -506,7 +508,7 @@ class AgentRunner:
506508
async def run(
507509
self,
508510
starting_agent: Agent[TContext],
509-
input: str | list[TResponseInputItem],
511+
input: str | list[TResponseInputItem] | RunState[TContext],
510512
**kwargs: Unpack[RunOptions[TContext]],
511513
) -> RunResult:
512514
context = kwargs.get("context")
@@ -519,19 +521,41 @@ async def run(
519521
if run_config is None:
520522
run_config = RunConfig()
521523

524+
# Check if we're resuming from a RunState
525+
is_resumed_state = isinstance(input, RunState)
526+
run_state: RunState[TContext] | None = None
527+
528+
if is_resumed_state:
529+
# Resuming from a saved state
530+
run_state = cast(RunState[TContext], input)
531+
original_user_input = run_state._original_input
532+
prepared_input = run_state._original_input
533+
534+
# Override context with the state's context if not provided
535+
if context is None and run_state._context is not None:
536+
context = run_state._context.context
537+
else:
538+
# Keep original user input separate from session-prepared input
539+
raw_input = cast(str | list[TResponseInputItem], input)
540+
original_user_input = raw_input
541+
prepared_input = await self._prepare_input_with_session(
542+
raw_input, session, run_config.session_input_callback
543+
)
544+
522545
if conversation_id is not None or previous_response_id is not None:
523546
server_conversation_tracker = _ServerConversationTracker(
524547
conversation_id=conversation_id, previous_response_id=previous_response_id
525548
)
526549
else:
527550
server_conversation_tracker = None
528551

529-
# Keep original user input separate from session-prepared input
530-
original_user_input = input
531-
prepared_input = await self._prepare_input_with_session(
532-
input, session, run_config.session_input_callback
533-
)
552+
# Prime the server conversation tracker from state if resuming
553+
if server_conversation_tracker is not None and is_resumed_state and run_state is not None:
554+
for response in run_state._model_responses:
555+
server_conversation_tracker.track_server_items(response)
534556

557+
# Always create a fresh tool_use_tracker
558+
# (it's rebuilt from the run state if needed during execution)
535559
tool_use_tracker = AgentToolUseTracker()
536560

537561
with TraceCtxManager(
@@ -541,14 +565,23 @@ async def run(
541565
metadata=run_config.trace_metadata,
542566
disabled=run_config.tracing_disabled,
543567
):
544-
current_turn = 0
545-
original_input: str | list[TResponseInputItem] = _copy_str_or_list(prepared_input)
546-
generated_items: list[RunItem] = []
547-
model_responses: list[ModelResponse] = []
548-
549-
context_wrapper: RunContextWrapper[TContext] = RunContextWrapper(
550-
context=context, # type: ignore
551-
)
568+
if is_resumed_state and run_state is not None:
569+
# Restore state from RunState
570+
current_turn = run_state._current_turn
571+
original_input = run_state._original_input
572+
generated_items = run_state._generated_items
573+
model_responses = run_state._model_responses
574+
# Cast to the correct type since we know this is TContext
575+
context_wrapper = cast(RunContextWrapper[TContext], run_state._context)
576+
else:
577+
# Fresh run
578+
current_turn = 0
579+
original_input = _copy_str_or_list(prepared_input)
580+
generated_items = []
581+
model_responses = []
582+
context_wrapper = RunContextWrapper(
583+
context=context, # type: ignore
584+
)
552585

553586
input_guardrail_results: list[InputGuardrailResult] = []
554587
tool_input_guardrail_results: list[ToolInputGuardrailResult] = []
@@ -666,6 +699,7 @@ async def run(
666699
tool_input_guardrail_results=tool_input_guardrail_results,
667700
tool_output_guardrail_results=tool_output_guardrail_results,
668701
context_wrapper=context_wrapper,
702+
interruptions=[],
669703
)
670704
if not any(
671705
guardrail_result.output.tripwire_triggered
@@ -675,6 +709,22 @@ async def run(
675709
session, [], turn_result.new_step_items
676710
)
677711

712+
return result
713+
elif isinstance(turn_result.next_step, NextStepInterruption):
714+
# Tool approval is needed - return a result with interruptions
715+
result = RunResult(
716+
input=original_input,
717+
new_items=generated_items,
718+
raw_responses=model_responses,
719+
final_output=None,
720+
_last_agent=current_agent,
721+
input_guardrail_results=input_guardrail_results,
722+
output_guardrail_results=[],
723+
tool_input_guardrail_results=tool_input_guardrail_results,
724+
tool_output_guardrail_results=tool_output_guardrail_results,
725+
context_wrapper=context_wrapper,
726+
interruptions=turn_result.next_step.interruptions,
727+
)
678728
return result
679729
elif isinstance(turn_result.next_step, NextStepHandoff):
680730
current_agent = cast(Agent[TContext], turn_result.next_step.new_agent)
@@ -711,7 +761,7 @@ async def run(
711761
def run_sync(
712762
self,
713763
starting_agent: Agent[TContext],
714-
input: str | list[TResponseInputItem],
764+
input: str | list[TResponseInputItem] | RunState[TContext],
715765
**kwargs: Unpack[RunOptions[TContext]],
716766
) -> RunResult:
717767
context = kwargs.get("context")
@@ -790,7 +840,7 @@ def run_sync(
790840
def run_streamed(
791841
self,
792842
starting_agent: Agent[TContext],
793-
input: str | list[TResponseInputItem],
843+
input: str | list[TResponseInputItem] | RunState[TContext],
794844
**kwargs: Unpack[RunOptions[TContext]],
795845
) -> RunResultStreaming:
796846
context = kwargs.get("context")
@@ -824,8 +874,14 @@ def run_streamed(
824874
context=context # type: ignore
825875
)
826876

877+
# Handle RunState input
878+
if isinstance(input, RunState):
879+
input_for_result = input._original_input
880+
else:
881+
input_for_result = input
882+
827883
streamed_result = RunResultStreaming(
828-
input=_copy_str_or_list(input),
884+
input=_copy_str_or_list(input_for_result),
829885
new_items=[],
830886
current_agent=starting_agent,
831887
raw_responses=[],
@@ -840,12 +896,13 @@ def run_streamed(
840896
_current_agent_output_schema=output_schema,
841897
trace=new_trace,
842898
context_wrapper=context_wrapper,
899+
interruptions=[],
843900
)
844901

845902
# Kick off the actual agent loop in the background and return the streamed result object.
846903
streamed_result._run_impl_task = asyncio.create_task(
847904
self._start_streaming(
848-
starting_input=input,
905+
starting_input=input_for_result,
849906
streamed_result=streamed_result,
850907
starting_agent=starting_agent,
851908
max_turns=max_turns,

0 commit comments

Comments
 (0)