|
53 | 53 | ToolCallItemTypes, |
54 | 54 | TResponseInputItem, |
55 | 55 | ) |
56 | | -from .lifecycle import RunHooks |
| 56 | +from .lifecycle import AgentHooksBase, RunHooks, RunHooksBase |
57 | 57 | from .logger import logger |
58 | 58 | from .memory import Session, SessionInputCallback |
59 | 59 | from .model_settings import ModelSettings |
@@ -461,13 +461,11 @@ async def run( |
461 | 461 | ) -> RunResult: |
462 | 462 | context = kwargs.get("context") |
463 | 463 | max_turns = kwargs.get("max_turns", DEFAULT_MAX_TURNS) |
464 | | - hooks = kwargs.get("hooks") |
| 464 | + hooks = cast(RunHooks[TContext], self._validate_run_hooks(kwargs.get("hooks"))) |
465 | 465 | run_config = kwargs.get("run_config") |
466 | 466 | previous_response_id = kwargs.get("previous_response_id") |
467 | 467 | conversation_id = kwargs.get("conversation_id") |
468 | 468 | session = kwargs.get("session") |
469 | | - if hooks is None: |
470 | | - hooks = RunHooks[Any]() |
471 | 469 | if run_config is None: |
472 | 470 | run_config = RunConfig() |
473 | 471 |
|
@@ -668,14 +666,12 @@ def run_streamed( |
668 | 666 | ) -> RunResultStreaming: |
669 | 667 | context = kwargs.get("context") |
670 | 668 | max_turns = kwargs.get("max_turns", DEFAULT_MAX_TURNS) |
671 | | - hooks = kwargs.get("hooks") |
| 669 | + hooks = cast(RunHooks[TContext], self._validate_run_hooks(kwargs.get("hooks"))) |
672 | 670 | run_config = kwargs.get("run_config") |
673 | 671 | previous_response_id = kwargs.get("previous_response_id") |
674 | 672 | conversation_id = kwargs.get("conversation_id") |
675 | 673 | session = kwargs.get("session") |
676 | 674 |
|
677 | | - if hooks is None: |
678 | | - hooks = RunHooks[Any]() |
679 | 675 | if run_config is None: |
680 | 676 | run_config = RunConfig() |
681 | 677 |
|
@@ -732,6 +728,23 @@ def run_streamed( |
732 | 728 | ) |
733 | 729 | return streamed_result |
734 | 730 |
|
| 731 | + @staticmethod |
| 732 | + def _validate_run_hooks( |
| 733 | + hooks: RunHooksBase[Any, Agent[Any]] | AgentHooksBase[Any, Agent[Any]] | Any | None, |
| 734 | + ) -> RunHooks[Any]: |
| 735 | + if hooks is None: |
| 736 | + return RunHooks[Any]() |
| 737 | + input_hook_type = type(hooks).__name__ |
| 738 | + if isinstance(hooks, AgentHooksBase): |
| 739 | + raise TypeError( |
| 740 | + "Run hooks must be instances of RunHooks. " |
| 741 | + f"Received agent-scoped hooks ({input_hook_type}). " |
| 742 | + "Attach AgentHooks to an Agent via Agent(..., hooks=...)." |
| 743 | + ) |
| 744 | + if not isinstance(hooks, RunHooksBase): |
| 745 | + raise TypeError(f"Run hooks must be instances of RunHooks. Received {input_hook_type}.") |
| 746 | + return hooks |
| 747 | + |
735 | 748 | @classmethod |
736 | 749 | async def _maybe_filter_model_input( |
737 | 750 | cls, |
|
0 commit comments