1111
1212from openhands .core .context import EnvContext , render_system_message
1313from openhands .core .conversation import ConversationCallbackType , ConversationState
14- from openhands .core .event import ActionEvent , AgentErrorEvent , LLMConvertibleEvent , MessageEvent , ObservationEvent , SystemPromptEvent
14+ from openhands .core .event import (
15+ ActionEvent ,
16+ AgentErrorEvent ,
17+ LLMConvertibleEvent ,
18+ MessageEvent ,
19+ ObservationEvent ,
20+ SystemPromptEvent ,
21+ )
1522from openhands .core .llm import LLM , Message , TextContent , get_llm_metadata
1623from openhands .core .logger import get_logger
17- from openhands .core .tool import BUILT_IN_TOOLS , ActionBase , FinishTool , ObservationBase , Tool
24+ from openhands .core .tool import (
25+ BUILT_IN_TOOLS ,
26+ ActionBase ,
27+ FinishTool ,
28+ ObservationBase ,
29+ Tool ,
30+ )
1831
1932from ..base import AgentBase
2033
@@ -32,10 +45,18 @@ def __init__(
3245 cli_mode : bool = True ,
3346 ) -> None :
3447 for tool in BUILT_IN_TOOLS :
35- assert tool not in tools , f"{ tool } is automatically included and should not be provided."
48+ assert tool not in tools , (
49+ f"{ tool } is automatically included and should not be provided."
50+ )
3651 super ().__init__ (llm = llm , tools = tools + BUILT_IN_TOOLS , env_context = env_context )
3752
38- self .system_message : TextContent = TextContent (text = render_system_message (prompt_dir = self .prompt_dir , system_prompt_filename = system_prompt_filename , cli_mode = cli_mode ))
53+ self .system_message : TextContent = TextContent (
54+ text = render_system_message (
55+ prompt_dir = self .prompt_dir ,
56+ system_prompt_filename = system_prompt_filename ,
57+ cli_mode = cli_mode ,
58+ )
59+ )
3960
4061 self .max_iterations : int = 10
4162
@@ -44,11 +65,16 @@ def init_state(
4465 state : ConversationState ,
4566 on_event : ConversationCallbackType ,
4667 ) -> None :
47- # TODO(openhands): we should add test to test this init_state will actually modify state in-place
68+ # TODO(openhands): we should add test to test this init_state will actually
69+ # modify state in-place
4870 messages = [e .to_llm_message () for e in state .events ]
4971 if len (messages ) == 0 :
5072 # Prepare system message
51- event = SystemPromptEvent (source = "agent" , system_prompt = self .system_message , tools = [t .to_openai_tool () for t in self .tools .values ()])
73+ event = SystemPromptEvent (
74+ source = "agent" ,
75+ system_prompt = self .system_message ,
76+ tools = [t .to_openai_tool () for t in self .tools .values ()],
77+ )
5278 on_event (event )
5379
5480 def step (
@@ -57,13 +83,22 @@ def step(
5783 on_event : ConversationCallbackType ,
5884 ) -> None :
5985 # Get LLM Response (Action)
60- llm_convertible_events = cast (list [LLMConvertibleEvent ], [e for e in state .events if isinstance (e , LLMConvertibleEvent )])
61- _messages = self .llm .format_messages_for_llm (LLMConvertibleEvent .events_to_messages (llm_convertible_events ))
86+ llm_convertible_events = cast (
87+ list [LLMConvertibleEvent ],
88+ [e for e in state .events if isinstance (e , LLMConvertibleEvent )],
89+ )
90+ _messages = self .llm .format_messages_for_llm (
91+ LLMConvertibleEvent .events_to_messages (llm_convertible_events )
92+ )
6293 logger .debug (f"Sending messages to LLM: { json .dumps (_messages , indent = 2 )} " )
6394 response : ModelResponse = self .llm .completion (
6495 messages = _messages ,
6596 tools = [tool .to_openai_tool () for tool in self .tools .values ()],
66- extra_body = {"metadata" : get_llm_metadata (model_name = self .llm .config .model , agent_name = self .name )},
97+ extra_body = {
98+ "metadata" : get_llm_metadata (
99+ model_name = self .llm .config .model , agent_name = self .name
100+ )
101+ },
67102 )
68103 assert len (response .choices ) == 1 and isinstance (response .choices [0 ], Choices )
69104 llm_message : LiteLLMMessage = response .choices [0 ].message # type: ignore
@@ -72,12 +107,24 @@ def step(
72107 if message .tool_calls and len (message .tool_calls ) > 0 :
73108 tool_call : ChatCompletionMessageToolCall
74109 if any (tc .type != "function" for tc in message .tool_calls ):
75- logger .warning ("LLM returned tool calls but some are not of type 'function' - ignoring those" )
110+ logger .warning (
111+ "LLM returned tool calls but some are not of type 'function' - "
112+ "ignoring those"
113+ )
76114
77- tool_calls = [tool_call for tool_call in message .tool_calls if tool_call .type == "function" ]
78- assert len (tool_calls ) > 0 , "LLM returned tool calls but none are of type 'function'"
115+ tool_calls = [
116+ tool_call
117+ for tool_call in message .tool_calls
118+ if tool_call .type == "function"
119+ ]
120+ assert len (tool_calls ) > 0 , (
121+ "LLM returned tool calls but none are of type 'function'"
122+ )
79123 if not all (isinstance (c , TextContent ) for c in message .content ):
80- logger .warning ("LLM returned tool calls but message content is not all TextContent - ignoring non-text content" )
124+ logger .warning (
125+ "LLM returned tool calls but message content is not all "
126+ "TextContent - ignoring non-text content"
127+ )
81128
82129 # Generate unique batch ID for this LLM response
83130 thought_content = [c for c in message .content if isinstance (c , TextContent )]
@@ -89,7 +136,9 @@ def step(
89136 tool_call ,
90137 llm_response_id = response .id ,
91138 on_event = on_event ,
92- thought = thought_content if i == 0 else [], # Only first gets thought
139+ thought = thought_content
140+ if i == 0
141+ else [], # Only first gets thought
93142 )
94143 if action_event is None :
95144 continue
@@ -130,34 +179,62 @@ def _get_action_events(
130179
131180 # Validate arguments
132181 try :
133- action : ActionBase = tool .action_type .model_validate (json .loads (tool_call .function .arguments ))
182+ action : ActionBase = tool .action_type .model_validate (
183+ json .loads (tool_call .function .arguments )
184+ )
134185 except (json .JSONDecodeError , ValidationError ) as e :
135- err = f"Error validating args { tool_call .function .arguments } for tool '{ tool .name } ': { e } "
186+ err = (
187+ f"Error validating args { tool_call .function .arguments } for tool "
188+ f"'{ tool .name } ': { e } "
189+ )
136190 event = AgentErrorEvent (error = err )
137191 on_event (event )
138192 return
139193
140194 # Create one ActionEvent per action
141- action_event = ActionEvent (action = action , thought = thought , tool_name = tool .name , tool_call_id = tool_call .id , tool_call = tool_call , llm_response_id = llm_response_id )
195+ action_event = ActionEvent (
196+ action = action ,
197+ thought = thought ,
198+ tool_name = tool .name ,
199+ tool_call_id = tool_call .id ,
200+ tool_call = tool_call ,
201+ llm_response_id = llm_response_id ,
202+ )
142203 on_event (action_event )
143204 return action_event
144205
145- def _execute_action_events (self , state : ConversationState , action_event : ActionEvent , on_event : ConversationCallbackType ):
206+ def _execute_action_events (
207+ self ,
208+ state : ConversationState ,
209+ action_event : ActionEvent ,
210+ on_event : ConversationCallbackType ,
211+ ):
146212 """Execute action events and update the conversation state.
147213
148- It will call the tool's executor and update the state & call callback fn with the observation.
214+ It will call the tool's executor and update the state & call callback fn
215+ with the observation.
149216 """
150217 tool = self .tools .get (action_event .tool_name , None )
151218 if tool is None :
152- raise RuntimeError (f"Tool '{ action_event .tool_name } ' not found. This should not happen as it was checked earlier." )
219+ raise RuntimeError (
220+ f"Tool '{ action_event .tool_name } ' not found. This should not happen "
221+ "as it was checked earlier."
222+ )
153223
154224 # Execute actions!
155225 if tool .executor is None :
156226 raise RuntimeError (f"Tool '{ tool .name } ' has no executor" )
157227 observation : ObservationBase = tool .executor (action_event .action )
158- assert isinstance (observation , ObservationBase ), f"Tool '{ tool .name } ' executor must return an ObservationBase"
228+ assert isinstance (observation , ObservationBase ), (
229+ f"Tool '{ tool .name } ' executor must return an ObservationBase"
230+ )
159231
160- obs_event = ObservationEvent (observation = observation , action_id = action_event .id , tool_name = tool .name , tool_call_id = action_event .tool_call .id )
232+ obs_event = ObservationEvent (
233+ observation = observation ,
234+ action_id = action_event .id ,
235+ tool_name = tool .name ,
236+ tool_call_id = action_event .tool_call .id ,
237+ )
161238 on_event (obs_event )
162239
163240 # Set conversation state
0 commit comments