|
1 | | -from typing import TYPE_CHECKING |
| 1 | +from typing import TYPE_CHECKING, Iterable |
2 | 2 |
|
3 | 3 |
|
4 | 4 | if TYPE_CHECKING: |
|
16 | 16 | logger = get_logger(__name__) |
17 | 17 |
|
18 | 18 |
|
| 19 | +def compose_callbacks(callbacks: Iterable[ConversationCallbackType]) -> ConversationCallbackType: |
| 20 | + def composed(event) -> None: |
| 21 | + for cb in callbacks: |
| 22 | + if cb: |
| 23 | + cb(event) |
| 24 | + return composed |
| 25 | + |
19 | 26 | class Conversation: |
20 | | - def __init__(self, agent: "AgentBase", on_event: ConversationCallbackType | None = None, max_iteration_per_run: int = 500): |
| 27 | + def __init__( |
| 28 | + self, |
| 29 | + agent: "AgentBase", |
| 30 | + callbacks: list[ConversationCallbackType] | None = None, |
| 31 | + max_iteration_per_run: int = 500, |
| 32 | + ): |
| 33 | + """Initialize the conversation.""" |
21 | 34 | self._visualizer = ConversationVisualizer() |
22 | | - self._on_event: ConversationCallbackType = on_event or self._visualizer.on_event |
| 35 | + # Compose multiple callbacks if a list is provided |
| 36 | + self._on_event = compose_callbacks( |
| 37 | + [self._visualizer.on_event] + (callbacks if callbacks else []) |
| 38 | + ) |
23 | 39 | self.max_iteration_per_run = max_iteration_per_run |
24 | 40 |
|
25 | 41 | self.agent = agent |
| 42 | + self._agent_initialized = False |
26 | 43 |
|
27 | 44 | # Guarding the conversation state to prevent multiple |
28 | 45 | # writers modify it at the same time |
29 | 46 | self._lock = RLock() |
30 | 47 | self.state = ConversationState() |
31 | 48 |
|
32 | | - with self._lock: |
33 | | - # will modify self.state in place |
34 | | - self.state = self.agent.init_state(self.state, on_event=self._on_event) |
35 | | - |
36 | 49 | def send_message(self, message: Message) -> None: |
37 | 50 | """Sending messages to the agent.""" |
38 | 51 | with self._lock: |
39 | | - messages = self.state.history.messages |
40 | | - messages.append(message) |
41 | | - if self._on_event: |
42 | | - self._on_event(message) |
| 52 | + if not self._agent_initialized: |
| 53 | + # Prepare initial state |
| 54 | + self.state = self.agent.init_state( |
| 55 | + self.state, |
| 56 | + initial_user_message=message, |
| 57 | + on_event=self._on_event, |
| 58 | + ) |
| 59 | + self._agent_initialized = True |
| 60 | + else: |
| 61 | + messages = self.state.history.messages |
| 62 | + messages.append(message) |
| 63 | + if self._on_event: |
| 64 | + self._on_event(message) |
43 | 65 |
|
44 | 66 | def run(self) -> None: |
45 | 67 | """Runs the conversation until the agent finishes.""" |
|
0 commit comments