Skip to content

Commit 855160a

Browse files
authored
support multiple callback fn; improve examples of how to add callback from script (#38)
1 parent 01f5fdc commit 855160a

File tree

8 files changed

+74
-18
lines changed

8 files changed

+74
-18
lines changed

examples/hello_world.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,13 @@
44

55
from openhands.core import (
66
LLM,
7+
ActionBase,
78
CodeActAgent,
89
Conversation,
10+
ConversationEventType,
911
LLMConfig,
1012
Message,
13+
ObservationBase,
1114
TextContent,
1215
Tool,
1316
get_logger,
@@ -42,7 +45,19 @@
4245

4346
# Agent
4447
agent = CodeActAgent(llm=llm, tools=tools)
45-
conversation = Conversation(agent=agent)
48+
49+
llm_messages = [] # collect raw LLM messages
50+
def conversation_callback(event: ConversationEventType):
51+
# print all the actions
52+
if isinstance(event, ActionBase):
53+
logger.info(f"Found a conversation action: {event}")
54+
elif isinstance(event, ObservationBase):
55+
logger.info(f"Found a conversation observation: {event}")
56+
elif isinstance(event, Message):
57+
logger.info(f"Found a conversation message: {str(event)[:200]}...")
58+
llm_messages.append(event.model_dump())
59+
60+
conversation = Conversation(agent=agent, callbacks=[conversation_callback])
4661

4762
conversation.send_message(
4863
message=Message(
@@ -51,3 +66,8 @@
5166
)
5267
)
5368
conversation.run()
69+
70+
print("="*100)
71+
print("Conversation finished. Got the following LLM messages:")
72+
for i, message in enumerate(llm_messages):
73+
print(f"Message {i}: {str(message)[:200]}")

openhands/core/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from .agent import AgentBase, CodeActAgent
44
from .config import LLMConfig, MCPConfig
5-
from .conversation import Conversation
5+
from .conversation import Conversation, ConversationCallbackType, ConversationEventType
66
from .llm import LLM, ImageContent, Message, TextContent
77
from .logger import get_logger
88
from .tool import ActionBase, ObservationBase, Tool
@@ -27,5 +27,7 @@
2727
"MCPConfig",
2828
"get_logger",
2929
"Conversation",
30+
"ConversationCallbackType",
31+
"ConversationEventType",
3032
"__version__",
3133
]

openhands/core/agent/base.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
from openhands.core.context.env_context import EnvContext
55
from openhands.core.conversation import ConversationCallbackType, ConversationState
6-
from openhands.core.llm import LLM
6+
from openhands.core.llm import LLM, Message
77
from openhands.core.logger import get_logger
88
from openhands.core.tool import Tool
99

@@ -60,6 +60,7 @@ def env_context(self) -> EnvContext | None:
6060
def init_state(
6161
self,
6262
state: ConversationState,
63+
initial_user_message: Message | None = None,
6364
on_event: ConversationCallbackType | None = None,
6465
) -> ConversationState:
6566
"""Initialize the empty conversation state to prepare the agent for user messages.

openhands/core/agent/codeact_agent/codeact_agent.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,16 +77,25 @@ def __init__(
7777
def init_state(
7878
self,
7979
state: ConversationState,
80+
initial_user_message: Message | None = None,
8081
on_event: ConversationCallbackType | None = None,
8182
) -> ConversationState:
8283
# TODO(openhands): we should add test to test this init_state will actually modify state in-place
8384
messages = state.history.messages
8485
if len(messages) == 0:
86+
# Prepare system message
8587
sys_msg = Message(role="system", content=[self.system_message])
8688
messages.append(sys_msg)
8789
if on_event:
8890
on_event(sys_msg)
89-
content = state.history.messages[-1].content
91+
if initial_user_message is None:
92+
raise ValueError("initial_user_message must be provided in init_state for CodeActAgent")
93+
94+
# Prepare user message
95+
content = initial_user_message.content
96+
# TODO: think about this - we might want to handle this outside Agent but inside Conversation (e.g., in send_messages)
97+
# downside of handling them inside Conversation would be: conversation don't have access
98+
# to *any* action execution runtime information
9099
if self.env_context:
91100
initial_env_context: list[TextContent] = self.env_context.render(self.prompt_manager)
92101
content += initial_env_context
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from .conversation import Conversation
22
from .state import ConversationState
3-
from .types import ConversationCallbackType
3+
from .types import ConversationCallbackType, ConversationEventType
44
from .visualizer import ConversationVisualizer
55

66

7-
__all__ = ["Conversation", "ConversationState", "ConversationCallbackType", "ConversationVisualizer"]
7+
__all__ = ["Conversation", "ConversationState", "ConversationCallbackType", "ConversationEventType", "ConversationVisualizer"]

openhands/core/conversation/conversation.py

Lines changed: 33 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import TYPE_CHECKING
1+
from typing import TYPE_CHECKING, Iterable
22

33

44
if TYPE_CHECKING:
@@ -16,30 +16,52 @@
1616
logger = get_logger(__name__)
1717

1818

19+
def compose_callbacks(callbacks: Iterable[ConversationCallbackType]) -> ConversationCallbackType:
20+
def composed(event) -> None:
21+
for cb in callbacks:
22+
if cb:
23+
cb(event)
24+
return composed
25+
1926
class Conversation:
20-
def __init__(self, agent: "AgentBase", on_event: ConversationCallbackType | None = None, max_iteration_per_run: int = 500):
27+
def __init__(
28+
self,
29+
agent: "AgentBase",
30+
callbacks: list[ConversationCallbackType] | None = None,
31+
max_iteration_per_run: int = 500,
32+
):
33+
"""Initialize the conversation."""
2134
self._visualizer = ConversationVisualizer()
22-
self._on_event: ConversationCallbackType = on_event or self._visualizer.on_event
35+
# Compose multiple callbacks if a list is provided
36+
self._on_event = compose_callbacks(
37+
[self._visualizer.on_event] + (callbacks if callbacks else [])
38+
)
2339
self.max_iteration_per_run = max_iteration_per_run
2440

2541
self.agent = agent
42+
self._agent_initialized = False
2643

2744
# Guarding the conversation state to prevent multiple
2845
# writers modify it at the same time
2946
self._lock = RLock()
3047
self.state = ConversationState()
3148

32-
with self._lock:
33-
# will modify self.state in place
34-
self.state = self.agent.init_state(self.state, on_event=self._on_event)
35-
3649
def send_message(self, message: Message) -> None:
3750
"""Sending messages to the agent."""
3851
with self._lock:
39-
messages = self.state.history.messages
40-
messages.append(message)
41-
if self._on_event:
42-
self._on_event(message)
52+
if not self._agent_initialized:
53+
# Prepare initial state
54+
self.state = self.agent.init_state(
55+
self.state,
56+
initial_user_message=message,
57+
on_event=self._on_event,
58+
)
59+
self._agent_initialized = True
60+
else:
61+
messages = self.state.history.messages
62+
messages.append(message)
63+
if self._on_event:
64+
self._on_event(message)
4365

4466
def run(self) -> None:
4567
"""Runs the conversation until the agent finishes."""

openhands/core/conversation/types.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,5 @@
44
from openhands.core.tool import ActionBase, ObservationBase
55

66

7-
ConversationCallbackType = Callable[[Message | ActionBase | ObservationBase], None]
7+
ConversationEventType = Message | ActionBase | ObservationBase
8+
ConversationCallbackType = Callable[[ConversationEventType], None]

openhands/core/llm/llm.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -676,6 +676,7 @@ def __str__(self) -> str:
676676
def __repr__(self) -> str:
677677
return str(self)
678678

679+
# TODO: we should ideally format this into a `to_litellm_message` for `Message` class`
679680
def format_messages_for_llm(self, messages: Message | list[Message]) -> list[dict]:
680681
if isinstance(messages, Message):
681682
messages = [messages]

0 commit comments

Comments
 (0)