|
| 1 | +from typing import Callable |
| 2 | + |
| 3 | +from openhands.core.llm import LLM |
| 4 | +from openhands.core.runtime import Tool, ActionBase, ObservationBase |
| 5 | +from openhands.core.context.env_context import EnvContext |
| 6 | +from openhands.core.llm.message import Message |
| 7 | +from openhands.core.logger import get_logger |
| 8 | + |
| 9 | +logger = get_logger(__name__) |
| 10 | + |
| 11 | + |
| 12 | +class AgentBase: |
| 13 | + def __init__( |
| 14 | + self, |
| 15 | + llm: LLM, |
| 16 | + tools: list[Tool], |
| 17 | + env_context: EnvContext | None = None, |
| 18 | + ) -> None: |
| 19 | + """Initializes a new instance of the Agent class.""" |
| 20 | + self._llm = llm |
| 21 | + self._tools = tools |
| 22 | + self._name_to_tool: dict[str, Tool] = {} |
| 23 | + for tool in tools: |
| 24 | + if tool.name in self._name_to_tool: |
| 25 | + raise ValueError(f"Duplicate tool name: {tool.name}") |
| 26 | + logger.debug(f"Registering tool: {tool}") |
| 27 | + self._name_to_tool[tool.name] = tool |
| 28 | + self._env_context = env_context |
| 29 | + |
| 30 | + @property |
| 31 | + def name(self) -> str: |
| 32 | + """Returns the name of the Agent.""" |
| 33 | + return self.__class__.__name__ |
| 34 | + |
| 35 | + @property |
| 36 | + def llm(self) -> LLM: |
| 37 | + """Returns the LLM instance used by the Agent.""" |
| 38 | + return self._llm |
| 39 | + |
| 40 | + @property |
| 41 | + def tools(self) -> list[Tool]: |
| 42 | + """Returns the list of tools available to the Agent.""" |
| 43 | + return self._tools |
| 44 | + |
| 45 | + def get_tool(self, name: str) -> Tool | None: |
| 46 | + """Returns the tool with the given name, or None if not found.""" |
| 47 | + return self._name_to_tool.get(name) |
| 48 | + |
| 49 | + @property |
| 50 | + def env_context(self) -> EnvContext | None: |
| 51 | + """Returns the environment context used by the Agent.""" |
| 52 | + return self._env_context |
| 53 | + |
| 54 | + def reset(self) -> None: |
| 55 | + """Resets the Agent's internal state.""" |
| 56 | + pass |
| 57 | + |
| 58 | + def run( |
| 59 | + self, |
| 60 | + user_input: Message, |
| 61 | + on_event: Callable[[Message | ActionBase | ObservationBase], None] |
| 62 | + | None = None, |
| 63 | + ) -> None: |
| 64 | + """Runs the Agent with the given input and returns the output. |
| 65 | +
|
| 66 | + The agent will stop when it reaches a terminal state, such as |
| 67 | + completing its task by calling "finish" or messaging the user by calling "message". |
| 68 | + Implementations should invoke `on_event` (if provided) for any |
| 69 | + Messages, Actions, or Observations they produce. |
| 70 | + """ |
| 71 | + raise NotImplementedError("Subclasses must implement this method.") |
0 commit comments