Skip to content

Commit 04dbc42

Browse files
google-genai-botcopybara-github
authored andcommitted
feat: Improve Tau-bench ADK colab stability
PiperOrigin-RevId: 825675599
1 parent 592c5d8 commit 04dbc42

File tree

4 files changed

+144
-44
lines changed

4 files changed

+144
-44
lines changed

contributing/samples/gepa/adk_agent.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from google.adk.agents import llm_agent
3131
from google.adk.agents import loop_agent
3232
from google.adk.events import event as event_lib
33+
from google.adk.models import google_llm
3334
from google.adk.tools import base_tool
3435
from google.genai import types
3536

@@ -98,6 +99,15 @@ async def run_async(self, *, args: Dict[str, Any], tool_context: Any) -> str:
9899
return env_response.observation
99100

100101

102+
def _default_retry_options() -> types.HttpRetryOptions:
103+
return types.HttpRetryOptions(
104+
initial_delay=2,
105+
attempts=4,
106+
max_delay=None,
107+
exp_base=2.0,
108+
)
109+
110+
101111
def _adk_agent(
102112
instruction: str,
103113
tools: list[base_tool.BaseTool],
@@ -120,7 +130,10 @@ def _adk_agent(
120130
# TDOO - Allow more flexibility in configuring the agent used in the loop.
121131
return llm_agent.LlmAgent(
122132
name=name or 'agent',
123-
model=model or 'gemini-2.5-flash',
133+
model=google_llm.Gemini(
134+
model=model or 'gemini-2.5-flash',
135+
retry_options=_default_retry_options(),
136+
),
124137
instruction=instruction,
125138
tools=tools,
126139
generate_content_config=types.GenerateContentConfig(
@@ -130,6 +143,10 @@ def _adk_agent(
130143
mode=types.FunctionCallingConfigMode.VALIDATED
131144
)
132145
),
146+
http_options=types.HttpOptions(
147+
timeout=30000,
148+
retry_options=_default_retry_options(),
149+
),
133150
),
134151
)
135152

contributing/samples/gepa/adk_agent_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -345,4 +345,4 @@ async def _mock_create_session(*args, **kwargs):
345345
)
346346
mock_runner_cls.assert_called_once()
347347
_, runner_kwargs = mock_runner_cls.call_args
348-
assert runner_kwargs["agent"].sub_agents[0].model == "some-test-model"
348+
assert runner_kwargs["agent"].sub_agents[0].model.model == "some-test-model"

contributing/samples/gepa/gepa_tau_bench.ipynb

Lines changed: 96 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525
"%cd ..\n",
2626
"!pip install gepa --quiet\n",
2727
"\n",
28+
"!pip install retry --quiet\n",
29+
"\n",
2830
"%cd tau-bench/"
2931
]
3032
},
@@ -249,16 +251,17 @@
249251
"# - A GEPA adapter that bridges GEPA's optimization process with tau-bench.\n",
250252
"\n",
251253
"\n",
254+
"from concurrent.futures import ThreadPoolExecutor\n",
255+
"from datetime import datetime\n",
252256
"import os\n",
253257
"import json\n",
258+
"import multiprocessing\n",
254259
"import random\n",
260+
"from retry import retry\n",
255261
"import traceback\n",
256-
"import multiprocessing\n",
257262
"from typing import List\n",
258-
"from datetime import datetime\n",
259-
"from concurrent.futures import ThreadPoolExecutor\n",
260263
"\n",
261-
"from google.adk.examples.gepa import tau_bench_agent as tau_bench_agent_lib\n",
264+
"import tau_bench_agent as tau_bench_agent_lib\n",
262265
"from tau_bench.envs import get_env\n",
263266
"from tau_bench.run import display_metrics\n",
264267
"from tau_bench.types import EnvRunResult, RunConfig\n",
@@ -349,42 +352,48 @@
349352
" if config.shuffle:\n",
350353
" random.shuffle(idxs)\n",
351354
"\n",
352-
" def _run(idx: int) -> EnvRunResult:\n",
355+
" @retry(tries=3, delay=10, backoff=2)\n",
356+
" def _run_with_retry(idx: int) -> EnvRunResult:\n",
353357
" isolated_env = get_env(\n",
354-
" config.env,\n",
355-
" user_strategy=config.user_strategy,\n",
356-
" user_model=config.user_model,\n",
357-
" task_split=config.task_split,\n",
358-
" user_provider=config.user_model_provider,\n",
359-
" task_index=idx,\n",
358+
" config.env,\n",
359+
" user_strategy=config.user_strategy,\n",
360+
" user_model=config.user_model,\n",
361+
" task_split=config.task_split,\n",
362+
" user_provider=config.user_model_provider,\n",
363+
" task_index=idx,\n",
360364
" )\n",
361365
" if print_results:\n",
362366
" print(f'Running task {idx}')\n",
363-
" try:\n",
364-
" res = agent.solve(\n",
367+
" res = agent.solve(\n",
365368
" env=isolated_env,\n",
366369
" task_index=idx,\n",
367-
" )\n",
368-
" result = EnvRunResult(\n",
370+
" )\n",
371+
" return EnvRunResult(\n",
369372
" task_id=idx,\n",
370373
" reward=res.reward,\n",
371374
" info=res.info,\n",
372375
" traj=res.messages,\n",
373376
" trial=i,\n",
374-
" )\n",
377+
" )\n",
378+
"\n",
379+
" def _run(idx: int) -> EnvRunResult:\n",
380+
" try:\n",
381+
" result = _run_with_retry(idx)\n",
375382
" except Exception as e:\n",
383+
" logging.warning('Inference error: %s', str(e))\n",
376384
" result = EnvRunResult(\n",
377-
" task_id=idx,\n",
378-
" reward=0.0,\n",
379-
" info={'error': str(e), 'traceback': traceback.format_exc()},\n",
380-
" traj=[],\n",
381-
" trial=i,\n",
385+
" task_id=idx,\n",
386+
" reward=0.0,\n",
387+
" info={'error': str(e), 'traceback': traceback.format_exc()},\n",
388+
" traj=[],\n",
389+
" trial=i,\n",
382390
" )\n",
391+
"\n",
383392
" if print_results:\n",
384393
" print(\n",
385-
" '✅' if result.reward == 1 else '❌',\n",
386-
" f'task_id={idx}',\n",
387-
" # result.info,\n",
394+
" '✅' if result.reward == 1 else '❌',\n",
395+
" f'task_id={idx}',\n",
396+
" # result.info,\n",
388397
" )\n",
389398
" print('-----')\n",
390399
" with lock:\n",
@@ -446,6 +455,26 @@
446455
" task_info: dict\n",
447456
"\n",
448457
"\n",
458+
"def refine_tau_bench_trajectory(traj: list[dict[str, Any]]) -> None:\n",
459+
" \"\"\"Removes unnecessary info from the trajectory, in place.\"\"\"\n",
460+
" for content in traj:\n",
461+
" for part in content[\"parts\"]:\n",
462+
" # Drop all fields that are not populated.\n",
463+
" to_drop = []\n",
464+
" for key in part:\n",
465+
" if not part[key]:\n",
466+
" to_drop.append(key)\n",
467+
" for key in to_drop:\n",
468+
" del part[key]\n",
469+
"\n",
470+
" # For function calls / responses only keep function names, input arguments\n",
471+
" # and outputs.\n",
472+
" if fc := part.get(\"function_call\"):\n",
473+
" part[\"function_call\"] = dict(name=fc[\"name\"], args=fc[\"args\"])\n",
474+
" if fr := part.get(\"function_response\"):\n",
475+
" part[\"function_response\"] = dict(name=fr[\"name\"], args=fr[\"response\"])\n",
476+
"\n",
477+
"\n",
449478
"class TauBenchAdapter(GEPAAdapter[\n",
450479
" TauBenchDataInst,\n",
451480
" TauBenchTrajectory,\n",
@@ -462,7 +491,7 @@
462491
" agent_strategy='tool-calling',\n",
463492
" user_strategy='llm',\n",
464493
" system_instruction_name='system_instruction',\n",
465-
" tool_definitions_name='tool_definitions',\n",
494+
" tools_description: list[dict[str, Any]] | None = None,\n",
466495
" max_concurrency=4,\n",
467496
" ):\n",
468497
" \"\"\"Initializes the TauBenchAdapter.\n",
@@ -476,8 +505,8 @@
476505
" user_strategy: The user simulation strategy (e.g., 'llm').\n",
477506
" system_instruction_name: The key in the candidate dictionary that holds\n",
478507
" the system instruction.\n",
479-
" tool_definitions_name: The key in the candidate dictionary that holds the\n",
480-
" tool definitions.\n",
508+
" tools_description: Describes each of the availble tools. This is used as context\n",
509+
" for the prompt proposer.\n",
481510
" max_concurrency: The maximum number of tasks to run in parallel.\n",
482511
" \"\"\"\n",
483512
" self._agent_model = agent_model\n",
@@ -488,7 +517,7 @@
488517
" self._user_strategy = user_strategy\n",
489518
" self._max_concurrency = max_concurrency\n",
490519
" self._system_instruction_name = system_instruction_name\n",
491-
" self._tool_definitions_name = tool_definitions_name\n",
520+
" self._tools_description = tools_description\n",
492521
"\n",
493522
" def evaluate(\n",
494523
" self,\n",
@@ -544,7 +573,7 @@
544573
" reward=res.reward,\n",
545574
" task_info=res.info))\n",
546575
" result_traj = res.traj\n",
547-
" # TODO - Consider refining the trajectory format.\n",
576+
" refine_tau_bench_trajectory(result_traj)\n",
548577
" trajectories.append(TauBenchTrajectory(result_traj=result_traj))\n",
549578
" scores.append(res.reward)\n",
550579
"\n",
@@ -574,7 +603,13 @@
574603
" data instances for reflection.\n",
575604
" \"\"\"\n",
576605
" system_instruction = candidate[self._system_instruction_name]\n",
577-
" tool_definitions = candidate[self._tool_definitions_name]\n",
606+
"\n",
607+
" tool_definitions = json.dumps(\n",
608+
" self._tools_description,\n",
609+
" indent=2,\n",
610+
" default=str,\n",
611+
" )\n",
612+
"\n",
578613
" inputs = '\\n\\n'.join([\n",
579614
" f'# System Instruction\\n{system_instruction}',\n",
580615
" f'# Tool Definitions\\n{tool_definitions}',\n",
@@ -670,7 +705,6 @@
670705
"]\n",
671706
"\n",
672707
"system_instruction_name = 'system_instruction'\n",
673-
"tool_definitions_name = 'tool_definitions'\n",
674708
"\n",
675709
"SEED_SYSTEM_INSTRUCTION = (\n",
676710
" 'you are a customer support agent helping customers resolve their '\n",
@@ -679,12 +713,6 @@
679713
"\n",
680714
"seed_candidate = {\n",
681715
" system_instruction_name: SEED_SYSTEM_INSTRUCTION,\n",
682-
" # TODO - Consider removing tool definition from optimization space.\n",
683-
" tool_definitions_name: json.dumps(\n",
684-
" tool_definitions_by_domain[tau_bench_env],\n",
685-
" indent=2,\n",
686-
" default=str,\n",
687-
" ),\n",
688716
"}"
689717
]
690718
},
@@ -700,6 +728,7 @@
700728
"# With the configuration and adapter in place, this section creates the adapter\n",
701729
"# instance and calls `gepa.optimize()` to start the Automatic Prompt\n",
702730
"# Optimization (APO) process.\n",
731+
"import litellm\n",
703732
"\n",
704733
"tau_bench_adapter = TauBenchAdapter(\n",
705734
" agent_model=agent_model,\n",
@@ -709,7 +738,7 @@
709738
" agent_strategy='tool-calling',\n",
710739
" user_strategy='llm',\n",
711740
" system_instruction_name=system_instruction_name,\n",
712-
" tool_definitions_name=tool_definitions_name,\n",
741+
" tools_description=tool_definitions_by_domain[tau_bench_env],\n",
713742
" max_concurrency=max_concurrency,\n",
714743
")\n",
715744
"\n",
@@ -720,7 +749,13 @@
720749
" task_lm=None, # this must be None when a custom adapter is used\n",
721750
" adapter=tau_bench_adapter,\n",
722751
" max_metric_calls=max_metric_calls,\n",
723-
" reflection_lm=f'vertex_ai/{reflection_model}',\n",
752+
" reflection_lm = (\n",
753+
" lambda prompt: litellm.completion_with_retries(\n",
754+
" model=f'vertex_ai/{reflection_model}',\n",
755+
" messages=[{\"role\": \"user\", \"content\": prompt}],\n",
756+
" num_retries=4, initial_delay=1, max_delay=1,\n",
757+
" ).choices[0].message.content\n",
758+
" ),\n",
724759
" reflection_minibatch_size=reflection_minibatch_size,\n",
725760
")\n",
726761
"list(enumerate(gepa_results.val_aggregate_scores))"
@@ -735,7 +770,6 @@
735770
"outputs": [],
736771
"source": [
737772
"#@title Evaluate All Candidates\n",
738-
"%%time\n",
739773
"\n",
740774
"\n",
741775
"# This is the prompt from https://arxiv.org/pdf/2406.12045\n",
@@ -855,15 +889,35 @@
855889
" )\n",
856890
" system_instruction_to_eval_results[system_instruction] = tau_bench_results"
857891
]
892+
},
893+
{
894+
"cell_type": "code",
895+
"execution_count": null,
896+
"metadata": {
897+
"id": "w4Q5hMuERuO6"
898+
},
899+
"outputs": [],
900+
"source": [
901+
"print(gepa_results.best_candidate['system_instruction'])"
902+
]
903+
},
904+
{
905+
"cell_type": "code",
906+
"execution_count": null,
907+
"metadata": {
908+
"id": "pbG7aBXLRuO6"
909+
},
910+
"outputs": [],
911+
"source": []
858912
}
859913
],
860914
"metadata": {
861915
"colab": {
862-
"provenance": [],
863916
"last_runtime": {
864917
"build_target": "//learning/language/tunelab/tunekit/colab:colab_notebook",
865918
"kind": "private"
866-
}
919+
},
920+
"provenance": []
867921
},
868922
"kernelspec": {
869923
"display_name": "Python 3 (ipykernel)",

contributing/samples/gepa/tau_bench_agent.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@
2828
from typing import Any
2929

3030
import adk_agent
31+
from google.adk.models import llm_response
32+
from google.adk.plugins import base_plugin
3133
from google.genai import types
3234
from tau_bench import envs
3335
from tau_bench import types as tau_bench_types
@@ -64,6 +66,26 @@ def _convert_tool(tool_def: dict[str, Any]) -> types.FunctionDeclaration:
6466
return types.FunctionDeclaration(**tool_def['function'])
6567

6668

69+
_LLM_CALL_ERROR = 'llm_call_error'
70+
71+
72+
class _TauBenchPlugin(base_plugin.BasePlugin):
73+
"""Catches LLM errors and emits event with error code for downstream usage."""
74+
75+
async def on_model_error_callback(
76+
self,
77+
*,
78+
callback_context: base_plugin.CallbackContext,
79+
llm_request: base_plugin.LlmRequest,
80+
error: Exception,
81+
) -> llm_response.LlmResponse:
82+
del callback_context, llm_request # Unused.
83+
return llm_response.LlmResponse(
84+
error_code=_LLM_CALL_ERROR,
85+
error_message=str(error),
86+
)
87+
88+
6789
class _ADKAgent(tool_calling_agent.ToolCallingAgent):
6890
"""ADK agent implementation for Tau Bench."""
6991

@@ -82,6 +104,9 @@ def solve(
82104
83105
Returns:
84106
The result of the solve.
107+
108+
Raises:
109+
- ValueError: If the LLM inference failed.
85110
"""
86111
# Thought-signature is excluded from the message serialization for the
87112
# following reasons:
@@ -102,7 +127,11 @@ def solve(
102127
tools=[_convert_tool(t) for t in env.tools_info],
103128
task_index=task_index,
104129
max_num_steps=max_num_steps,
130+
plugins=[_TauBenchPlugin(name='error_plugin')],
105131
):
132+
if event.error_code == _LLM_CALL_ERROR:
133+
raise ValueError(f'Error {event.error_code=}: {event.error_message=}')
134+
106135
if not event.content:
107136
continue
108137
messages.append(event.content.model_dump(exclude=content_exclusion))

0 commit comments

Comments
 (0)