|
25 | 25 | "%cd ..\n", |
26 | 26 | "!pip install gepa --quiet\n", |
27 | 27 | "\n", |
| 28 | + "!pip install retry --quiet\n", |
| 29 | + "\n", |
28 | 30 | "%cd tau-bench/" |
29 | 31 | ] |
30 | 32 | }, |
|
249 | 251 | "# - A GEPA adapter that bridges GEPA's optimization process with tau-bench.\n", |
250 | 252 | "\n", |
251 | 253 | "\n", |
| 254 | + "from concurrent.futures import ThreadPoolExecutor\n", |
| 255 | + "from datetime import datetime\n", |
252 | 256 | "import os\n", |
253 | 257 | "import json\n", |
| 258 | + "import multiprocessing\n", |
254 | 259 | "import random\n", |
| 260 | + "from retry import retry\n", |
255 | 261 | "import traceback\n", |
256 | | - "import multiprocessing\n", |
257 | 262 | "from typing import List\n", |
258 | | - "from datetime import datetime\n", |
259 | | - "from concurrent.futures import ThreadPoolExecutor\n", |
260 | 263 | "\n", |
261 | | - "from google.adk.examples.gepa import tau_bench_agent as tau_bench_agent_lib\n", |
| 264 | + "import tau_bench_agent as tau_bench_agent_lib\n", |
262 | 265 | "from tau_bench.envs import get_env\n", |
263 | 266 | "from tau_bench.run import display_metrics\n", |
264 | 267 | "from tau_bench.types import EnvRunResult, RunConfig\n", |
|
349 | 352 | " if config.shuffle:\n", |
350 | 353 | " random.shuffle(idxs)\n", |
351 | 354 | "\n", |
352 | | - " def _run(idx: int) -> EnvRunResult:\n", |
| 355 | + " @retry(tries=3, delay=10, backoff=2)\n", |
| 356 | + " def _run_with_retry(idx: int) -> EnvRunResult:\n", |
353 | 357 | " isolated_env = get_env(\n", |
354 | | - " config.env,\n", |
355 | | - " user_strategy=config.user_strategy,\n", |
356 | | - " user_model=config.user_model,\n", |
357 | | - " task_split=config.task_split,\n", |
358 | | - " user_provider=config.user_model_provider,\n", |
359 | | - " task_index=idx,\n", |
| 358 | + " config.env,\n", |
| 359 | + " user_strategy=config.user_strategy,\n", |
| 360 | + " user_model=config.user_model,\n", |
| 361 | + " task_split=config.task_split,\n", |
| 362 | + " user_provider=config.user_model_provider,\n", |
| 363 | + " task_index=idx,\n", |
360 | 364 | " )\n", |
361 | 365 | " if print_results:\n", |
362 | 366 | " print(f'Running task {idx}')\n", |
363 | | - " try:\n", |
364 | | - " res = agent.solve(\n", |
| 367 | + " res = agent.solve(\n", |
365 | 368 | " env=isolated_env,\n", |
366 | 369 | " task_index=idx,\n", |
367 | | - " )\n", |
368 | | - " result = EnvRunResult(\n", |
| 370 | + " )\n", |
| 371 | + " return EnvRunResult(\n", |
369 | 372 | " task_id=idx,\n", |
370 | 373 | " reward=res.reward,\n", |
371 | 374 | " info=res.info,\n", |
372 | 375 | " traj=res.messages,\n", |
373 | 376 | " trial=i,\n", |
374 | | - " )\n", |
| 377 | + " )\n", |
| 378 | + "\n", |
| 379 | + " def _run(idx: int) -> EnvRunResult:\n", |
| 380 | + " try:\n", |
| 381 | + " result = _run_with_retry(idx)\n", |
375 | 382 | " except Exception as e:\n", |
| 383 | + " logging.warning('Inference error: %s', str(e))\n", |
376 | 384 | " result = EnvRunResult(\n", |
377 | | - " task_id=idx,\n", |
378 | | - " reward=0.0,\n", |
379 | | - " info={'error': str(e), 'traceback': traceback.format_exc()},\n", |
380 | | - " traj=[],\n", |
381 | | - " trial=i,\n", |
| 385 | + " task_id=idx,\n", |
| 386 | + " reward=0.0,\n", |
| 387 | + " info={'error': str(e), 'traceback': traceback.format_exc()},\n", |
| 388 | + " traj=[],\n", |
| 389 | + " trial=i,\n", |
382 | 390 | " )\n", |
| 391 | + "\n", |
383 | 392 | " if print_results:\n", |
384 | 393 | " print(\n", |
385 | | - " '✅' if result.reward == 1 else '❌',\n", |
386 | | - " f'task_id={idx}',\n", |
387 | | - " # result.info,\n", |
| 394 | + " '✅' if result.reward == 1 else '❌',\n", |
| 395 | + " f'task_id={idx}',\n", |
| 396 | + " # result.info,\n", |
388 | 397 | " )\n", |
389 | 398 | " print('-----')\n", |
390 | 399 | " with lock:\n", |
|
446 | 455 | " task_info: dict\n", |
447 | 456 | "\n", |
448 | 457 | "\n", |
| 458 | + "def refine_tau_bench_trajectory(traj: list[dict[str, Any]]) -> None:\n", |
| 459 | + " \"\"\"Removes unnecessary info from the trajectory, in place.\"\"\"\n", |
| 460 | + " for content in traj:\n", |
| 461 | + " for part in content[\"parts\"]:\n", |
| 462 | + " # Drop all fields that are not populated.\n", |
| 463 | + " to_drop = []\n", |
| 464 | + " for key in part:\n", |
| 465 | + " if not part[key]:\n", |
| 466 | + " to_drop.append(key)\n", |
| 467 | + " for key in to_drop:\n", |
| 468 | + " del part[key]\n", |
| 469 | + "\n", |
| 470 | + " # For function calls / responses only keep function names, input arguments\n", |
| 471 | + " # and outputs.\n", |
| 472 | + " if fc := part.get(\"function_call\"):\n", |
| 473 | + " part[\"function_call\"] = dict(name=fc[\"name\"], args=fc[\"args\"])\n", |
| 474 | + " if fr := part.get(\"function_response\"):\n", |
| 475 | + " part[\"function_response\"] = dict(name=fr[\"name\"], args=fr[\"response\"])\n", |
| 476 | + "\n", |
| 477 | + "\n", |
449 | 478 | "class TauBenchAdapter(GEPAAdapter[\n", |
450 | 479 | " TauBenchDataInst,\n", |
451 | 480 | " TauBenchTrajectory,\n", |
|
462 | 491 | " agent_strategy='tool-calling',\n", |
463 | 492 | " user_strategy='llm',\n", |
464 | 493 | " system_instruction_name='system_instruction',\n", |
465 | | - " tool_definitions_name='tool_definitions',\n", |
| 494 | + " tools_description: list[dict[str, Any]] | None = None,\n", |
466 | 495 | " max_concurrency=4,\n", |
467 | 496 | " ):\n", |
468 | 497 | " \"\"\"Initializes the TauBenchAdapter.\n", |
|
476 | 505 | " user_strategy: The user simulation strategy (e.g., 'llm').\n", |
477 | 506 | " system_instruction_name: The key in the candidate dictionary that holds\n", |
478 | 507 | " the system instruction.\n", |
479 | | - " tool_definitions_name: The key in the candidate dictionary that holds the\n", |
480 | | - " tool definitions.\n", |
| 508 | + " tools_description: Describes each of the availble tools. This is used as context\n", |
| 509 | + " for the prompt proposer.\n", |
481 | 510 | " max_concurrency: The maximum number of tasks to run in parallel.\n", |
482 | 511 | " \"\"\"\n", |
483 | 512 | " self._agent_model = agent_model\n", |
|
488 | 517 | " self._user_strategy = user_strategy\n", |
489 | 518 | " self._max_concurrency = max_concurrency\n", |
490 | 519 | " self._system_instruction_name = system_instruction_name\n", |
491 | | - " self._tool_definitions_name = tool_definitions_name\n", |
| 520 | + " self._tools_description = tools_description\n", |
492 | 521 | "\n", |
493 | 522 | " def evaluate(\n", |
494 | 523 | " self,\n", |
|
544 | 573 | " reward=res.reward,\n", |
545 | 574 | " task_info=res.info))\n", |
546 | 575 | " result_traj = res.traj\n", |
547 | | - " # TODO - Consider refining the trajectory format.\n", |
| 576 | + " refine_tau_bench_trajectory(result_traj)\n", |
548 | 577 | " trajectories.append(TauBenchTrajectory(result_traj=result_traj))\n", |
549 | 578 | " scores.append(res.reward)\n", |
550 | 579 | "\n", |
|
574 | 603 | " data instances for reflection.\n", |
575 | 604 | " \"\"\"\n", |
576 | 605 | " system_instruction = candidate[self._system_instruction_name]\n", |
577 | | - " tool_definitions = candidate[self._tool_definitions_name]\n", |
| 606 | + "\n", |
| 607 | + " tool_definitions = json.dumps(\n", |
| 608 | + " self._tools_description,\n", |
| 609 | + " indent=2,\n", |
| 610 | + " default=str,\n", |
| 611 | + " )\n", |
| 612 | + "\n", |
578 | 613 | " inputs = '\\n\\n'.join([\n", |
579 | 614 | " f'# System Instruction\\n{system_instruction}',\n", |
580 | 615 | " f'# Tool Definitions\\n{tool_definitions}',\n", |
|
670 | 705 | "]\n", |
671 | 706 | "\n", |
672 | 707 | "system_instruction_name = 'system_instruction'\n", |
673 | | - "tool_definitions_name = 'tool_definitions'\n", |
674 | 708 | "\n", |
675 | 709 | "SEED_SYSTEM_INSTRUCTION = (\n", |
676 | 710 | " 'you are a customer support agent helping customers resolve their '\n", |
|
679 | 713 | "\n", |
680 | 714 | "seed_candidate = {\n", |
681 | 715 | " system_instruction_name: SEED_SYSTEM_INSTRUCTION,\n", |
682 | | - " # TODO - Consider removing tool definition from optimization space.\n", |
683 | | - " tool_definitions_name: json.dumps(\n", |
684 | | - " tool_definitions_by_domain[tau_bench_env],\n", |
685 | | - " indent=2,\n", |
686 | | - " default=str,\n", |
687 | | - " ),\n", |
688 | 716 | "}" |
689 | 717 | ] |
690 | 718 | }, |
|
700 | 728 | "# With the configuration and adapter in place, this section creates the adapter\n", |
701 | 729 | "# instance and calls `gepa.optimize()` to start the Automatic Prompt\n", |
702 | 730 | "# Optimization (APO) process.\n", |
| 731 | + "import litellm\n", |
703 | 732 | "\n", |
704 | 733 | "tau_bench_adapter = TauBenchAdapter(\n", |
705 | 734 | " agent_model=agent_model,\n", |
|
709 | 738 | " agent_strategy='tool-calling',\n", |
710 | 739 | " user_strategy='llm',\n", |
711 | 740 | " system_instruction_name=system_instruction_name,\n", |
712 | | - " tool_definitions_name=tool_definitions_name,\n", |
| 741 | + " tools_description=tool_definitions_by_domain[tau_bench_env],\n", |
713 | 742 | " max_concurrency=max_concurrency,\n", |
714 | 743 | ")\n", |
715 | 744 | "\n", |
|
720 | 749 | " task_lm=None, # this must be None when a custom adapter is used\n", |
721 | 750 | " adapter=tau_bench_adapter,\n", |
722 | 751 | " max_metric_calls=max_metric_calls,\n", |
723 | | - " reflection_lm=f'vertex_ai/{reflection_model}',\n", |
| 752 | + " reflection_lm = (\n", |
| 753 | + " lambda prompt: litellm.completion_with_retries(\n", |
| 754 | + " model=f'vertex_ai/{reflection_model}',\n", |
| 755 | + " messages=[{\"role\": \"user\", \"content\": prompt}],\n", |
| 756 | + " num_retries=4, initial_delay=1, max_delay=1,\n", |
| 757 | + " ).choices[0].message.content\n", |
| 758 | + " ),\n", |
724 | 759 | " reflection_minibatch_size=reflection_minibatch_size,\n", |
725 | 760 | ")\n", |
726 | 761 | "list(enumerate(gepa_results.val_aggregate_scores))" |
|
735 | 770 | "outputs": [], |
736 | 771 | "source": [ |
737 | 772 | "#@title Evaluate All Candidates\n", |
738 | | - "%%time\n", |
739 | 773 | "\n", |
740 | 774 | "\n", |
741 | 775 | "# This is the prompt from https://arxiv.org/pdf/2406.12045\n", |
|
855 | 889 | " )\n", |
856 | 890 | " system_instruction_to_eval_results[system_instruction] = tau_bench_results" |
857 | 891 | ] |
| 892 | + }, |
| 893 | + { |
| 894 | + "cell_type": "code", |
| 895 | + "execution_count": null, |
| 896 | + "metadata": { |
| 897 | + "id": "w4Q5hMuERuO6" |
| 898 | + }, |
| 899 | + "outputs": [], |
| 900 | + "source": [ |
| 901 | + "print(gepa_results.best_candidate['system_instruction'])" |
| 902 | + ] |
| 903 | + }, |
| 904 | + { |
| 905 | + "cell_type": "code", |
| 906 | + "execution_count": null, |
| 907 | + "metadata": { |
| 908 | + "id": "pbG7aBXLRuO6" |
| 909 | + }, |
| 910 | + "outputs": [], |
| 911 | + "source": [] |
858 | 912 | } |
859 | 913 | ], |
860 | 914 | "metadata": { |
861 | 915 | "colab": { |
862 | | - "provenance": [], |
863 | 916 | "last_runtime": { |
864 | 917 | "build_target": "//learning/language/tunelab/tunekit/colab:colab_notebook", |
865 | 918 | "kind": "private" |
866 | | - } |
| 919 | + }, |
| 920 | + "provenance": [] |
867 | 921 | }, |
868 | 922 | "kernelspec": { |
869 | 923 | "display_name": "Python 3 (ipykernel)", |
|
0 commit comments