Skip to content

Commit fd51efa

Browse files
authored
Fix model replacement on looping handoffs (#1122)
* Keep to be replaced agents in memory so they can be replaced in looping handoffs * Unrelated add clippy to CI * The new question results in continue as new which is not supported correctly in the workflow * Change error type
1 parent 2f04a16 commit fd51efa

File tree

3 files changed

+99
-57
lines changed

3 files changed

+99
-57
lines changed

.github/workflows/ci.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,8 @@ jobs:
5555
with:
5656
submodules: recursive
5757
- uses: dtolnay/rust-toolchain@stable
58+
with:
59+
components: "clippy"
5860
- uses: Swatinem/rust-cache@v2
5961
with:
6062
workspaces: temporalio/bridge -> target

temporalio/contrib/openai_agents/_openai_runner.py

Lines changed: 50 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import dataclasses
2-
import json
32
import typing
43
from typing import Any, Optional, Union
54

@@ -17,14 +16,62 @@
1716
TResponseInputItem,
1817
)
1918
from agents.run import DEFAULT_AGENT_RUNNER, DEFAULT_MAX_TURNS, AgentRunner
20-
from pydantic_core import to_json
2119

2220
from temporalio import workflow
2321
from temporalio.contrib.openai_agents._model_parameters import ModelActivityParameters
2422
from temporalio.contrib.openai_agents._temporal_model_stub import _TemporalModelStub
2523
from temporalio.contrib.openai_agents.workflow import AgentsWorkflowError
2624

2725

26+
# Recursively replace models in all agents
27+
def _convert_agent(
28+
model_params: ModelActivityParameters,
29+
agent: Agent[Any],
30+
seen: Optional[dict[int, Agent]],
31+
) -> Agent[Any]:
32+
if seen is None:
33+
seen = dict()
34+
35+
# Short circuit if this model was already seen to prevent looping from circular handoffs
36+
if id(agent) in seen:
37+
return seen[id(agent)]
38+
39+
# This agent has already been processed in some other run
40+
if isinstance(agent.model, _TemporalModelStub):
41+
return agent
42+
43+
# Save the new version of the agent so that we can replace loops
44+
new_agent = dataclasses.replace(agent)
45+
seen[id(agent)] = new_agent
46+
47+
name = _model_name(agent)
48+
49+
new_handoffs: list[Union[Agent, Handoff]] = []
50+
for handoff in agent.handoffs:
51+
if isinstance(handoff, Agent):
52+
new_handoffs.append(_convert_agent(model_params, handoff, seen))
53+
elif isinstance(handoff, Handoff):
54+
original_invoke = handoff.on_invoke_handoff
55+
56+
async def on_invoke(context: RunContextWrapper[Any], args: str) -> Agent:
57+
handoff_agent = await original_invoke(context, args)
58+
return _convert_agent(model_params, handoff_agent, seen)
59+
60+
new_handoffs.append(
61+
dataclasses.replace(handoff, on_invoke_handoff=on_invoke)
62+
)
63+
else:
64+
raise TypeError(f"Unknown handoff type: {type(handoff)}")
65+
66+
new_agent.model = _TemporalModelStub(
67+
model_name=name,
68+
model_params=model_params,
69+
agent=agent,
70+
)
71+
new_agent.handoffs = new_handoffs
72+
return new_agent
73+
74+
2875
class TemporalOpenAIRunner(AgentRunner):
2976
"""Temporal Runner for OpenAI agents.
3077
@@ -101,54 +148,9 @@ async def run(
101148
),
102149
)
103150

104-
# Recursively replace models in all agents
105-
def convert_agent(agent: Agent[Any], seen: Optional[set[int]]) -> Agent[Any]:
106-
if seen is None:
107-
seen = set()
108-
109-
# Short circuit if this model was already seen to prevent looping from circular handoffs
110-
if id(agent) in seen:
111-
return agent
112-
seen.add(id(agent))
113-
114-
# This agent has already been processed in some other run
115-
if isinstance(agent.model, _TemporalModelStub):
116-
return agent
117-
118-
name = _model_name(agent)
119-
120-
new_handoffs: list[Union[Agent, Handoff]] = []
121-
for handoff in agent.handoffs:
122-
if isinstance(handoff, Agent):
123-
new_handoffs.append(convert_agent(handoff, seen))
124-
elif isinstance(handoff, Handoff):
125-
original_invoke = handoff.on_invoke_handoff
126-
127-
async def on_invoke(
128-
context: RunContextWrapper[Any], args: str
129-
) -> Agent:
130-
handoff_agent = await original_invoke(context, args)
131-
return convert_agent(handoff_agent, seen)
132-
133-
new_handoffs.append(
134-
dataclasses.replace(handoff, on_invoke_handoff=on_invoke)
135-
)
136-
else:
137-
raise ValueError(f"Unknown handoff type: {type(handoff)}")
138-
139-
return dataclasses.replace(
140-
agent,
141-
model=_TemporalModelStub(
142-
model_name=name,
143-
model_params=self.model_params,
144-
agent=agent,
145-
),
146-
handoffs=new_handoffs,
147-
)
148-
149151
try:
150152
return await self._runner.run(
151-
starting_agent=convert_agent(starting_agent, None),
153+
starting_agent=_convert_agent(self.model_params, starting_agent, None),
152154
input=input,
153155
context=context,
154156
max_turns=max_turns,

tests/contrib/openai_agents/test_openai.py

Lines changed: 47 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -96,9 +96,13 @@
9696
TestModelProvider,
9797
)
9898
from temporalio.contrib.openai_agents._model_parameters import ModelSummaryProvider
99-
from temporalio.contrib.openai_agents._temporal_model_stub import _extract_summary
99+
from temporalio.contrib.openai_agents._openai_runner import _convert_agent
100+
from temporalio.contrib.openai_agents._temporal_model_stub import (
101+
_extract_summary,
102+
_TemporalModelStub,
103+
)
100104
from temporalio.contrib.pydantic import pydantic_data_converter
101-
from temporalio.exceptions import ApplicationError, CancelledError
105+
from temporalio.exceptions import ApplicationError, CancelledError, TemporalError
102106
from temporalio.testing import WorkflowEnvironment
103107
from temporalio.workflow import ActivityConfig
104108
from tests.contrib.openai_agents.research_agents.research_manager import (
@@ -897,7 +901,10 @@ async def update_seat(
897901
async def on_seat_booking_handoff(
898902
context: RunContextWrapper[AirlineAgentContext],
899903
) -> None:
900-
flight_number = f"FLT-{workflow.random().randint(100, 999)}"
904+
try:
905+
flight_number = f"FLT-{workflow.random().randint(100, 999)}"
906+
except TemporalError:
907+
flight_number = "FLT-100"
901908
context.context.flight_number = flight_number
902909

903910

@@ -975,6 +982,8 @@ class CustomerServiceModel(StaticTestModel):
975982
ResponseBuilders.output_message(
976983
"Your seat has been updated to a window seat. If there's anything else you need, feel free to let me know!"
977984
),
985+
ResponseBuilders.tool_call("{}", "transfer_to_triage_agent"),
986+
ResponseBuilders.output_message("You're welcome!"),
978987
]
979988

980989

@@ -988,10 +997,7 @@ def __init__(self, input_items: list[TResponseInputItem] = []):
988997

989998
@workflow.run
990999
async def run(self, input_items: list[TResponseInputItem] = []):
991-
await workflow.wait_condition(
992-
lambda: workflow.info().is_continue_as_new_suggested()
993-
and workflow.all_handlers_finished()
994-
)
1000+
await workflow.wait_condition(lambda: False)
9951001
workflow.continue_as_new(self.input_items)
9961002

9971003
@workflow.query
@@ -1062,7 +1068,13 @@ async def test_customer_service_workflow(client: Client, use_local_model: bool):
10621068
]
10631069
client = Client(**new_config)
10641070

1065-
questions = ["Hello", "Book me a flight to PDX", "11111", "Any window seat"]
1071+
questions = [
1072+
"Hello",
1073+
"Book me a flight to PDX",
1074+
"11111",
1075+
"Any window seat",
1076+
"Take me back to the triage agent to say goodbye",
1077+
]
10661078

10671079
async with new_worker(
10681080
client,
@@ -1101,7 +1113,7 @@ async def test_customer_service_workflow(client: Client, use_local_model: bool):
11011113
if e.HasField("activity_task_completed_event_attributes"):
11021114
events.append(e)
11031115

1104-
assert len(events) == 6
1116+
assert len(events) == 8
11051117
assert (
11061118
"Hi there! How can I assist you today?"
11071119
in events[0]
@@ -1138,6 +1150,18 @@ async def test_customer_service_workflow(client: Client, use_local_model: bool):
11381150
.activity_task_completed_event_attributes.result.payloads[0]
11391151
.data.decode()
11401152
)
1153+
assert (
1154+
"transfer_to_triage_agent"
1155+
in events[6]
1156+
.activity_task_completed_event_attributes.result.payloads[0]
1157+
.data.decode()
1158+
)
1159+
assert (
1160+
"You're welcome!"
1161+
in events[7]
1162+
.activity_task_completed_event_attributes.result.payloads[0]
1163+
.data.decode()
1164+
)
11411165

11421166

11431167
class InputGuardrailModel(OpenAIResponsesModel):
@@ -2571,3 +2595,17 @@ def override_get_activities() -> Sequence[Callable]:
25712595
err.value.cause.message
25722596
== "MCP Stateful Server Worker failed to schedule activity."
25732597
)
2598+
2599+
2600+
async def test_model_conversion_loops():
2601+
agent = init_agents()
2602+
converted = _convert_agent(ModelActivityParameters(), agent, None)
2603+
seat_booking_handoff = converted.handoffs[1]
2604+
assert isinstance(seat_booking_handoff, Handoff)
2605+
context: RunContextWrapper[AirlineAgentContext] = RunContextWrapper(
2606+
context=AirlineAgentContext() # type: ignore
2607+
)
2608+
seat_booking_agent = await seat_booking_handoff.on_invoke_handoff(context, "")
2609+
triage_agent = seat_booking_agent.handoffs[0]
2610+
assert isinstance(triage_agent, Agent)
2611+
assert isinstance(triage_agent.model, _TemporalModelStub)

0 commit comments

Comments
 (0)