|
1 | 1 | from datetime import timedelta |
2 | 2 |
|
3 | 3 | from pydantic import BaseModel |
4 | | -from restack_ai.agent import agent, import_functions, log |
| 4 | +from restack_ai.agent import NonRetryableError, agent, import_functions, log |
5 | 5 |
|
6 | 6 | with import_functions(): |
7 | 7 | from openai import pydantic_function_tool |
@@ -40,75 +40,88 @@ async def messages(self, messages_event: MessagesEvent) -> list[Message]: |
40 | 40 | ), |
41 | 41 | ] |
42 | 42 |
|
| 43 | + try: |
| 44 | + completion = await agent.step( |
| 45 | + function=llm_chat, |
| 46 | + function_input=LlmChatInput( |
| 47 | + messages=self.messages, tools=tools |
| 48 | + ), |
| 49 | + start_to_close_timeout=timedelta(seconds=120), |
| 50 | + ) |
| 51 | + except Exception as e: |
| 52 | + error_message = f"Error during llm_chat: {e}" |
| 53 | + raise NonRetryableError(error_message) from e |
| 54 | + else: |
| 55 | + log.info(f"completion: {completion}") |
| 56 | + |
| 57 | + tool_calls = completion.choices[0].message.tool_calls |
| 58 | + self.messages.append( |
| 59 | + Message( |
| 60 | + role="assistant", |
| 61 | + content=completion.choices[0].message.content or "", |
| 62 | + tool_calls=tool_calls, |
| 63 | + ) |
| 64 | + ) |
43 | 65 |
|
44 | | - completion = await agent.step( |
45 | | - function=llm_chat, |
46 | | - function_input=LlmChatInput( |
47 | | - messages=self.messages, tools=tools |
48 | | - ), |
49 | | - start_to_close_timeout=timedelta(seconds=120), |
50 | | - ) |
| 66 | + log.info(f"tool_calls: {tool_calls}") |
51 | 67 |
|
52 | | - log.info(f"completion: {completion}") |
| 68 | + if tool_calls: |
| 69 | + for tool_call in tool_calls: |
| 70 | + log.info(f"tool_call: {tool_call}") |
53 | 71 |
|
54 | | - tool_calls = completion.choices[0].message.tool_calls |
55 | | - self.messages.append( |
56 | | - Message( |
57 | | - role="assistant", |
58 | | - content=completion.choices[0].message.content or "", |
59 | | - tool_calls=tool_calls, |
60 | | - ) |
61 | | - ) |
62 | | - |
63 | | - log.info(f"tool_calls: {tool_calls}") |
64 | | - |
65 | | - if tool_calls: |
66 | | - for tool_call in tool_calls: |
67 | | - log.info(f"tool_call: {tool_call}") |
68 | | - |
69 | | - name = tool_call.function.name |
70 | | - |
71 | | - match name: |
72 | | - case lookup_sales.__name__: |
73 | | - args = LookupSalesInput.model_validate_json( |
74 | | - tool_call.function.arguments |
75 | | - ) |
76 | | - |
77 | | - log.info(f"calling {name} with args: {args}") |
78 | | - |
79 | | - result = await agent.step( |
80 | | - function=lookup_sales, |
81 | | - function_input=LookupSalesInput(category=args.category), |
82 | | - start_to_close_timeout=timedelta(seconds=120), |
83 | | - ) |
84 | | - self.messages.append( |
85 | | - Message( |
86 | | - role="tool", |
87 | | - tool_call_id=tool_call.id, |
88 | | - content=str(result), |
89 | | - ) |
90 | | - ) |
91 | | - |
92 | | - completion_with_tool_call = await agent.step( |
93 | | - function=llm_chat, |
94 | | - function_input=LlmChatInput( |
95 | | - messages=self.messages, system_content=system_content |
96 | | - ), |
97 | | - start_to_close_timeout=timedelta(seconds=120), |
98 | | - ) |
99 | | - self.messages.append( |
100 | | - Message( |
101 | | - role="assistant", |
102 | | - content=completion_with_tool_call.choices[ |
103 | | - 0 |
104 | | - ].message.content |
105 | | - or "", |
| 72 | + name = tool_call.function.name |
| 73 | + |
| 74 | + match name: |
| 75 | + case lookup_sales.__name__: |
| 76 | + args = LookupSalesInput.model_validate_json( |
| 77 | + tool_call.function.arguments |
106 | 78 | ) |
107 | | - ) |
108 | | - else: |
109 | | - pass |
110 | 79 |
|
111 | | - return self.messages |
| 80 | + log.info(f"calling {name} with args: {args}") |
| 81 | + |
| 82 | + try: |
| 83 | + result = await agent.step( |
| 84 | + function=lookup_sales, |
| 85 | + function_input=LookupSalesInput(category=args.category), |
| 86 | + start_to_close_timeout=timedelta(seconds=120), |
| 87 | + ) |
| 88 | + except Exception as e: |
| 89 | + error_message = f"Error during lookup_sales: {e}" |
| 90 | + raise NonRetryableError(error_message) from e |
| 91 | + else: |
| 92 | + self.messages.append( |
| 93 | + Message( |
| 94 | + role="tool", |
| 95 | + tool_call_id=tool_call.id, |
| 96 | + content=str(result), |
| 97 | + ) |
| 98 | + ) |
| 99 | + |
| 100 | + try: |
| 101 | + completion_with_tool_call = await agent.step( |
| 102 | + function=llm_chat, |
| 103 | + function_input=LlmChatInput( |
| 104 | + messages=self.messages |
| 105 | + ), |
| 106 | + start_to_close_timeout=timedelta(seconds=120), |
| 107 | + ) |
| 108 | + except Exception as e: |
| 109 | + error_message = f"Error during llm_chat: {e}" |
| 110 | + raise NonRetryableError(error_message) from e |
| 111 | + else: |
| 112 | + self.messages.append( |
| 113 | + Message( |
| 114 | + role="assistant", |
| 115 | + content=completion_with_tool_call.choices[ |
| 116 | + 0 |
| 117 | + ].message.content |
| 118 | + or "", |
| 119 | + ) |
| 120 | + ) |
| 121 | + else: |
| 122 | + pass |
| 123 | + |
| 124 | + return self.messages |
112 | 125 |
|
113 | 126 | @agent.event |
114 | 127 | async def end(self) -> EndEvent: |
|
0 commit comments