Skip to content

Commit 4667080

Browse files
Fix tool schema and add relevant tests; Add bash implmentations (#4)
Co-authored-by: openhands <openhands@all-hands.dev>
1 parent 1b43bbf commit 4667080

File tree

5 files changed

+813
-13
lines changed

5 files changed

+813
-13
lines changed

openhands/core/runtime/schema.py

Lines changed: 62 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,65 @@ def py_type(spec: dict[str, Any]) -> Any:
2424
return Any
2525

2626

27+
def _process_schema_node(node, defs):
28+
"""Recursively process a schema node to simplify and resolve $ref.
29+
30+
https://www.reddit.com/r/mcp/comments/1kjo9gt/toolinputschema_conversion_from_pydanticmodel/
31+
https://gist.github.com/leandromoreira/3de4819e4e4df9422d87f1d3e7465c16
32+
"""
33+
# Handle $ref references
34+
if "$ref" in node:
35+
ref_path = node["$ref"]
36+
if ref_path.startswith("#/$defs/"):
37+
ref_name = ref_path.split("/")[-1]
38+
if ref_name in defs:
39+
# Process the referenced definition
40+
return _process_schema_node(defs[ref_name], defs)
41+
42+
# Start with a new schema object
43+
result = {}
44+
45+
# Copy the basic properties
46+
if "type" in node:
47+
result["type"] = node["type"]
48+
49+
# Handle anyOf (often used for optional fields with None)
50+
if "anyOf" in node:
51+
non_null_types = [t for t in node["anyOf"] if t.get("type") != "null"]
52+
if non_null_types:
53+
# Process the first non-null type
54+
processed = _process_schema_node(non_null_types[0], defs)
55+
result.update(processed)
56+
57+
# Handle description
58+
if "description" in node:
59+
result["description"] = node["description"]
60+
61+
# Handle object properties recursively
62+
if node.get("type") == "object" and "properties" in node:
63+
result["type"] = "object"
64+
result["properties"] = {}
65+
66+
# Process each property
67+
for prop_name, prop_schema in node["properties"].items():
68+
result["properties"][prop_name] = _process_schema_node(prop_schema, defs)
69+
70+
# Add required fields if present
71+
if "required" in node:
72+
result["required"] = node["required"]
73+
74+
# Handle arrays
75+
if node.get("type") == "array" and "items" in node:
76+
result["type"] = "array"
77+
result["items"] = _process_schema_node(node["items"], defs)
78+
79+
# Handle enum
80+
if "enum" in node:
81+
result["enum"] = node["enum"]
82+
83+
return result
84+
85+
2786
class Schema(BaseModel):
2887
"""Base schema for input action / output observation."""
2988

@@ -32,13 +91,9 @@ class Schema(BaseModel):
3291
@classmethod
3392
def to_mcp_schema(cls) -> dict[str, Any]:
3493
"""Convert to JSON schema format compatible with MCP."""
35-
js = cls.model_json_schema()
36-
req = [n for n, f in cls.model_fields.items() if f.is_required()]
37-
return {
38-
"type": "object",
39-
"properties": js.get("properties", {}) or {},
40-
"required": req or [],
41-
}
94+
full_schema = cls.model_json_schema()
95+
# This will get rid of all "anyOf" in the schema, so it is fully compatible with MCP tool schema
96+
return _process_schema_node(full_schema, full_schema.get("$defs", {}))
4297

4398
@classmethod
4499
def from_mcp_schema(

openhands/core/runtime/tool.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,16 @@
1-
from typing import Any, Callable
1+
import re
2+
from typing import Any, Callable, TypeVar, Generic
23
from pydantic import BaseModel, Field
34
from .schema import ActionBase, ObservationBase, Schema
45

6+
ActionT = TypeVar("ActionT", bound=ActionBase)
7+
ObservationT = TypeVar("ObservationT", bound=ObservationBase)
8+
9+
10+
def to_camel_case(s: str) -> str:
11+
parts = re.split(r"[_\-\s]+", s)
12+
return "".join(word.capitalize() for word in parts if word)
13+
514

615
class ToolAnnotations(BaseModel):
716
"""Annotations to provide hints about the tool's behavior.
@@ -30,7 +39,7 @@ class ToolAnnotations(BaseModel):
3039
)
3140

3241

33-
class Tool:
42+
class Tool(Generic[ActionT, ObservationT]):
3443
"""Tool that wraps an executor function with input/output validation and schema.
3544
3645
- Normalize input/output schemas (class or dict) into both model+schema.
@@ -48,7 +57,7 @@ def __init__(
4857
description: str | None = None,
4958
annotations: ToolAnnotations | None = None,
5059
_meta: dict[str, Any] | None = None,
51-
execute_fn: Callable[[ActionBase], ObservationBase] | None = None,
60+
execute_fn: Callable[[ActionT], ObservationT] | None = None,
5261
):
5362
self.name = name
5463
self.description = description
@@ -71,7 +80,7 @@ def _set_input_schema(
7180
elif isinstance(input_schema, dict):
7281
self.input_schema = input_schema
7382
self.action_type = ActionBase.from_mcp_schema(
74-
f"{self.name}Action", input_schema
83+
f"{to_camel_case(self.name)}Action", input_schema
7584
)
7685
else:
7786
raise TypeError(
@@ -93,14 +102,18 @@ def _set_output_schema(
93102
elif isinstance(output_schema, dict):
94103
self.output_schema = output_schema
95104
self.observation_type = ObservationBase.from_mcp_schema(
96-
f"{self.name}Observation", output_schema
105+
f"{to_camel_case(self.name)}Observation", output_schema
97106
)
98107
else:
99108
raise TypeError(
100109
"output_schema must be ObservationBase subclass, dict, or None"
101110
)
102111

103-
def call(self, action: ActionBase) -> ObservationBase:
112+
def call(self, action: ActionT) -> ObservationBase:
113+
"""Validate input, execute, and coerce output.
114+
115+
We always return some ObservationBase subclass, but not always the generic ObservationT.
116+
"""
104117
if self.execute_fn is None:
105118
raise NotImplementedError(f"Tool '{self.name}' has no executor")
106119

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
from .definition import execute_bash_tool, ExecuteBashAction, ExecuteBashObservation
2+
3+
4+
__all__ = ["execute_bash_tool", "ExecuteBashAction", "ExecuteBashObservation"]
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
"""Execute bash tool implementation."""
2+
3+
from pydantic import Field
4+
5+
from openhands.core.runtime.tool import Tool, ToolAnnotations
6+
from openhands.core.runtime.schema import ActionBase, ObservationBase
7+
from openhands.core.runtime.security import SECURITY_RISK_DESC, SECURITY_RISK_LITERAL
8+
9+
10+
class ExecuteBashAction(ActionBase):
11+
"""Schema for bash command execution."""
12+
13+
command: str = Field(
14+
description="The bash command to execute. Can be empty string to view additional logs when previous exit code is `-1`. Can be `C-c` (Ctrl+C) to interrupt the currently running process. Note: You can only execute one bash command at a time. If you need to run multiple commands sequentially, you can use `&&` or `;` to chain them together."
15+
)
16+
is_input: bool = Field(
17+
default=False,
18+
description="If True, the command is an input to the running process. If False, the command is a bash command to be executed in the terminal. Default is False.",
19+
)
20+
timeout: float | None = Field(
21+
default=None,
22+
description="Optional. Sets a hard timeout in seconds for the command execution. If not provided, the command will use the default soft timeout behavior.",
23+
)
24+
security_risk: SECURITY_RISK_LITERAL = Field(description=SECURITY_RISK_DESC)
25+
26+
27+
class ExecuteBashObservation(ObservationBase):
28+
"""A ToolResult that can be rendered as a CLI output."""
29+
30+
output: str = Field(
31+
default="", description="The output from the command execution (stdout)."
32+
)
33+
exit_code: int = Field(
34+
default=0,
35+
description="The exit code of the command. -1 indicates the process hit the soft timeout and is not yet finished.",
36+
)
37+
error: str = Field(
38+
default="", description="Any error output from the command execution (stderr)."
39+
)
40+
timeout: bool = Field(
41+
default=False, description="Whether the command execution timed out."
42+
)
43+
44+
45+
TOOL_DESCRIPTION = """Execute a bash command in the terminal within a persistent shell session.
46+
47+
48+
### Command Execution
49+
* One command at a time: You can only execute one bash command at a time. If you need to run multiple commands sequentially, use `&&` or `;` to chain them together.
50+
* Persistent session: Commands execute in a persistent shell session where environment variables, virtual environments, and working directory persist between commands.
51+
* Soft timeout: Commands have a soft timeout of 10 seconds, once that's reached, you have the option to continue or interrupt the command (see section below for details)
52+
* Shell options: Do NOT use `set -e`, `set -eu`, or `set -euo pipefail` in shell scripts or commands in this environment. The runtime may not support them and can cause unusable shell sessions. If you want to run multi-line bash commands, write the commands to a file and then run it, instead.
53+
54+
### Long-running Commands
55+
* For commands that may run indefinitely, run them in the background and redirect output to a file, e.g. `python3 app.py > server.log 2>&1 &`.
56+
* For commands that may run for a long time (e.g. installation or testing commands), or commands that run for a fixed amount of time (e.g. sleep), you should set the "timeout" parameter of your function call to an appropriate value.
57+
* If a bash command returns exit code `-1`, this means the process hit the soft timeout and is not yet finished. By setting `is_input` to `true`, you can:
58+
- Send empty `command` to retrieve additional logs
59+
- Send text (set `command` to the text) to STDIN of the running process
60+
- Send control commands like `C-c` (Ctrl+C), `C-d` (Ctrl+D), or `C-z` (Ctrl+Z) to interrupt the process
61+
- If you do C-c, you can re-start the process with a longer "timeout" parameter to let it run to completion
62+
63+
### Best Practices
64+
* Directory verification: Before creating new directories or files, first verify the parent directory exists and is the correct location.
65+
* Directory management: Try to maintain working directory by using absolute paths and avoiding excessive use of `cd`.
66+
67+
### Output Handling
68+
* Output truncation: If the output exceeds a maximum length, it will be truncated before being returned.
69+
"""
70+
71+
72+
execute_bash_tool = Tool(
73+
name="execute_bash",
74+
input_schema=ExecuteBashAction,
75+
output_schema=ExecuteBashObservation,
76+
description=TOOL_DESCRIPTION,
77+
annotations=ToolAnnotations(
78+
title="execute_bash",
79+
readOnlyHint=False,
80+
destructiveHint=True,
81+
idempotentHint=False,
82+
openWorldHint=True,
83+
),
84+
)

0 commit comments

Comments
 (0)