|
| 1 | +import asyncio |
1 | 2 | import logging |
2 | 3 | import traceback |
3 | 4 | from collections.abc import Sequence |
@@ -56,33 +57,50 @@ def get_tool_description(self) -> Tool: |
56 | 57 | "minimum": 0, |
57 | 58 | }, |
58 | 59 | }, |
59 | | - "required": ["command"], |
| 60 | + "required": ["command", "directory"], |
60 | 61 | }, |
61 | 62 | ) |
62 | 63 |
|
63 | 64 | async def run_tool(self, arguments: dict) -> Sequence[TextContent]: |
64 | 65 | """Execute the shell command with the given arguments""" |
65 | 66 | command = arguments.get("command", []) |
66 | 67 | stdin = arguments.get("stdin") |
67 | | - directory = arguments.get("directory") |
| 68 | + directory = arguments.get("directory", "/tmp") # default to /tmp for safety |
68 | 69 | timeout = arguments.get("timeout") |
69 | 70 |
|
70 | 71 | if not command: |
71 | 72 | raise ValueError("No command provided") |
72 | 73 |
|
73 | | - result = await self.executor.execute(command, stdin, directory, timeout) |
| 74 | + if not isinstance(command, list): |
| 75 | + raise ValueError("'command' must be an array") |
74 | 76 |
|
75 | | - # Raise error if command execution failed |
76 | | - if result.get("error"): |
77 | | - raise RuntimeError(result["error"]) |
| 77 | + # Make sure directory exists |
| 78 | + if not directory: |
| 79 | + raise ValueError("Directory is required") |
78 | 80 |
|
79 | | - # Convert executor result to TextContent sequence |
80 | 81 | content: list[TextContent] = [] |
81 | | - |
82 | | - if result.get("stdout"): |
83 | | - content.append(TextContent(type="text", text=result["stdout"])) |
84 | | - if result.get("stderr"): |
85 | | - content.append(TextContent(type="text", text=result["stderr"])) |
| 82 | + try: |
| 83 | + # Handle execution with timeout |
| 84 | + try: |
| 85 | + result = await asyncio.wait_for( |
| 86 | + self.executor.execute( |
| 87 | + command, directory, stdin, None |
| 88 | + ), # Pass None for timeout |
| 89 | + timeout=timeout, |
| 90 | + ) |
| 91 | + except asyncio.TimeoutError as e: |
| 92 | + raise ValueError("Command execution timed out") from e |
| 93 | + |
| 94 | + if result.get("error"): |
| 95 | + raise ValueError(result["error"]) |
| 96 | + |
| 97 | + if result.get("stdout"): |
| 98 | + content.append(TextContent(type="text", text=result["stdout"])) |
| 99 | + if result.get("stderr"): |
| 100 | + content.append(TextContent(type="text", text=result["stderr"])) |
| 101 | + |
| 102 | + except asyncio.TimeoutError as e: |
| 103 | + raise ValueError(f"Command timed out after {timeout} seconds") from e |
86 | 104 |
|
87 | 105 | return content |
88 | 106 |
|
@@ -111,7 +129,6 @@ async def call_tool(name: str, arguments: Any) -> Sequence[TextContent]: |
111 | 129 |
|
112 | 130 | except Exception as e: |
113 | 131 | logger.error(traceback.format_exc()) |
114 | | - logger.error(f"Error during call_tool: {str(e)}") |
115 | 132 | raise RuntimeError(f"Error executing command: {str(e)}") from e |
116 | 133 |
|
117 | 134 |
|
|
0 commit comments