Skip to content

Commit b9f8bc1

Browse files
add yaml rule mcp
1 parent a9cf05c commit b9f8bc1

File tree

1 file changed

+28
-2
lines changed

1 file changed

+28
-2
lines changed

main.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,17 @@ def find_code(
1313
pattern: str = Field(description="The ast-grep pattern to search for"),
1414
language: str = Field(description="The language of the query", default=""),
1515
) -> List[dict[str, Any]]:
16-
"""Find code in a project folder that matches the given ast-grep pattern.
17-
"""
16+
"""Find code in a project folder that matches the given ast-grep pattern"""
1817
return run_ast_grep_command(pattern, project_folder, language)
1918

19+
@mcp.tool()
20+
def find_code_by_rule(
21+
project_folder: str = Field(description="The path to the project folder"),
22+
yaml: str = Field(description="The ast-grep pattern to search for"),
23+
) -> List[dict[str, Any]]:
24+
"""Find code using ast-grep's YAML rule in a project folder"""
25+
return run_ast_grep_yaml(yaml, project_folder)
26+
2027
def run_ast_grep_command(pattern: str, project_folder: str, language: Optional[str]) -> List[dict[str, Any]]:
2128
try:
2229
args = ["ast-grep", "--pattern", pattern, "--json", project_folder]
@@ -38,5 +45,24 @@ def run_ast_grep_command(pattern: str, project_folder: str, language: Optional[s
3845
print("Command not found")
3946
return []
4047

48+
def run_ast_grep_yaml(yaml: str, project_folder: str) -> List[dict[str, Any]]:
49+
try:
50+
args = ["ast-grep", "scan","--inline-rules", yaml, "--json", project_folder]
51+
# Run command and capture output
52+
result = subprocess.run(
53+
args,
54+
capture_output=True,
55+
text=True,
56+
check=True # Raises CalledProcessError if return code is non-zero
57+
)
58+
return json.loads(result.stdout)
59+
except subprocess.CalledProcessError as e:
60+
print(f"Command failed with return code {e.returncode}")
61+
print("Error output:", e.stderr)
62+
return e.stderr
63+
except FileNotFoundError:
64+
print("Command not found")
65+
return []
66+
4167
if __name__ == "__main__":
4268
mcp.run(transport = "stdio")

0 commit comments

Comments
 (0)